GPU check

In [2]:
!nvidia-smi

Wed Aug 27 18:43:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 561.19                 Driver Version: 561.19         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   42C    P3             12W /   30W |       0MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!pip install -q diffusers transformers accelerate peft

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    TrainerCallback,
    logging
)
from torch.nn import functional as F
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from huggingface_hub import login

logger = logging.get_logger(__name__)

In [None]:
q_lora = False
lora_r, lora_alpha, lora_dropout = 8, 16, 0.05
brevity_B, brevity_lambda = 10, 0.1

model_path = "gemma_model"
tokenizer_path = "gemma_tokenizer"
dataset_path = "/content/100_sample_query_finetuning.csv"
output_dir = "./gemma-lora-finetuned"

In [None]:
# If loading from huggingface
login(token= os.getenv("hf_token"))
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-270m", use_fast=True)

if qlora:
    bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", quantization_config=bnb_config, device_map="auto")
    model = prepare_model_for_kbit_training(model)
else:
    model = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation='eager')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
if q_lora:
    bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config, device_map="auto")
    model = prepare_model_for_kbit_training(model)
else:
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")

In [None]:
ds = load_dataset("csv", data_files=dataset_path)
dataset = ds["train"]
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in ["query", "answer", "theme"]])

In [None]:
def build_prompt(ex):
    prompt = f"Query: {ex['query']}\nAnswer:"
    full = f"Query: {ex['query']}\nAnswer: {ex['answer']}"
    return prompt, full

def preprocess(ex):
    prompt, full = build_prompt(ex)

    prompt_ids_len = len(tokenizer(prompt, truncation=True, max_length=512)["input_ids"])

    full_enc = tokenizer(
        full,
        truncation=True,
        padding="max_length",
        max_length=512,
    )

    labels = full_enc["input_ids"].copy()
    labels[:prompt_ids_len] = [-100] * prompt_ids_len

    full_enc["labels"] = labels
    return full_enc

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "<pad>"})

model.resize_token_embeddings(len(tokenizer))

tokenized_ds = ds["train"].map(preprocess, remove_columns=ds["train"].column_names)

target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj"]
config = LoraConfig(r=lora_r, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout, task_type="CAUSAL_LM")
model = get_peft_model(model, config)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        ce_loss = outputs.loss
        logits = outputs.logits

        total_loss = ce_loss
        brevity_loss = torch.tensor(0.0).to(ce_loss.device)
        if self.is_in_train and labels is not None:
            answer_lengths = (labels != -100).sum(dim=1)
            penalty = brevity_lambda * torch.clamp(answer_lengths.float() - brevity_B, min=0)
            brevity_loss = penalty.mean()
            total_loss = ce_loss + brevity_loss

        preds = torch.argmax(logits, dim=-1)

        mask = labels != -100
        correct = (preds[mask] == labels[mask]).float().mean().item()
        cfs = correct  # in [0,1]

        probs = F.softmax(logits, dim=-1)
        pred_emb = probs.mean(dim=1)  
        label_onehot = F.one_hot(torch.clamp(labels, min=0), num_classes=probs.size(-1)).float()
        label_emb = label_onehot.mean(dim=1)
        sim = F.cosine_similarity(pred_emb, label_emb).mean().item()
        mci = sim

        unique_tokens = torch.unique(preds)
        cgvr = len(unique_tokens) / (preds.numel() + 1e-8)

        self.log({
            "loss": total_loss.item(),
            "ce_loss": ce_loss.item(),
            "brevity_loss": brevity_loss.item(),
            "cfs": cfs,
            "mci": mci,
            "cgvr": cgvr
        })

        return (total_loss, outputs) if return_outputs else total_loss


In [None]:
class CustomLogCallback(TrainerCallback):
    def __init__(self, log_file_path):
        self.log_file_path = log_file_path
        # Ensure directory exists
        os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
        # Clear log file at the start of training
        with open(self.log_file_path, "w") as f:
            f.write("Custom Training Logs\n" + "="*22 + "\n")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            with open(self.log_file_path, "a") as f:
                log_str = f"Step: {state.global_step}, Loss: {logs.get('loss', 'N/A')}, CE Loss: {logs.get('ce_loss', 'N/A')}, Brevity Loss: {logs.get('brevity_loss', 'N/A')}, LR: {logs.get('learning_rate', 'N/A')}\n"
                f.write(log_str)

In [None]:
custom_log_path = os.path.join(output_dir, "custom_metrics.log")
custom_logger = CustomLogCallback(log_file_path=custom_log_path)


args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=2,  
    gradient_accumulation_steps=8, 
    learning_rate=2e-4,
    fp16=True,
    logging_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=3,
    remove_unused_columns=False,
    report_to="tensorboard"          #tensorboard --logdir gemma-lora-finetuned/runs  # to check the logs in tensorboard
)
trainer = CustomTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[custom_logger]
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)