In [1]:
from unsloth import FastModel,FastLanguageModel
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_value_
import time
from collections import defaultdict
from utility import TTTDataset_iter, seed_worker
from torch.utils.data import DataLoader
from utility import load_grouped_data
from peft import PeftModel
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model, tokenizer = FastModel.from_pretrained(
    # model_name = "unsloth/Qwen3-8B-unsloth-bnb-4bit",
    # model_name = "unsloth/Qwen3-8B-Base-unsloth-bnb-4bit",
    model_name = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit",
    # model_name="unsloth/gemma-3-12b-pt",
    # model_name="unsloth/gemma-3-4b-pt",
    # max_seq_length = 8192, # Choose any for long context!
    load_in_4bit = True,
)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
INFO 08-09 08:35:07 [__init__.py:235] Automatically detected platform cuda.
==((====))==  Unsloth 2025.7.11: Fast Qwen3 patching. Transformers: 4.54.1. vLLM: 0.10.0.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.635 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [2]:
lm_head_weight = nn.Parameter(torch.load('./Model/Gwen4B_lm_head.pth').T)
# lm_head_weight = nn.Parameter(torch.load('./Model/Gwen8B_lm_head.pth').T)
lm_head_weight.requires_grad_(True);

In [3]:
train_data, holdout_data = load_grouped_data()

dataloader = DataLoader(
    TTTDataset_iter(train_data, holdout_data, tokenizer, samples_per_epoch=1000),
    batch_size=1,
    num_workers=4,
    worker_init_fn=seed_worker,
    collate_fn=lambda x: x[0]
)
# for test_idx_info, input_ids, vi_index, labels in dataloader:
#     break

#### Fine-tune lm_head

In [4]:
epochs = 10
accumulation_steps = 32
lr = 1e-5
clip = 1e-3
label_smoothing = 0.1

In [5]:
trainable_params = [lm_head_weight]
optimizer = torch.optim.Adam(trainable_params,lr = lr)
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing, reduction='none')
print(len(trainable_params))

1


In [6]:
start_time = time.time()
train_loss_accum = 0
val_loss_accum = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (test_idx_info, input_ids, vi_index, labels) in enumerate(dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            with torch.no_grad(): # as we are training the lm_head only.
                output = model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits, labels) # first token is used for training
            train_loss = loss[0]
            train_loss.backward()

            # tracking the loss
            train_loss_accum += train_loss.item()
            val_loss_accum += loss[1].item()
            # TODO: track the probability of the test example in nested list
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} train_loss: {train_loss_accum / (i+1)}, val_loss: {val_loss_accum / (i+1)}")
    train_loss_accum = 0
    val_loss_accum = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 train_loss: 0.5570830078125, val_loss: 0.5913955078125
Epoch 1 train_loss: 0.6126015625, val_loss: 0.58017578125
Epoch 2 train_loss: 0.623546875, val_loss: 0.566310546875
Epoch 3 train_loss: 0.5957783203125, val_loss: 0.590337890625
Epoch 4 train_loss: 0.6101357421875, val_loss: 0.5660537109375
Epoch 5 train_loss: 0.5791826171875, val_loss: 0.543955078125
Epoch 6 train_loss: 0.586201171875, val_loss: 0.5533984375
Epoch 7 train_loss: 0.62674609375, val_loss: 0.568876953125
Epoch 8 train_loss: 0.5879658203125, val_loss: 0.56141015625
Epoch 9 train_loss: 0.5730390625, val_loss: 0.5651708984375
Time taken: 6.346543021996816 minutes


In [7]:
torch.save(lm_head_weight, 'Model/lm_head_weight.pth')

#### LORA

In [8]:
epochs = 10
accumulation_steps = 32
lr = 6e-7
clip = 6e-5
label_smoothing = 0.1

In [9]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)
lm_head_weight.requires_grad_(True);
trainable_params = [param for param in model.parameters() if param.requires_grad]
trainable_params.append(lm_head_weight)
optimizer = torch.optim.Adam(trainable_params,lr = lr) 
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing, reduction='none')
print(len(trainable_params))

Unsloth: Making `model.base_model.model.model` require gradients
505


In [10]:
start_time = time.time()
train_loss_accum = 0
val_loss_accum = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (test_idx_info, input_ids, vi_index, labels) in enumerate(dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            output = model.base_model.model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits, labels) # first token is used for training
            train_loss = loss[0]
            train_loss.backward()

            # tracking the loss
            train_loss_accum += train_loss.item()
            val_loss_accum += loss[1].item()
            # TODO: track the probability of the test example in nested list
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} train_loss: {train_loss_accum / (i+1)}, val_loss: {val_loss_accum / (i+1)}")
    train_loss_accum = 0
    val_loss_accum = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 train_loss: 0.5915966796875, val_loss: 0.56445703125
Epoch 1 train_loss: 0.5783642578125, val_loss: 0.5507529296875
Epoch 2 train_loss: 0.5625537109375, val_loss: 0.588408203125
Epoch 3 train_loss: 0.54944921875, val_loss: 0.57584765625
Epoch 4 train_loss: 0.5615791015625, val_loss: 0.5427177734375
Epoch 5 train_loss: 0.5929501953125, val_loss: 0.569513671875
Epoch 6 train_loss: 0.6152080078125, val_loss: 0.57633984375
Epoch 7 train_loss: 0.5685703125, val_loss: 0.5600537109375
Epoch 8 train_loss: 0.56568359375, val_loss: 0.5652109375
Epoch 9 train_loss: 0.572830078125, val_loss: 0.555408203125
Time taken: 20.43752895196279 minutes


In [None]:
start_time = time.time()
train_loss_accum = 0
val_loss_accum = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (test_idx_info, input_ids, vi_index, labels) in enumerate(dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            output = model.base_model.model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits, labels) # first token is used for training
            train_loss = loss[0]
            train_loss.backward()

            # tracking the loss
            train_loss_accum += train_loss.item()
            val_loss_accum += loss[1].item()
            # TODO: track the probability of the test example in nested list
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} train_loss: {train_loss_accum / (i+1)}, val_loss: {val_loss_accum / (i+1)}")
    train_loss_accum = 0
    val_loss_accum = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 train_loss: 0.5560966796875, val_loss: 0.5601484375
Epoch 1 train_loss: 0.5756650390625, val_loss: 0.5670615234375
Epoch 2 train_loss: 0.562619140625, val_loss: 0.57713671875


In [11]:
model.save_pretrained("Model/merged_model4b")

In [12]:
# continue training
# model = PeftModel.from_pretrained(model, "Model/merged_model4b", is_trainable=True)