In [None]:
from unsloth import FastModel,FastLanguageModel
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_value_
import time
from collections import defaultdict
from utility import TTTDataset_iter, seed_worker
from torch.utils.data import DataLoader
from utility import load_grouped_data
from peft import PeftModel
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model, tokenizer = FastModel.from_pretrained(
    # model_name = "unsloth/Qwen3-8B-unsloth-bnb-4bit",
    # model_name = "unsloth/Qwen3-8B-Base-unsloth-bnb-4bit",
    model_name = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit",
    # model_name="unsloth/gemma-3-12b-pt",
    # model_name="unsloth/gemma-3-4b-pt",
    # max_seq_length = 8192, # Choose any for long context!
    resize_model_vocab = 81005, 
    load_in_4bit = True,
)
model.model.embed_tokens.load_state_dict({'weight':torch.load('Model/reduced_embedding.pt')})

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 08-11 06:53:27 [__init__.py:235] Automatically detected platform cuda.
==((====))==  Unsloth 2025.7.11: Fast Qwen3 patching. Transformers: 4.54.1. vLLM: 0.10.0.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.635 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


RuntimeError: Error(s) in loading state_dict for Embedding:
	size mismatch for weight: copying a param with shape torch.Size([81005, 2560]) from checkpoint, the shape in current model is torch.Size([80963, 2560]).

In [2]:
lm_head_weight = nn.Parameter(torch.load('./Model/Gwen4B_lm_head.pth').T)
# lm_head_weight = nn.Parameter(torch.load('./Model/Gwen8B_lm_head.pth').T)
lm_head_weight.requires_grad_(True);

In [3]:
train_data, holdout_data = load_grouped_data()

dataloader = DataLoader(
    TTTDataset_iter(train_data, holdout_data, tokenizer, samples_per_epoch=2000),
    batch_size=1,
    # num_workers=4,
    worker_init_fn=seed_worker,
    collate_fn=lambda x: x[0]
)
# for test_idx_info, input_ids, vi_index, labels in dataloader:
#     break

In [22]:
count = 0
for test_idx_info, input_ids, vi_index, labels in dataloader:
    input_ids2 = vocab_mapping[input_ids]
    count += input_ids2.max() == 80962

In [23]:
count

tensor(387)

In [8]:
vocab_mapping = torch.load('Model/vocab_mapping.pt')

In [9]:
input_ids2 = vocab_mapping[input_ids]

In [15]:
input_ids2 == input_ids2.max()

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

In [11]:
input_ids2.shape, input_ids.shape

(torch.Size([1, 430]), torch.Size([1, 430]))

#### Fine-tune lm_head

In [9]:
epochs = 10
accumulation_steps = 64
lr = 2e-4
clip = 2e-2
label_smoothing = 0.1

In [10]:
trainable_params = [lm_head_weight]
optimizer = torch.optim.Adam(trainable_params,lr = lr)
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing, reduction='none')
print(len(trainable_params))

1


In [11]:
start_time = time.time()
train_loss_accum = 0
val_loss_accum = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (test_idx_info, input_ids, vi_index, labels) in enumerate(dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            with torch.no_grad(): # as we are training the lm_head only.
                output = model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits, labels) # first token is used for training
            train_loss = loss[0] / accumulation_steps
            train_loss.backward()

            # tracking the loss
            train_loss_accum += train_loss.item()
            val_loss_accum += loss[1].item()
            # TODO: track the probability of the test example in nested list
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} train_loss: {train_loss_accum * accumulation_steps / (i+1)}, val_loss: {val_loss_accum / (i+1)}")
    train_loss_accum = 0
    val_loss_accum = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 train_loss: 0.5748720703125, val_loss: 0.51524365234375
Epoch 1 train_loss: 0.54870361328125, val_loss: 0.50256640625
Epoch 2 train_loss: 0.5354169921875, val_loss: 0.5253955078125
Epoch 3 train_loss: 0.52655126953125, val_loss: 0.5092529296875
Epoch 4 train_loss: 0.51628564453125, val_loss: 0.5098896484375
Epoch 5 train_loss: 0.49919140625, val_loss: 0.4918056640625
Epoch 6 train_loss: 0.51589794921875, val_loss: 0.5126552734375
Epoch 7 train_loss: 0.5431318359375, val_loss: 0.539453125
Epoch 8 train_loss: 0.51102587890625, val_loss: 0.50875341796875
Epoch 9 train_loss: 0.498658203125, val_loss: 0.50115283203125
Time taken: 12.896301766236624 minutes


In [12]:
torch.save(lm_head_weight, 'Model/lm_head_weight.pth')

#### LORA

In [13]:
epochs = 10
accumulation_steps = 64
lr = 2e-5
clip = 2e-3
label_smoothing = 0.1

In [14]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)
lm_head_weight.requires_grad_(True);
trainable_params = [param for param in model.parameters() if param.requires_grad]
trainable_params.append(lm_head_weight)
optimizer = torch.optim.Adam(trainable_params,lr = lr) 
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing, reduction='none')
print(len(trainable_params))

Unsloth: Making `model.base_model.model.model` require gradients
505


In [15]:
start_time = time.time()
train_loss_accum = 0
val_loss_accum = 0
prob_list = defaultdict(list)
for epoch in range(epochs):
    for i, (test_idx_info, input_ids, vi_index, labels) in enumerate(dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            input_ids, vi_index, labels = input_ids.to('cuda'), vi_index.to('cuda'), labels.to('cuda')
            output = model.base_model.model.model(input_ids)
            logits = output.last_hidden_state[0, vi_index] @ lm_head_weight # (# of Violation, 4096) @ (4096, 2) -> (# of Violation, 2)
            loss = loss_fn(logits, labels) # first token is used for training
            train_loss = loss[0] / accumulation_steps
            train_loss.backward()

            # tracking the loss
            train_loss_accum += train_loss.item()
            val_loss_accum += loss[1].item()
            # TODO: track the probability of the test example in nested list
            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(trainable_params,clip)
                optimizer.step()
                optimizer.zero_grad()
    print(f"Epoch {epoch} train_loss: {train_loss_accum * accumulation_steps / (i+1)}, val_loss: {val_loss_accum / (i+1)}")
    train_loss_accum = 0
    val_loss_accum = 0
print(f"Time taken: {(time.time() - start_time)/60} minutes")

Epoch 0 train_loss: 0.503875, val_loss: 0.49848681640625
Epoch 1 train_loss: 0.474677734375, val_loss: 0.478013671875
Epoch 2 train_loss: 0.41661767578125, val_loss: 0.45158642578125
Epoch 3 train_loss: 0.43537109375, val_loss: 0.4673544921875
Epoch 4 train_loss: 0.39513330078125, val_loss: 0.4109423828125
Epoch 5 train_loss: 0.3757841796875, val_loss: 0.4032353515625
Epoch 6 train_loss: 0.36643994140625, val_loss: 0.41466943359375
Epoch 7 train_loss: 0.3674013671875, val_loss: 0.40312939453125
Epoch 8 train_loss: 0.36049951171875, val_loss: 0.4411083984375


OutOfMemoryError: CUDA out of memory. Tried to allocate 30.00 MiB. GPU 0 has a total capacity of 23.63 GiB of which 12.88 MiB is free. Including non-PyTorch memory, this process has 23.34 GiB memory in use. Of the allocated memory 22.57 GiB is allocated by PyTorch, and 298.68 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [16]:
model.save_pretrained("Model/merged_model4b")

In [12]:
# continue training
# model = PeftModel.from_pretrained(model, "Model/merged_model4b", is_trainable=True)