In [1]:
!pip install -q transformers accelerate peft datasets bitsandbytes

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/411.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m409.6/411.1 kB[0m [31m13.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m411.1/411.1 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from datasets import Dataset
from transformers import Trainer
from torch import autocast
import json
import re
from datasets import Dataset
from transformers import TrainingArguments, DataCollatorForLanguageModeling
from huggingface_hub import login
from peft import PeftModel

### Training of the second LoRA
This script applies Parameter-Efficient Fine-Tuning (PEFT) using LoRA (Low-Rank Adaptation). It first loads a previously trained LoRA adapter (first_adapter) as the base, then adds a new trainable adapter with specified LoRA configuration. Finally, it activates the new adapter for training and prints the number of trainable parameters

In [None]:
model_id = "google/gemma-3-4b-pt"
#login(token="hf_...")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

# Load the model with bfloat16
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

previous_lora_path = "./first_adapter"

# Load the finished adapter on top of the base model
model = PeftModel.from_pretrained(model, previous_lora_path, adapter_name="pretrained")

lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model.add_adapter("trainable", lora_config)

# Activate it for training
model.set_adapter("trainable")
model.print_trainable_parameters()


In [None]:
training_args = TrainingArguments(
    output_dir="./lora_gemma",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    max_steps=2000, #increase for real training
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10, # increase for training
    save_strategy="steps",  # save by steps, not by epoches
    save_steps=101,  #increase for training
    evaluation_strategy="no",
    report_to="none",
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = MyTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)

### Custom Trainer
This part defines a custom Trainer class that does not only use torch.autocast with bfloat16 for mixed-precision training and enhanced numerical stability, but also includes additional logging for debugging NaNs and monitoring training dynamics.

In [None]:
class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        input_ids = inputs["input_ids"].to(model.device)
        labels = inputs["labels"].to(model.device)

        # autocast with bfloat16
        with torch.autocast("cuda", dtype=torch.bfloat16):
            outputs = model(input_ids=input_ids, labels=labels)
            loss = outputs.loss

        if torch.isnan(loss):
            print("🚨 NaN detected in loss!")

        if self.state.global_step % self.args.logging_steps == 0:
            print(f"[Step {self.state.global_step}] Loss: {loss.item():.4f}")
            logits = outputs.logits

            has_nan = torch.isnan(logits).any().item()
            max_logit = logits.max().item()
            min_logit = logits.min().item()

            print(f"    Logits NaN? {has_nan}")
            print(f"    Logits range: [{min_logit:.3f}, {max_logit:.3f}]")

            if loss.item() < 0.01:
                print("⚠️ Warning: loss is very small — possible overfitting?")

        return (loss, outputs) if return_outputs else loss



def split_sentences(text):
    # The simplest separation by periods, exclamation marks and question marks
    sentences = re.split(r'(?<=[.!?])\s+', text)
    return sentences

def clean_text(text, top_cut=0.1, bottom_cut=0.1):
    length = len(text)
    start = int(length * top_cut)
    end = int(length * (1 - bottom_cut))
    trimmed = text[start:end]

    sentences = split_sentences(trimmed)
    cleaned_text = " ".join(sentences)
    return cleaned_text

all_samples = []

with open("books.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        data = json.loads(line)
        raw_text = data["text"]
        cleaned = clean_text(raw_text)

        tokens = tokenizer(cleaned, return_tensors="pt", truncation=False)["input_ids"][0]

        max_length = 256
        stride = 128
        for i in range(0, len(tokens) - max_length, stride):
            chunk = tokens[i : i + max_length]
            all_samples.append({
                "input_ids": chunk,
                "labels": chunk
            })

dataset = Dataset.from_list(all_samples)


In [None]:

sample = dataset[0]

input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(model.device)
labels = torch.tensor(sample["labels"]).unsqueeze(0).to(model.device)

with torch.autocast("cuda", dtype=torch.bfloat16):
    outputs = model(input_ids=input_ids, labels=labels)
    loss = outputs.loss
    print("Sample loss:", outputs.loss.item())
    print("Any logits NaN?", torch.isnan(outputs.logits).any().item())
    logits = outputs.logits

print("Logits dtype:", logits.dtype)
print("Logits min:", logits.min().item())
print("Logits max:", logits.max().item())

print("Labels min:", labels.min().item())
print("Labels max:", labels.max().item())

vocab_size = tokenizer.vocab_size
print("Tokenizer vocab size:", vocab_size)
print("Any label >= vocab_size:", (labels >= vocab_size).any().item())

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

trainer.train()

model.save_pretrained("./lora_adapter")
tokenizer.save_pretrained("./lora_adapter")