### 🛠️ Environment Setup


In [1]:
import os
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '2'  # Faster HF downloads
os.environ['PYTHONIOENCODING'] = 'utf-8'       # Text encoding consistency
os.environ['PYTHONUTF8'] = '1'                 # Enable UTF-8 mode for Python

# GPU setup
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # for single gpu

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_count = torch.cuda.device_count()
    for i in range(gpu_count):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)} - {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

PyTorch version: 2.6.0+cu124
CUDA available: True
GPU 0: NVIDIA A100-SXM4-40GB - 42.5 GB


### 📦Installing Required Packages

In [2]:
from IPython.display import Markdown, FileLink, display, clear_output

In [3]:
%%capture

# Memory & performance optimization: Quantization, acceleration, efficient attention, GPU kernels
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 triton

# Unsloth fine-tuning ecosystem and parameter-efficient training
!pip install --no-deps unsloth unsloth_zoo peft trl cut_cross_entropy

# Data pipeline essentials
!pip install "datasets>=3.4.1" sentencepiece protobuf hf_transfer
!pip install -U "huggingface-hub>=0.34.0,<1.0"

# Computer vision model support (for multimodal capabilities)
!pip install --no-deps --upgrade timm

# Latest Transformers library from development branch
!pip install --no-deps git+https://github.com/huggingface/transformers.git

# Evaluation and logging tools
#!pip install evaluate sacrebleu jiwer wandb

## 🤖 Loading the Base Model


In [4]:
from unsloth import FastModel
import torch, gc

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E4B-it",
    max_seq_length = 2048,
    load_in_4bit = True,
    load_in_8bit = False,
    full_finetuning = False,
    #max_memory={0: "6GB", "cpu": "14GB"}
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.56.0.dev0.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.72G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

### 🔍 Understanding the Parameters:

We begin by loading a compact, memory-efficient version of the model using [Unsloth](https://github.com/unslothai/unsloth) — a lightweight wrapper for fast LLM training with LoRA. Here's what each parameter does:

- **`model_name`**: The base model to customize with LoRA fine-tuning

- **`max_seq_length`** Maximum number of tokens the model can handle per input. This defines the context window (i.e., how much text the model can "see" at once):
    - 2048 tokens ≈ 1,500 words (suitable for Q&A, short tasks)
    - 4096 tokens ≈ 3,000 words (multi-turn conversations, short docs)
    - 8192 tokens ≈ 6,000 words (large context, code, papers)
    - 32768 tokens ≈ 24,000 words (for extended-context models)

- **`load_in_4bit`**: Reduces memory usage by ~75% with minimal quality loss.
- **`load_in_8bit`**: Near precesion of original model.
- **`load_in_16bit`**: This is good with representing things with details.

In [5]:
# To Render response in Markdown
from transformers import TextStreamer
from IPython.display import Markdown, display, clear_output
import torch, gc, time

class SimpleJupyterStreamer(TextStreamer):
    def __init__(self, tokenizer, skip_prompt=False, **decode_kwargs):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.generated_text = ""
        self.last_update = time.time()

    def put(self, value):
        if value.ndim > 1:
            if value.shape[0] > 1:
                raise ValueError("TextStreamer only supports batch size 1")
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        text = self.tokenizer.decode(value, **self.decode_kwargs)
        if text:
            self.generated_text += text
            if time.time() - self.last_update > 0.1:
                clear_output(wait=True)
                display(Markdown(f"🤖 **Generating...**\n\n{self.generated_text}"))
                self.last_update = time.time()

def chat_inference(messages, model, tokenizer, max_new_tokens=2048):
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to("cuda")

    streamer = SimpleJupyterStreamer(tokenizer, skip_prompt=True)

    _ = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=1.0,
        top_k=64,
        top_p=0.95,
        streamer=streamer,
    )

    # Final output render
    clear_output(wait=True)
    display(Markdown(f"🤖 **Response :**\n\n{streamer.generated_text.strip()}"))

    # Free memory
    del inputs
    torch.cuda.empty_cache()
    gc.collect()


In [6]:
model_instruction = (

    "Your answer to Partha, using poetic and direct Gita-like language, 2–4 sentences. Do NOT mention Arjuna; use 'Partha'. Do NOT copy verses, but paraphrase the wisdom as Krishna would speak.\n\n"
    "A short real-world story involving {a famouse personality}, that reflects Krishna's advice. Make the story relevant, grounded, and within 2–4 sentences."

)

In [7]:
import random
import numpy as np

# For reproducibility
set_all_seeds = lambda seed: seed is not None and [torch.manual_seed(seed), torch.cuda.manual_seed(seed), torch.cuda.manual_seed_all(seed), random.seed(seed), np.random.seed(seed)]

# Simple utility to wrap user content in chat format
def create_message(content_list, role="user"):
    return [{"role": role, "content": content_list}]

# Adds system instruction and delegates to chat inference
def ask_multimodal(content_list, model, tokenizer, max_new_tokens=256, role="user", model_instruction=model_instruction, seed=73127):
    set_all_seeds(seed)
    messages = [{"role": "system",
                 "content": [{"type": "text", "text": model_instruction}]
               }] + create_message(content_list, role)
    chat_inference(messages, model, tokenizer, max_new_tokens=max_new_tokens)

## ⚙️ Understanding LoRA Configuration


In [8]:
# Add LoRA adapters to the model
model = FastModel.get_peft_model(
    model,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    r=64,
    lora_alpha=64,
    lora_dropout=0,
    bias="none",
    use_cache = False,
    use_gradient_checkpointing=True,  # True or "unsloth" for very long context
    use_rslora=True,
    random_state=73
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


### 🔧 LoRA Adapter Configuration

We add lightweight LoRA adapters to the model to enable efficient fine-tuning. Here's what the key parameters mean:

<hr style="border: none; height: 1px; background: #eee;" />

#### 🎯 Target Modules – Where LoRA is applied

These layers are fine-tuned while keeping the rest of the model frozen:

- **Attention Layers** (how the model "focuses"):
  - `q_proj`: Query - "**What am I looking for?**"
  - `k_proj`: Key - "**What information is available?**"
  - `v_proj`: Value - "**Here’s the information itself.**"
  - `o_proj`: Output - "**Projects attention output back to model’s hidden space**"

- **MLP Layers** (how the model "thinks"):
  - `gate_proj`: Controls flow of information
  - `up_proj`: Expands dimensionality for processing
  - `down_proj`: Compresses back to original size

> More modules = stronger fine-tuning, but also more memory & compute

<hr style="border: none; height: 1px; background: #eee;" />

#### 📐 LoRA Hyperparameters

- **`r` (Rank)** – Controls the capacity of each adapter:
  - `8`: Very lightweight (~0.3M trainable params)
  - `16`: ⭐ Optimal trade-off (~0.7M params)
  - `32`: Expanded adaptation (~1.4M params)
  - `64`: High‑rank adaptation (~2.8M params)

- **`lora_alpha`** – Scales the adapted weights; typically equal to `r`.

- **`lora_dropout`** – Dropout rate for LoRA layers. Often set to `0` for stability.

<hr style="border: none; height: 1px; background: #eee;" />

#### 🧠 Memory Optimization for training

- **`use_gradient_checkpointing="unsloth"`** – Reduces memory usage by trading off compute. Useful for training larger models on limited hardware.
- **`use_rslora=True`** – Enables *Rank-Stabilized LoRA*, improving training quality on small batch sizes.
- **`random_state=73`** – Sets the random seed for reproducibility.


In [9]:
from unsloth.chat_templates import get_chat_template

# Set up the chat template for Gemma 3
tokenizer = get_chat_template(
    tokenizer,
    chat_template="gemma-3",
)

## Loading the Dataset


In [10]:
from datasets import load_dataset

ds = load_dataset("json", data_files="/content/Krishna_dataset_v.jsonl", split="train")

print(ds.features)

Generating train split: 0 examples [00:00, ? examples/s]

{'question': Value('string'), 'personality': Value('string'), 'answer': Value('string')}


In [11]:
import pandas as pd
clear_output()
sample_df = ds.select(range(3)).to_pandas()
display(sample_df)

Unnamed: 0,question,personality,answer
0,Can you guide me on why do I feel lost sometim...,Yuval Noah Harari,"O Partha, feeling lost is the veil of Maya tha..."
1,What is the value of silence?,Rihanna,"O Partha, silence is the serene space where wi..."
2,How do I handle spiritual loneliness?.,Kiran Mazumdar-Shaw,"Partha, in the silence of solitude, know that ..."


### 🧹 Cleaning the Dataset



In [12]:
dataset = ds.remove_columns(['personality'])
print(dataset)

Dataset({
    features: ['question', 'answer'],
    num_rows: 2010
})


In [13]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.expand_frame_repr', False)

sample_df = dataset.select(range(5)).to_pandas()
display(sample_df)

Unnamed: 0,question,answer
0,Can you guide me on why do I feel lost sometimes? in tough times?,"O Partha, feeling lost is the veil of Maya that clouds the mind. In difficult times, remember your true self beyond this illusion, and stay anchored in unwavering faith. The essence of life is in embracing the inner light, which guides you through the shadows. Yuval Noah Harari, feeling overwhelmed by the complexities of modern life, took a silent meditation retreat. In the silence, he found clarity and peace, realizing that the key to overcoming confusion lies in understanding one's own mind and staying present."
1,What is the value of silence?,"O Partha, silence is the serene space where wisdom takes root, and the soul finds its true voice. In the stillness, the noise of the world fades, allowing the heart to align with the divine rhythm of truth. Rihanna, amidst the clamor of fame, often retreats to a quiet sanctuary where she can reflect and listen to her inner voice. In these moments of silence, she finds clarity and strength, guiding her creative spirit and personal journey."
2,How do I handle spiritual loneliness?.,"Partha, in the silence of solitude, know that the soul is never alone. The divine presence resides within, a companion in the heart's inner sanctum. Embrace this eternal connection, and the loneliness shall dissolve into the oneness of spirit. Kiran Mazumdar-Shaw, while leading Biocon through challenging times, found herself isolated in decision-making. She turned to meditation, discovering a profound inner peace and connection that guided her through solitude, transforming her loneliness into strength and clarity."
3,How should I react to negative news? for someone like me?,"O Partha, remain steadfast in your inner tranquility amidst the tumult of the world. Let not the waves of despair shake your resolve, but view them as passing clouds, impermanent and fleeting. Anchor yourself in wisdom and discernment, for the soul's true nature is beyond the reach of external shadows. Michelle Obama, faced with criticism throughout her time in the public eye, chose to remain gracious and focused on her mission. She often spoke about how she filtered out negativity, concentrating on the positive impact she could make, and by doing so, she inspired countless individuals to rise above the noise and work towards meaningful change."
4,How can I manage my time more effectively? in my situation?,"O Partha, time is but a river flowing eternally; be the vigilant boatman who steers with purpose. Attend to each moment with mindful action, for diligence and focus shall guide thee to your rightful destination. Reed Hastings, when founding Netflix, realized the importance of focusing on core objectives. By prioritizing innovation and customer experience over distractions, he steered Netflix towards becoming a leading entertainment provider."


Unsloth provides a helpful utility called `standardize_data_formats()` that automatically converts many popular dataset formats into a structure that's compatible with Unsloth's fine-tuning pipeline.

In [14]:
from unsloth.chat_templates import standardize_data_formats
dataset = standardize_data_formats(dataset)

In [15]:
dataset[5]

{'question': 'How can I discover my true purpose? in my situation?',
 'answer': 'O Partha, seek within the depths of your own heart, where the divine whisper resides. Embrace your dharma, the duty that aligns with your nature, and act without attachment to the fruits. In the clarity of selfless action, your true purpose shall unveil itself. Isaac Asimov found his purpose in writing when he realized that his passion for storytelling could inspire others. Despite being a professor, he embraced writing as his true calling, producing works that continue to enlighten and entertain, showing that purpose often reveals itself through what we love to do.'}

**Transform Q&A pairs into chat-style conversation text** – Uses Unsloth’s tokenizer chat template to wrap each question and answer in Gemma‑3n’s expected chat template (`<start_of_turn>user/model<end_of_turn>`) , and strips the `<bos>` token since the processor will add it during training. The model expects only one `<bos>` token per sequence.


In [16]:
from datasets import Dataset, concatenate_datasets

# Convert your custom examples to match original dataset keys
custom_examples = [
    {
        "question": "There’s been a flood in my area and I’ve lost internet. What should I do?",
        "answer": "Stay calm. Move to higher ground immediately if you’re in danger. Avoid walking or driving through floodwaters. Use a battery-powered radio to get updates. Keep your phone charged using a power bank. Store clean water and dry food. Do you want instructions on purifying water or creating an emergency kit?"
    },
    {
        "question": "Explain how the water cycle works using simple language.",
        "answer": "The water cycle is how water moves through the Earth. First, water in lakes and oceans heats up and becomes steam (evaporation). Then, the steam becomes clouds (condensation). When the clouds get heavy, water falls as rain (precipitation). That water goes back to rivers, lakes, or the ground, and the cycle starts again."
    },
    {
        "question": "What's my favorite programming language and why do I prefer it?",
        "answer": "Based on our previous conversations, your favorite is R Markdown (Rmd). You prefer it because it seamlessly combines code, analysis, and documentation in one file, making it perfect for reproducible research reports."
    }
]

# Convert to HF dataset
custom_dataset = Dataset.from_list(custom_examples)

def merged_dataset(dataset, n_len=-1):
    if n_len == -1:
        original_sample = dataset.shuffle(seed=73)
    else:
        original_sample = dataset.select(range(n_len)).shuffle(seed=73)
    return concatenate_datasets([original_sample, custom_dataset])


In [17]:
n_samples = 2010

In [18]:
def formatting_prompts_func(examples):
    questions = examples["question"]
    responses = examples["answer"]

    texts = []
    for question, response in zip(questions, responses):
        conversation = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": response},
        ]

        formatted_text = tokenizer.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=False,
        ).removeprefix('<bos>')

        texts.append(formatted_text)

    return {"text": texts}

# Merge, format, and reduce columns
dataset_merged = merged_dataset(dataset, n_samples)
dataset_formatted = dataset_merged.map(formatting_prompts_func, batched=True)
dataset = dataset_formatted.select_columns(["text"])

print("After formatting columns:", dataset.column_names)

Map:   0%|          | 0/2013 [00:00<?, ? examples/s]

After formatting columns: ['text']


### 🧱 What the Function Does

* It loops through each `Question`–`Response` pair in the dataset
* Creates a **conversation structure**:

  ```python
  [
    {"role": "user", "content": question},
    {"role": "assistant", "content": response}
  ]
  ```
* Applies the tokenizer's `apply_chat_template()` method, which wraps the conversation in special tokens used by **Gemma-3n**


In [19]:
# example
dataset[-1]["text"]

"<start_of_turn>user\nWhat's my favorite programming language and why do I prefer it?<end_of_turn>\n<start_of_turn>model\nBased on our previous conversations, your favorite is R Markdown (Rmd). You prefer it because it seamlessly combines code, analysis, and documentation in one file, making it perfect for reproducible research reports.<end_of_turn>\n"

<a name="Train"></a>
### 🏋️ Train the model

In [20]:
# To Enable evaluation training
use_eval_set = False
patience = 7

In [21]:
# Callbacks
from transformers import EarlyStoppingCallback, TrainerCallback, TrainerControl, TrainerState
import torch
from typing import Dict, Any

class TrainingLossEarlyStoppingCallback(TrainerCallback):
    def __init__(self, early_stopping_patience: int = 10, min_delta: float = 0.001, min_steps: int = 20):
        self.early_stopping_patience = early_stopping_patience
        self.min_delta = min_delta
        self.min_steps = min_steps
        self.best_loss = float('inf')
        self.patience_counter = 0
        self.best_step = 0

    def on_log(self, args, state: TrainerState, control: TrainerControl, logs: Dict[str, float] = None, **kwargs):
        if logs is None or logs.get('loss') is None:
            return

        current_loss = logs.get('loss')

        if state.global_step < self.min_steps:
            if current_loss < self.best_loss:
                self.best_loss = current_loss
                self.best_step = state.global_step
                print(f"🎯 New best training loss: {current_loss:.6f} at step {state.global_step} (warmup phase)")
            else:
                if state.global_step > 1:
                    print(f"No improvement at step {state.global_step} (warmup phase, < min_steps ({self.min_steps}))")
            return

        if current_loss < self.best_loss - self.min_delta:
            self.best_loss = current_loss
            self.patience_counter = 0
            self.best_step = state.global_step
            print(f"🎯 New best training loss: {current_loss:.6f} at step {state.global_step}")
        else:
            self.patience_counter += 1
            if self.patience_counter <= 3:
                print(f"No improvement for {self.patience_counter}/{self.early_stopping_patience} steps")

        if self.patience_counter >= self.early_stopping_patience:
            print(f"⏹️ Early stopping at step {state.global_step}. Best loss: {self.best_loss:.6f}")
            control.should_training_stop = True

class StepFinalCallback(TrainerCallback):
    def __init__(self, use_eval_set: bool = False):
        self.use_eval_set = use_eval_set
        self.step_losses = []
        self.final_logged = False

    def on_step_end(self, args, state, control, **kwargs):
        # Force logging for final step if not already logged
        if (state.global_step == args.max_steps and
            state.global_step % args.logging_steps != 0 and
            not self.final_logged):
            control.should_log = True
            self.final_logged = True

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and state.global_step > 0:
            step_loss = logs.get('loss')
            if step_loss is not None:
                self.step_losses.append({'step': state.global_step, 'loss': step_loss})

            print(f"\n=== Step {state.global_step} Results ===")
            for key, value in logs.items():
                if key == 'train_loss':  # Skip the average train_loss
                    continue
                if isinstance(value, float):
                    print(f"{key}: {value:.6f}")
                else:
                    print(f"{key}: {value}")
            print("-" * 40)

    def on_train_end(self, args, state, control, **kwargs):
        if not self.step_losses:
            return

        trainer = kwargs.get('trainer')
        first_loss = self.step_losses[0]['loss']
        final_loss = self.step_losses[-1]['loss']
        best_loss = min(entry['loss'] for entry in self.step_losses)
        improvement = first_loss - final_loss
        improvement_pct = (improvement / first_loss) * 100

        print("\n" + "="*50)
        print("🎯 FINAL MODEL EVALUATION")
        print("="*50)
        print(f"📈 Training Summary:")
        print(f"   Initial Loss: {first_loss:.6f}")
        print(f"   Last Step Loss: {final_loss:.6f}")
        print(f"   Best Loss: {best_loss:.6f}")
        print(f"   Improvement: {improvement:.6f} ({improvement_pct:.2f}%)")
        print(f"   Total Steps: {len(self.step_losses)}")

        if len(self.step_losses) >= 5:
            print(f"\n📊 Loss Progression (Last 5 Steps):")
            for entry in self.step_losses[-5:]:
                print(f"   Step {entry['step']:3d}: {entry['loss']:.6f}")

        if trainer and self.use_eval_set and trainer.eval_dataset:
            try:
                eval_results = trainer.evaluate()
                print(f"\n🔍 Final Evaluation Results:")
                for key, value in eval_results.items():
                    if isinstance(value, float):
                        print(f"   {key}: {value:.6f}")
            except:
                pass

        print("="*50)

# Callbacks function
def setup_callbacks(use_eval_set=use_eval_set, patience=patience):
    callbacks = []
    if use_eval_set:
        from transformers import EarlyStoppingCallback
        callbacks.append(EarlyStoppingCallback(early_stopping_patience=patience))
    else:
        callbacks.append(TrainingLossEarlyStoppingCallback(early_stopping_patience=patience))
    callbacks.append(StepFinalCallback(use_eval_set=use_eval_set))
    return callbacks


# Helpers
def get_hardware_factors(ds_size):
    #  Detects GPU memory availability and calculates scaling factors.
    gpu_stats = torch.cuda.get_device_properties(0)
    available_memory = round(
        gpu_stats.total_memory / 1024**3
        - torch.cuda.max_memory_reserved() / 1024**3, 1
    )
    size_factor = min(1.0, ds_size / (200 + ds_size * 0.8))
    mem_factor = min(1.0, available_memory / 16)

    return available_memory, size_factor, mem_factor

def efficient_bs(ds_size, mem_factor, size_factor, mini_batch):
    # Max batch allowed by memory and dataset size
    max_bs_mem  = 2 ** int(1 + 2 * mem_factor)
    max_bs_data = int(1 + 15 * mem_factor / (1 + 10 / ds_size))
    batch_size  = max(1, min(max_bs_mem, max_bs_data))

    # Cap for mini-batch mode
    if mini_batch:
        batch_size = min(batch_size, 2)

    # Target effective size
    cal_target = int(4 + 28 * size_factor / (1 + 500 / ds_size))
    target_eff = min(14, cal_target) if mini_batch else cal_target

    # Gradient accumulation steps
    min_accumulation = max(7, int(10 * math.exp(-ds_size / 500)))
    accumulation_steps = max(min_accumulation, target_eff // batch_size)

    return batch_size, accumulation_steps

def is_t4():
    gpu_name = torch.cuda.get_device_name(0)
    return 'T4' in gpu_name

In [22]:
# Patch model to avoid Unsloth AttributeError
if not hasattr(model, "_flag_for_generation"):
    model._flag_for_generation = True


In [23]:
from trl import SFTConfig, SFTTrainer
from unsloth import is_bfloat16_supported
from transformers import EarlyStoppingCallback
import math

# Dataset splitting logic
if use_eval_set:
    split_dataset = dataset.train_test_split(test_size=0.1, seed=73)
    train_dataset = split_dataset['train']
    eval_dataset = split_dataset['test']
else:
    train_dataset = dataset
    eval_dataset = None

# Auto-calculated training parameters
ds_size = len(train_dataset)
available_memory, size_factor, memory_factor = get_hardware_factors(ds_size)

# Batch configuration
mini_batch = True  # False → batch size based on VRAM
batch_size, accumulation = efficient_bs(ds_size, memory_factor, size_factor, mini_batch)
effective_batch_size = batch_size * accumulation

# Training steps
steps_per_epoch = max(1, ds_size // effective_batch_size)
epoch_scale = max(3, int(5 + 15 * math.exp(-ds_size/1000)))
max_steps = max(50, min(3000, steps_per_epoch * epoch_scale))
max_steps = int(max_steps * 1.3) if is_t4() else max_steps

# Learning rate
dataset_stability = math.sqrt(50) / math.sqrt(50 + ds_size)
base_lr = (3e-5 + 2e-4 * size_factor) * (0.3 + 0.7 / (1 + dataset_stability * 10))
adaptive_lr = max(1e-6, min(5e-4, base_lr))

# Intervals and scheduling
log_interval, eval_interval = max(1, max_steps // 20), max(1, steps_per_epoch)
warmup_ratio = max(0.1, 0.4 * math.exp(-ds_size/300))
warmup_steps = max(5, int(max_steps * warmup_ratio))

# Regularization
weight_decay = max(0.005, 0.08 * math.exp(-ds_size/400))
max_grad_norm = max(0.2, 1.0 - 0.8 * size_factor / (1 + 100 / ds_size))
max_grad_norm *= 0.75 if is_t4() else 1
scheduler_type = "linear" if ds_size < 600 else "cosine"

# Initialize the trainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    dataset_text_field = "text",
    packing = False,  # True → Multi‑turn conversations
    callbacks = setup_callbacks(use_eval_set=use_eval_set, patience=patience),

    args = SFTConfig(
        # Training config
        per_device_train_batch_size = batch_size,
        gradient_accumulation_steps = accumulation,
        **{"max_steps": max_steps},

        # Learning rate scheduling
        learning_rate = adaptive_lr,
        warmup_steps = warmup_steps,
        optim = "adafactor", # More adaptive
        weight_decay = weight_decay,
        lr_scheduler_type = scheduler_type,

        # Performance
        dataset_num_proc = 1,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        dataloader_pin_memory = True,
        max_grad_norm = max_grad_norm,
        dataloader_drop_last = True,
        remove_unused_columns = True,

        # Checkpointing
        save_steps = log_interval,
        save_total_limit = patience + 1,
        save_strategy = "steps",
        output_dir = "outputs",

        # Evaluation settings (conditional)
        **({
            "do_eval": True,
            "eval_steps": eval_interval,
            "eval_strategy": "steps",
            "per_device_eval_batch_size": 1,  # Smaller batch size for evaluation
            "eval_accumulation_steps": 1,
            "greater_is_better": False,
            "metric_for_best_model": "eval_loss",
            "load_best_model_at_end": True,
        } if use_eval_set else {
            "eval_strategy": "no",
        }),

        # Logging
        seed = 73,
        logging_steps = log_interval,
        logging_first_step = True,
        disable_tqdm = False,
        report_to = "none",  # Set this to "wandb" if using Weights & Biases
    ),
)

# Configuration summary
constraint = "Memory" if batch_size == int(1 + 7 * memory_factor) else "Dataset"
print(f"{'='*70}")
print(f"TRAINING CONFIGURATION SUMMARY")
print(f"Dataset: {ds_size} samples | GPU: {available_memory}GB | Factors: size={size_factor:.2f}, memory={memory_factor:.2f}")
print(f"Batch: {batch_size} x {accumulation} = {effective_batch_size} (limited by {constraint})")
print(f"Training: {max_steps} steps ({epoch_scale} epochs, {steps_per_epoch} steps/epoch)")
print(f"Learning: {adaptive_lr:.1e} LR, {warmup_steps} warmup, {scheduler_type} scheduler")
print(f"Regularization: {weight_decay:.4f} weight decay, {max_grad_norm:.1f} grad norm")
print(f"Monitoring: log every {log_interval}, eval every {eval_interval}, patience {patience}")
print(f"{'='*70}")


Unsloth: Tokenizing ["text"]:   0%|          | 0/2013 [00:00<?, ? examples/s]

TRAINING CONFIGURATION SUMMARY
Dataset: 2013 samples | GPU: 27.5GB | Factors: size=1.00, memory=1.00
Batch: 2 x 7 = 14 (limited by Dataset)
Training: 1001 steps (7 epochs, 143 steps/epoch)
Learning: 1.3e-04 LR, 100 warmup, cosine scheduler
Regularization: 0.0050 weight decay, 0.2 grad norm
Monitoring: log every 50, eval every 143, patience 7


In [24]:
# Apply response-only training
from unsloth.chat_templates import train_on_responses_only

# This ensures we only train on the assistant's responses, not the user's questions
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
    num_proc         = 1,
)

Map:   0%|          | 0/2013 [00:00<?, ? examples/s]

**Why Response-Only Training?**

- We don't want the model to learn to predict user inputs
- We only want it to learn better responses
- This significantly improves training efficiency and model quality

In [25]:
tokenizer.decode(trainer.train_dataset[8]["input_ids"])

'<bos><start_of_turn>user\nHow do I balance work and personal life?<end_of_turn>\n<start_of_turn>model\nO Partha, just as a river flows steadily towards the ocean, so should you strive for harmony in the realms of duty and desire. Engage in your tasks with unwavering focus, yet nurture the bonds of the heart, for balance in action leads to true fulfillment. PV Sindhu, amidst intense training for the Olympics, always ensures she spends quality time with her family. This balance between her career and personal life fuels her spirit, allowing her to perform with joy and dedication on the court.<end_of_turn>\n'

In [26]:
def colored_print(text, color_code):
    return f"\033[1;{color_code}m\033[1m{text}\033[0m"

print(colored_print("🔦 What model sees:", "94"), tokenizer.decode(trainer.train_dataset[8]["input_ids"])[:100] + "...")
print(colored_print("💡 What model learns:", "92"), tokenizer.decode([x for x in trainer.train_dataset[8]["labels"] if x != -100])[:100] + "...")

[1;94m[1m🔦 What model sees:[0m <bos><start_of_turn>user
How do I balance work and personal life?<end_of_turn>
<start_of_turn>model
...
[1;92m[1m💡 What model learns:[0m O Partha, just as a river flows steadily towards the ocean, so should you strive for harmony in the ...


Notice: Full context provided for understanding, but gradients only flow through the answer portion

In [27]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.
12.078 GB of memory reserved.


In [28]:
from unsloth import unsloth_train
trainer_stats = unsloth_train(trainer) # trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,013 | Num Epochs = 7 | Total steps = 1,001
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 7
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 7 x 1) = 14
 "-____-"     Trainable parameters = 160,759,808 of 8,010,738,000 (2.01% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,8.118
50,3.0073
100,1.9401
150,1.7422
200,1.3891
250,1.3706
300,1.2387
350,0.8098
400,0.833
450,0.6846


🎯 New best training loss: 8.118000 at step 1 (warmup phase)

=== Step 1 Results ===
loss: 8.118000
grad_norm: 26427788.000000
learning_rate: 0.000000
epoch: 0.006958
----------------------------------------
🎯 New best training loss: 3.007300 at step 50

=== Step 50 Results ===
loss: 3.007300
grad_norm: 12.885169
learning_rate: 0.000065
epoch: 0.347913
----------------------------------------
🎯 New best training loss: 1.940100 at step 100

=== Step 100 Results ===
loss: 1.940100
grad_norm: 5.309426
learning_rate: 0.000131
epoch: 0.695825
----------------------------------------
🎯 New best training loss: 1.742200 at step 150

=== Step 150 Results ===
loss: 1.742200
grad_norm: 3.538751
learning_rate: 0.000131
epoch: 1.041750
----------------------------------------
🎯 New best training loss: 1.389100 at step 200

=== Step 200 Results ===
loss: 1.389100
grad_norm: 3.803827
learning_rate: 0.000128
epoch: 1.389662
----------------------------------------
🎯 New best training loss: 1.370600 at 

In [29]:
GB_CONVERSION = 1024 ** 3
SECONDS_TO_MINUTES = 60

# Memory calculations
used_memory_gb = torch.cuda.max_memory_reserved() / GB_CONVERSION
used_memory_for_training_gb = used_memory_gb - start_gpu_memory
used_percentage = (used_memory_gb / max_memory) * 100
training_percentage = (used_memory_for_training_gb / max_memory) * 100

# Time calculations
runtime_seconds = trainer_stats.metrics['train_runtime']
runtime_minutes = runtime_seconds / SECONDS_TO_MINUTES

print("TRAINING STATISTICS")
print("=" * 50)
print(f"Training time: {runtime_seconds:.1f} seconds ({runtime_minutes:.2f} minutes)")
print(f"Peak memory usage: {used_memory_gb:.3f} GB ({used_percentage:.1f}% of max)")
print(f"Memory for training: {used_memory_for_training_gb:.3f} GB ({training_percentage:.1f}% of max)")
print("=" * 50)

TRAINING STATISTICS
Training time: 10448.1 seconds (174.13 minutes)
Peak memory usage: 12.338 GB (31.2% of max)
Memory for training: 0.260 GB (0.7% of max)


<a name="Inference"></a>
### Inference After Training
According to the `Gemma-3n` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`

In [30]:
# After Training
ask_multimodal([
    {"type": "text", "text": "How do I accept things I cannot change?"}
], model, tokenizer, max_new_tokens=300, model_instruction="")

🤖 **Response :**

O Partha, embrace the flow of existence with serenity, for not all is within your realm to alter. Find peace in understanding the divine order and focus your energies on what truly lies within your grasp. Act with wisdom and grace, knowing the ultimate outcome is beyond your control. Chanakya, faced with the unyielding strength of the Mauryan Empire, chose to focus his energies on building a new empire in the east. Recognizing that he could not change the empire's core, he skillfully maneuvered to create a buffer state, expanding his influence without direct conquest.<end_of_turn>

In [31]:
# After Training
ask_multimodal([
    {"type": "text", "text": "What can I do to stop overthinking?"}
], model, tokenizer, max_new_tokens=300, model_instruction="")

🤖 **Response :**

O Partha, the mind is restless and difficult to control, yet with practice and determination, it can be tamed. Focus on your duty without attachment to the fruits, and let your actions flow like a river to the sea, unperturbed by the obstacles in its path. Olivia Rodrigo, while preparing for a major concert, found herself overwhelmed with thoughts of potential mistakes. She decided to focus solely on her vocal warm-ups and stage presence, letting go of the fear of imperfection, and found that her performance was electric, connecting deeply with her fans.<end_of_turn>

In [32]:
# After Training
ask_multimodal([
    {"type": "text", "text": "What is the value of silence?"}
], model, tokenizer, max_new_tokens=300, model_instruction="")

🤖 **Response :**

O Partha, in the embrace of silence, one finds the true voice of the self, unburdened by the clamor of the world. It is a sanctuary where wisdom whispers and the soul finds its eternal dance. Plato, renowned for his dialogues, often engaged in long periods of silence before speaking. This practice allowed him to formulate his profound ideas, free from the immediate influence of others, thus capturing the essence of his philosophical inquiries.<end_of_turn>

### 💾 To Save LoRA Adapters

In [33]:
# to save lora adapters (~100mb)
model.save_pretrained("gemma-3n-lora-model")
tokenizer.save_pretrained("gemma-3n-lora-model")

import shutil
folder_path = "./gemma-3n-lora-model"
zip_path = f"{folder_path}.zip"
shutil.make_archive(folder_path, 'zip', folder_path)

from IPython.display import FileLink
FileLink(zip_path)

**Benefits of saving LoRA adapters:**

- **Small file size**: Only a few MB instead of several GB
- **Portable**: Can be shared easily, uploaded to Hugging Face Hub
- **Flexible**: Can be loaded on top of any compatible base model

### 🌐 Save Full Model  
Merging the base model with the trained adapter weights and saving in **float16** format for VLLM.

In [34]:
import shutil

# Remove unwanted directory to free up disk space before merging
def cleanup_dir(dir_="dir_name"):
    if os.path.exists(dir_):
        shutil.rmtree(dir_)
        print(f"{dir_} directory removed successfully")

In [36]:
# Merge to 16bit
model_dir = "gemma-3n-finetune"
cleanup_dir(model_dir)
model.save_pretrained_merged(model_dir, tokenizer, save_method="merged_8bit")
# model.save_pretrained_merged(model_dir, tokenizer, save_method="merged_16bit")

Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model-00001-of-00004.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Downloading safetensors index for unsloth/gemma-3n-e4b-it...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Unsloth: Merging weights into 16bit:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  25%|██▌       | 1/4 [00:21<01:05, 21.96s/it]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  50%|█████     | 2/4 [01:00<01:03, 31.51s/it]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  75%|███████▌  | 3/4 [01:57<00:43, 43.39s/it]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit: 100%|██████████| 4/4 [02:21<00:00, 35.37s/it]


In [None]:
from unsloth import FastLanguageModel
import torch
import os, shutil

# Your fine-tuned model name
model_name = "/content/gemma-3n-finetune"

# Load the model (even if it was trained in 4bit or 16bit)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name,
    load_in_4bit=False,        # can still merge to 8bit without loading in 4bit
    max_seq_length=2048,
    dtype=None,
    device_map="auto"
)

# Switch to eval mode
model.eval()

# Define output folder
save_folder = "gemma3n_finetuned_8bit"

# Optional: clean the folder first
if os.path.exists(save_folder):
    shutil.rmtree(save_folder)

# 🔹 Merge and save in 8-bit
model.save_pretrained_merged(
    save_folder,
    tokenizer,
    save_method="merged_8bit"   # ✅ faster and smaller than 16-bit
)

print(f"✅ 8-bit model merged and saved to: {save_folder}")


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.1: Fast Gemma patching. Transformers: 4.55.0.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


## Push model to HuggingFace

In [None]:
model.push_to_hub_merged(
        "p2kalita/gemma-3n-E4B-it-finetuned-KrishnaAI_8bit", tokenizer
    )

In [38]:
pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm

Collecting unsloth
  Downloading unsloth-2025.8.1-py3-none-any.whl.metadata (47 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/47.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.3/47.3 kB[0m [31m127.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting unsloth_zoo
  Downloading unsloth_zoo-2025.8.1-py3-none-any.whl.metadata (8.1 kB)
Collecting transformers
  Downloading transformers-4.55.0-py3-none-any.whl.metadata (39 kB)
Collecting timm
  Downloading timm-1.0.19-py3-none-any.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.8/60.8 kB[0m [31m254.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading unsloth-2025.8.1-py3-none-any.whl (299 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m299.8/299.8 kB[0m [31m339.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading unsloth_zoo-2025.8.1-py3-none-any.whl (166 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "/content/gemma-3n-finetune",
        max_seq_length = 2048,
        dtype = None,
        load_in_4bit=True
)

In [None]:
model.push_to_hub("p2kalita/gemma-3n-E4B-it-finetuned-KrishnaAI-q_4")