Full attention on X and causal on Y

In [1]:
from unsloth import FastModel
import torch
from torch.nn.utils import clip_grad_value_
import torch.nn as nn
# from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
from functions import *
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-4b-pt-unsloth-bnb-4bit",
    # model_name="unsloth/gemma-3-12b-pt",
    # model_name="unsloth/gemma-3-4b-pt",
    max_seq_length = 8192, # Choose any for long context!
    load_in_4bit = True,
    resize_model_vocab=16,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 04-06 07:49:28 [__init__.py:239] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.dev0. vLLM: 0.8.2.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.642 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


KeyboardInterrupt: 

In [None]:
model = model.base_model
model.train();
model.lm_head.weight.requires_grad_(True);
model.model.embed_tokens.padding_idx = None # otherwise token zero will be ignored

In [None]:
# resize model vocab
# model.model.embed_tokens = Gemma3TextScaledWordEmbedding(16,2560,0,50.59644256269407)
# model.lm_head = nn.Linear(2560, 16, bias=False)
# model.lm_head.weight = model.model.embed_tokens.weight

In [None]:
import json
output_path = '/home/zhenlan/Desktop/Projects/ARC2/Data/ARC-AGI-2-main/combined_data.json'
with open(output_path, 'r') as f:
    data = json.load(f)

#### Fine-tune embedding

In [None]:
epochs = 10
accumulation_steps = 32
lr = 3e-5
clip = 3e-3
MAX_LEN = 3600

In [None]:
trainable_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(trainable_params,lr = lr)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
# 85 minutes
train_loss = 0
val_loss = 0
for epoch in range(epochs):
    model.train()
    for i,(x,y,lengths) in enumerate(data_gen(data,True,MAX_LEN,return_lengths=True)):
        mask = create_arc_causal_attention_mask(*lengths)
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            yhat = model(x,attention_mask=mask).logits
            loss = loss_fn(yhat.view(-1,yhat.shape[-1]),y.view(-1))
        loss.backward()
        train_loss += loss.item()

        if (i + 1) % accumulation_steps == 0:
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()
    model.eval()
    for j,(x,y,lengths) in enumerate(data_gen(data,False,MAX_LEN*4,return_lengths=True)):
        mask = create_arc_causal_attention_mask(*lengths)
        with torch.no_grad():
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                yhat = model(x,attention_mask=mask).logits
                loss = loss_fn(yhat.view(-1,yhat.shape[-1]),y.view(-1))
        val_loss += loss.item()
    
    print(f"Epoch {epoch+1} - Train Loss: {train_loss/i:.4f} - Val Loss: {val_loss/j:.4f}")
    train_loss = 0
    val_loss = 0

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!
Epoch 1 - Train Loss: 6.4463 - Val Loss: 7.3893


KeyboardInterrupt: 

In [None]:
torch.save(model.lm_head.state_dict(), '../Model/lm_heads_weights_pt.pth')

#### Fine-Tune QLORA

In [None]:
epochs = 60
accumulation_steps = 32
lr = 2e-5
clip = 2e-3
MAX_LEN = 3600

In [None]:
model.max_seq_length = MAX_LEN

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 64,           # Larger = higher accuracy, but might overfit
    lora_alpha = 64,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

Unsloth: Making `model.base_model.model.model` require gradients


In [None]:
trainable_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(trainable_params,lr = lr)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
train_loss = 0
val_loss = 0
for epoch in range(epochs):
    model.train()
    for i,(x,y,lengths) in enumerate(data_gen(data,True,MAX_LEN,return_lengths=True)):
        mask = create_arc_causal_attention_mask(*lengths)
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            yhat = model(x,attention_mask=mask).logits
            loss = loss_fn(yhat.view(-1,yhat.shape[-1]),y.view(-1))
        loss.backward()
        train_loss += loss.item()

        if (i + 1) % accumulation_steps == 0:
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()
    model.eval()
    for j,(x,y,lengths) in enumerate(data_gen(data,False,MAX_LEN*4,return_lengths=True)):
        mask = create_arc_causal_attention_mask(*lengths)
        with torch.no_grad():
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                yhat = model(x,attention_mask=mask).logits
                loss = loss_fn(yhat.view(-1,yhat.shape[-1]),y.view(-1))
        val_loss += loss.item()
    
    print(f"Epoch {epoch+1} - Train Loss: {train_loss/i:.4f} - Val Loss: {val_loss/j:.4f}")
    train_loss = 0
    val_loss = 0

Epoch 1 - Train Loss: 1.4876 - Val Loss: 1.1082
Epoch 2 - Train Loss: 1.1100 - Val Loss: 0.9136
Epoch 3 - Train Loss: 0.9532 - Val Loss: 0.7667
Epoch 4 - Train Loss: 0.8027 - Val Loss: 0.6260
Epoch 5 - Train Loss: 0.6754 - Val Loss: 0.5132


KeyboardInterrupt: 

In [None]:
model.save_pretrained("../Model/merged_model_pt")
torch.save(model.lm_head.state_dict(), '../Model/lm_heads_weights_pt.pth')

