In [1]:
from unsloth import FastModel
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_value_
import time
from collections import defaultdict
from utility import TTTDataset, seed_worker
from torch.utils.data import DataLoader
from utility import load_grouped_data
import pandas as pd
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model, tokenizer = FastModel.from_pretrained(
    # model_name = "unsloth/Qwen3-8B-unsloth-bnb-4bit",
    model_name = "unsloth/Qwen3-8B-Base-unsloth-bnb-4bit",
    # model_name="unsloth/gemma-3-12b-pt",
    # model_name="unsloth/gemma-3-4b-pt",
    max_seq_length = 8192, # Choose any for long context!
    load_in_4bit = True,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 08-05 16:41:51 [__init__.py:235] Automatically detected platform cuda.
==((====))==  Unsloth 2025.7.11: Fast Qwen3 patching. Transformers: 4.54.1. vLLM: 0.10.0.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.635 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [2]:
lm_head_weight = nn.Parameter(torch.load('./Model/Gwen8B_lm_head_base.pth').T)
# lm_head_weight = nn.Parameter(torch.load('./Model/Gwen8B_lm_head.pth').T)
lm_head_weight.requires_grad_(True);

In [3]:
train_data, holdout_data = load_grouped_data()

loader = DataLoader(
    TTTDataset(train_data, holdout_data, tokenizer, samples_per_epoch=1000),
    batch_size=1,
    num_workers=4,
    worker_init_fn=seed_worker,
)

In [4]:
iter_loader = iter(loader)

In [44]:
input_ids, vi_index, labels = next(iter_loader)
labels

tensor([[0, 1, 1]])

In [30]:
input_ids.shape

torch.Size([1, 131])

In [31]:
vi_index[0]

tensor([ 59,  69, 130])

In [32]:
input_ids[0, vi_index[0]]

tensor([25, 25, 25])

In [33]:
labels

tensor([[1, 0, 0]])

#### Fine-tune lm_head

In [4]:
epochs = 3
accumulation_steps = 64
lr = 1e-5
clip = 1e-3
label_smoothing = 0.1

In [5]:
trainable_params = [lm_head_weight]
optimizer = torch.optim.Adam(trainable_params,lr = lr)
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
print(len(trainable_params))

1


In [6]:
start_time = time.time()
train_loss = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (row_id, input_ids, vi_index, labels) in enumerate(dataloader):
        row_id = int(row_id)
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            with torch.no_grad(): # as we are training the lm_head only.
                output = model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits[:2], labels) # first 2 tokens are used for training, (N, C), (N,)
            loss.backward()
            train_loss += loss.item()
            if vi_index.shape[0] == 3:
                prob = torch.softmax(logits[2], dim=0)[1].item()
                prob_list[row_id].append(prob)
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} loss: {train_loss / (i+1)}")
    train_loss = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 loss: 0.8505517650320354
Epoch 1 loss: 0.7960620803043371
Epoch 2 loss: 0.7578240512567768
Time taken: 4.878338877360026 minutes
