In [5]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [2]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade transformers # Only for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


In [4]:
import torch
torch.cuda.empty_cache()

In [6]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E4B-it", # Or "unsloth/gemma-3n-E2B-it"
    dtype = None, # None for auto detection
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


2025-08-02 22:12:11.165843: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754172731.512435      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754172731.609455      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.54.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.72G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

In [7]:
from transformers import TextStreamer
import gc
# Helper function for inference
def do_gemma_3n_inference(model, messages, max_new_tokens = 128):
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True, # Must add for generation
        tokenize = True,
        return_dict = True,
        return_tensors = "pt",
    ).to("cuda")
    _ = model.generate(
        **inputs,
        max_new_tokens = max_new_tokens,
        temperature = 1.0, top_p = 0.95, top_k = 64,
        streamer = TextStreamer(tokenizer, skip_prompt = True),
    )
    # Cleanup to reduce VRAM usage
    del inputs
    torch.cuda.empty_cache()
    gc.collect()

In [8]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # Should leave on always!

    r = 8,           # Larger = higher accuracy, but might overfit
    lora_alpha = 8,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


In [9]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

In [10]:
%%capture 
pip install python-docx

In [11]:
# If not already installed in this environment; uncomment if needed
# !pip install python-docx datasets transformers unsloth

import os
import re
from collections import defaultdict
from docx import Document
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from unsloth.chat_templates import get_chat_template


In [12]:
# === CONFIG ===
INPUT_DIR = "/kaggle/input/raw-docs"  # folder with your .docx files
AGE_GROUP_REGEX = re.compile(r"Ages?\s*([\d\-]+)", flags=re.IGNORECASE)
SPLIT_RATIOS = {"train": 0.8, "val": 0.1, "test": 0.1}
BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"  # use this for testing; swap to your Gemma 3n checkpoint
MAX_LENGTH = 256
IGNORE_INDEX = -100

# === Helpers ===
def infer_age_group_from_filename(path):
    fname = os.path.basename(path)
    m = AGE_GROUP_REGEX.search(fname)
    return m.group(1) if m else "unknown"

def is_topic_header(text):
    return bool(re.match(r"^[A-Z][a-zA-Z& ]+$", text)) and not text.endswith("?")

def clean_answer(text):
    text = text.replace("Sparky's Answer:", "").strip()
    return re.split(r"Wow! Fact:.*?Wow!$", text, flags=re.DOTALL)[0].strip()

def normalize_subject(subject):
    if not subject:
        return "unknown"
    s = subject.strip().lower()
    if "math" in s:
        return "math"
    if "science" in s:
        return "science"
    if "geography" in s:
        return "geography"
    if "history" in s:
        return "history"
    return s


In [13]:
def parse_docx_file(path):
    print(f"Parsing {path}")
    doc = Document(path)
    age_group = infer_age_group_from_filename(path)
    examples = []
    current_topic = None
    current_question = None
    current_answer_parts = []

    for para in doc.paragraphs:
        text = para.text.strip()
        if not text:
            continue
        if is_topic_header(text):
            current_topic = text
        elif text.endswith("?") and "Sparky's Answer" not in text:
            if current_question and current_answer_parts:
                cleaned = clean_answer(" ".join(current_answer_parts))
                examples.append({
                    "question": current_question,
                    "answer": cleaned,
                    "subject": normalize_subject(current_topic),
                    "age_group": age_group,
                    "format": "open_ended",
                })
                current_answer_parts = []
            current_question = text
        elif "Sparky's Answer:" in text or current_answer_parts:
            current_answer_parts.append(text)

    if current_question and current_answer_parts:
        cleaned = clean_answer(" ".join(current_answer_parts))
        examples.append({
            "question": current_question,
            "answer": cleaned,
            "subject": normalize_subject(current_topic),
            "age_group": age_group,
            "format": "open_ended",
        })

    print(f"  → extracted {len(examples)} examples")
    return examples

# Collect all examples by age group
by_age = defaultdict(list)
for root, _, files in os.walk(INPUT_DIR):
    for fname in files:
        if fname.lower().endswith(".docx"):
            path = os.path.join(root, fname)
            exs = parse_docx_file(path)
            for ex in exs:
                by_age[ex["age_group"]].append(ex)


Parsing /kaggle/input/raw-docs/Ages 5-7.docx
  → extracted 103 examples
Parsing /kaggle/input/raw-docs/Ages 1113.docx
  → extracted 49 examples
Parsing /kaggle/input/raw-docs/Ages 8-10.docx
  → extracted 206 examples
Parsing /kaggle/input/raw-docs/Ages 1415.docx
  → extracted 15 examples


In [14]:
raw_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, use_fast=True)
if raw_tokenizer.pad_token is None:
    raw_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
chat_tokenizer_wrapper = get_chat_template(raw_tokenizer, chat_template="gemma-3")


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

In [15]:
def format_and_tokenize_batch(batch):
    input_ids_list = []
    labels_list = []
    for q, a in zip(batch["question"], batch["answer"]):
        convo_full = [{"role":"user","content":q}, {"role":"assistant","content":a}]
        convo_prompt = [{"role":"user","content":q}]
        full_text = chat_tokenizer_wrapper.apply_chat_template(
            convo_full, tokenize=False, add_generation_prompt=False
        ).removeprefix("<bos>")
        prompt_text = chat_tokenizer_wrapper.apply_chat_template(
            convo_prompt, tokenize=False, add_generation_prompt=False
        ).removeprefix("<bos>")

        tok_full = raw_tokenizer(full_text, truncation=True, max_length=MAX_LENGTH)
        tok_prompt = raw_tokenizer(prompt_text, truncation=True, max_length=MAX_LENGTH)

        input_ids = tok_full["input_ids"]
        labels = input_ids.copy()
        prompt_len = len(tok_prompt["input_ids"])
        for i in range(prompt_len):
            if i < len(labels):
                labels[i] = IGNORE_INDEX
        input_ids_list.append(input_ids)
        labels_list.append(labels)

    return {"input_ids": input_ids_list, "labels": labels_list}


In [16]:
processed_datasets = {}  # keep for inspection / training

for age_group, examples in by_age.items():
    print(f"\n--- Age group {age_group} ({len(examples)} examples) ---")
    ds = Dataset.from_list(examples)

    # If too few per subject for stratification, you can drop stratify_by_column temporarily
    try:
        train_testval = ds.train_test_split(
            test_size=1 - SPLIT_RATIOS["train"], seed=42, stratify_by_column="subject"
        )
        val_test = train_testval["test"].train_test_split(
            test_size=SPLIT_RATIOS["test"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"]),
            seed=43,
            stratify_by_column="subject"
        )
    except Exception as e:
        print("Stratified split failed:", e, "Falling back to random split.")
        train_testval = ds.train_test_split(test_size=1 - SPLIT_RATIOS["train"], seed=42)
        val_test = train_testval["test"].train_test_split(
            test_size=SPLIT_RATIOS["test"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"]),
            seed=43
        )

    dataset_dict = DatasetDict({
        "train": train_testval["train"],
        "val": val_test["train"],
        "test": val_test["test"],
    })

    # Apply formatting+tokenization
    def preprocess_split(split_ds):
        return split_ds.map(
            lambda batch: format_and_tokenize_batch(batch),
            batched=True,
            remove_columns=[c for c in split_ds.column_names if c not in ("subject","age_group","format")],
        )

    tokenized = DatasetDict({
        split: preprocess_split(dataset_dict[split]) for split in ["train","val","test"]
    })

    tokenized.set_format(type="torch", columns=["input_ids", "labels"])
    processed_datasets[age_group] = tokenized

    # Optional: persist to disk per age group
    out_dir = os.path.join("prepared_finetune_data_notebook", f"tokenized_age_{age_group}")
    tokenized.save_to_disk(out_dir)
    print(f"Saved tokenized dataset to {out_dir}")



--- Age group 5-7 (103 examples) ---
Stratified split failed: Stratifying by column is only supported for ClassLabel column, and column subject is Value. Falling back to random split.


Map:   0%|          | 0/82 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/82 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/11 [00:00<?, ? examples/s]

Saved tokenized dataset to prepared_finetune_data_notebook/tokenized_age_5-7

--- Age group 1113 (49 examples) ---
Stratified split failed: Stratifying by column is only supported for ClassLabel column, and column subject is Value. Falling back to random split.


Map:   0%|          | 0/39 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/39 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5 [00:00<?, ? examples/s]

Saved tokenized dataset to prepared_finetune_data_notebook/tokenized_age_1113

--- Age group 8-10 (206 examples) ---
Stratified split failed: Stratifying by column is only supported for ClassLabel column, and column subject is Value. Falling back to random split.


Map:   0%|          | 0/164 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/164 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/21 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/21 [00:00<?, ? examples/s]

Saved tokenized dataset to prepared_finetune_data_notebook/tokenized_age_8-10

--- Age group 1415 (15 examples) ---
Stratified split failed: Stratifying by column is only supported for ClassLabel column, and column subject is Value. Falling back to random split.


Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/12 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2 [00:00<?, ? examples/s]

Saved tokenized dataset to prepared_finetune_data_notebook/tokenized_age_1415


In [17]:
from copy import deepcopy

def add_text_field(ds_dict, chat_tokenizer_wrapper):
    # ds_dict is a DatasetDict with splits like "train","val","test", containing at least "question" and "answer"
    def make_text(example):
        convo = [
            {"role": "user", "content": example["question"]},
            {"role": "assistant", "content": example["answer"]},
        ]
        # generation prompt included so SFTTrainer knows where answer starts
        example["text"] = chat_tokenizer_wrapper.apply_chat_template(
            convo, tokenize=False, add_generation_prompt=True
        ).removeprefix("<bos>")
        return example

    return ds_dict.map(make_text, batched=False)


In [18]:
# 1. Rebuild raw_datasets_by_age (stratified splits)
by_age_raw = defaultdict(list)
for root, _, files in os.walk(INPUT_DIR):
    for fname in files:
        if not fname.lower().endswith(".docx"):
            continue
        path = os.path.join(root, fname)
        examples = parse_docx_file(path)  # your existing parser
        for ex in examples:
            by_age_raw[ex["age_group"]].append(ex)

raw_datasets_by_age = {}
for age_group, examples in by_age_raw.items():
    ds = Dataset.from_list(examples)
    try:
        train_testval = ds.train_test_split(
            test_size=1 - SPLIT_RATIOS["train"],
            seed=42,
            stratify_by_column="subject",
        )
        val_test = train_testval["test"].train_test_split(
            test_size=SPLIT_RATIOS["test"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"]),
            seed=43,
            stratify_by_column="subject",
        )
    except Exception:
        train_testval = ds.train_test_split(test_size=1 - SPLIT_RATIOS["train"], seed=42)
        val_test = train_testval["test"].train_test_split(
            test_size=SPLIT_RATIOS["test"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"]),
            seed=43,
        )

    raw_datasets_by_age[age_group] = DatasetDict({
        "train": train_testval["train"],
        "val": val_test["train"],
        "test": val_test["test"],
    })

Parsing /kaggle/input/raw-docs/Ages 5-7.docx
  → extracted 103 examples
Parsing /kaggle/input/raw-docs/Ages 1113.docx
  → extracted 49 examples
Parsing /kaggle/input/raw-docs/Ages 8-10.docx
  → extracted 206 examples
Parsing /kaggle/input/raw-docs/Ages 1415.docx
  → extracted 15 examples


In [19]:
# 2. Add "text" field for SFTTrainer
def add_text_field(ds_dict):
    def make_text(example):
        convo = [
            {"role": "user", "content": example["question"]},
            {"role": "assistant", "content": example["answer"]},
        ]
        example["text"] = chat_tokenizer_wrapper.apply_chat_template(
            convo, tokenize=False, add_generation_prompt=True
        ).removeprefix("<bos>")
        return example
    return ds_dict.map(make_text, batched=False)

dataset_texts_by_age = {}
for age_group, raw_ds in raw_datasets_by_age.items():
    dataset_texts_by_age[age_group] = add_text_field(raw_ds)

Map:   0%|          | 0/82 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Map:   0%|          | 0/39 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/164 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

In [20]:
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from trl import SFTTrainer, SFTConfig
import torch
import copy

# Load base model once (frozen except adapters)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
    device_map=None
).to(device)

base_model.eval()  # just to be safe; SFTTrainer will set train mode when needed

adapter_checkpoints = {}
previous_adapter_path = None

# Ensure age_order is sorted as desired
# e.g., age_order = sorted(dataset_texts_by_age.keys(), key=lambda s: [int(x) for x in s.split("-")])
for i, age in enumerate(age_order):
    print(f"\n=== Curriculum step: age group {age} ===")
    ds = dataset_texts_by_age[age]  # has train/val/test with "text"

    # Reload a fresh base model to avoid in-place contamination between adapters
    model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, torch_dtype=torch.float16, device_map="auto")

    # LoRA config
    peft_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],  # adjust if your model uses different module names
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Wrap base with LoRA
    model = get_peft_model(model, peft_config)

    # Warm-start from previous adapter if available
    if previous_adapter_path:
        # This assumes previous_adapter_path contains the PEFT adapter directory saved by trainer.save_model()
        model = PeftModel.from_pretrained(model, previous_adapter_path)

    model.print_trainable_parameters()

    # Create SFTTrainer for this age
    trainer = SFTTrainer(
        model=model,
        tokenizer=raw_tokenizer,
        train_dataset=ds["train"],
        eval_dataset=ds.get("val"),
        args=SFTConfig(
            dataset_text_field="text",
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            warmup_steps=10,
            max_steps=100,
            learning_rate=2e-4,
            logging_steps=10,
            optim="paged_adamw_8bit",
            weight_decay=0.01,
            lr_scheduler_type="linear",
            seed=42 + i,
            report_to="none",
        ),
    )

    trainer.train()

    # Save this age's adapter (only adapter weights + config)
    save_dir = f"./adapter_age_{age.replace('/', '_')}"
    trainer.save_model(save_dir)  # PEFT adapter checkpoint
    adapter_checkpoints[age] = save_dir
    previous_adapter_path = save_dir  # warm-start the next age from here

    # Optional: cross-age validation on neighboring age
    if i + 1 < len(age_order):
        neighbor = age_order[i + 1]
        print(f"Evaluating adapter for age {age} on neighbor age {neighbor}'s val set:")
        neighbor_val = dataset_texts_by_age[neighbor]["val"]
        eval_trainer = SFTTrainer(
            model=model,  # current adapter loaded
            tokenizer=raw_tokenizer,
            train_dataset=None,
            eval_dataset=neighbor_val,
            args=SFTConfig(
                dataset_text_field="text",
                per_device_train_batch_size=1,
                gradient_accumulation_steps=1,
                max_steps=1,
                learning_rate=2e-4,
                report_to="none",
            ),
        )
        print(eval_trainer.evaluate())


config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 60.12 MiB is free. Process 5655 has 14.68 GiB memory in use. Of the allocated memory 14.51 GiB is allocated by PyTorch, and 38.80 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

NO LONGER NEED THE BELOW

In [None]:
from unsloth.chat_templates import standardize_data_formats
dataset = standardize_data_formats(dataset)

In [None]:
dataset[100]

In [None]:
def formatting_prompts_func(examples):
   convos = examples["messages"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
   return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)

In [None]:
dataset[100]["text"]

Train the Model

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 60,
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "paged_adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
    ),
)

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

In [None]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

In [None]:
trainer_stats = trainer.train()

In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)
messages = [{
    "role": "user",
    "content": [{
        "type" : "text",
        "text" : "How do birds know where to fly when they migrate?",
    }]
}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
    tokenize = True,
    return_dict = True,
).to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens = 64, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
)
tokenizer.batch_decode(outputs)

In [None]:
model.save_pretrained("gemma-3n")  # Local saving
tokenizer.save_pretrained("gemma-3n")
# model.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving
# tokenizer.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving