In [1]:
from unsloth import FastModel
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_value_
import time
from collections import defaultdict
from utility import TTTDataset_iter, seed_worker
from torch.utils.data import DataLoader
from utility import load_grouped_data
from peft import PeftModel
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model, tokenizer = FastModel.from_pretrained(
    # model_name = "unsloth/Qwen3-8B-unsloth-bnb-4bit",
    # model_name = "unsloth/Qwen3-8B-Base-unsloth-bnb-4bit",
    model_name = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit",
    # model_name="unsloth/gemma-3-12b-pt",
    # model_name="unsloth/gemma-3-4b-pt",
    max_seq_length = 8192, # Choose any for long context!
    load_in_4bit = True,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 08-06 17:25:49 [__init__.py:235] Automatically detected platform cuda.
==((====))==  Unsloth 2025.7.11: Fast Qwen3 patching. Transformers: 4.54.1. vLLM: 0.10.0.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.635 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [2]:
lm_head_weight = nn.Parameter(torch.load('./Model/Gwen4B_lm_head.pth').T)
# lm_head_weight = nn.Parameter(torch.load('./Model/Gwen8B_lm_head.pth').T)
lm_head_weight.requires_grad_(True);

In [3]:
train_data, holdout_data = load_grouped_data()

dataloader = DataLoader(
    TTTDataset_iter(train_data, holdout_data, tokenizer, samples_per_epoch=1000),
    batch_size=1,
    num_workers=4,
    worker_init_fn=seed_worker,
    collate_fn=lambda x: x[0]
)
# for test_idx_info, input_ids, vi_index, labels in dataloader:
#     break

#### Fine-tune lm_head

In [4]:
epochs = 3
accumulation_steps = 32
lr = 2e-5
clip = 2e-3
label_smoothing = 0.1

In [5]:
trainable_params = [lm_head_weight]
optimizer = torch.optim.Adam(trainable_params,lr = lr)
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing, reduction='none')
print(len(trainable_params))

1


In [6]:
start_time = time.time()
train_loss_accum = 0
val_loss_accum = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (test_idx_info, input_ids, vi_index, labels) in enumerate(dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            with torch.no_grad(): # as we are training the lm_head only.
                output = model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits, labels) # first 2 tokens are used for training
            train_loss = loss[:2].mean()
            train_loss.backward()

            # tracking the loss
            train_loss_accum += train_loss.item()
            val_loss_accum += loss[2].item()
            # TODO: track the probability of the test example in nested list
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} train_loss: {train_loss_accum / (i+1)}, val_loss: {val_loss_accum / (i+1)}")
    train_loss_accum = 0
    val_loss_accum = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 train_loss: 0.56742578125, val_loss: 0.5485537109375
Epoch 1 train_loss: 0.5771298828125, val_loss: 0.575361328125
Epoch 2 train_loss: 0.5515166015625, val_loss: 0.5511064453125
Time taken: 1.983101236820221 minutes


In [7]:
torch.save(lm_head_weight, 'Model/lm_head_weight.pth')

#### LORA

In [8]:
epochs = 3
accumulation_steps = 32
lr = 7e-6
clip = 7e-4
label_smoothing = 0.1

In [9]:
model = FastModel.get_peft_model(
    model,
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 32,           # Larger = higher accuracy, but might overfit
    lora_alpha = 32,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_gradient_checkpointing = "unsloth",
)
trainable_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.Adam(trainable_params,lr = lr) 
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing, reduction='none')
print(len(trainable_params))

Unsloth: Making `model.base_model.model.model` require gradients
504


In [10]:
start_time = time.time()
train_loss_accum = 0
val_loss_accum = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (test_idx_info, input_ids, vi_index, labels) in enumerate(dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            output = model.base_model.model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits, labels) # first 2 tokens are used for training
            train_loss = loss[:2].mean()
            train_loss.backward()

            # tracking the loss
            train_loss_accum += train_loss.item()
            val_loss_accum += loss[2].item()
            # TODO: track the probability of the test example in nested list
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} train_loss: {train_loss_accum / (i+1)}, val_loss: {val_loss_accum / (i+1)}")
    train_loss_accum = 0
    val_loss_accum = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 train_loss: 0.46651953125, val_loss: 0.5565478515625
Epoch 1 train_loss: 0.3863818359375, val_loss: 0.5776376953125


KeyboardInterrupt: 

In [None]:
model.save_pretrained("Model/merged_model4b")

In [None]:
# continue training
model = PeftModel.from_pretrained(model, "Model/merged_model4b", is_trainable=True)