In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [2]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade transformers # Only for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


In [4]:
import torch
torch.cuda.empty_cache()

In [5]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth

BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"

model, tokenizer = FastModel.from_pretrained(
    model_name = BASE_MODEL_NAME, # Or "unsloth/gemma-3n-E2B-it"
    dtype = None, # None for auto detection
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


2025-08-04 23:25:58.320968: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754349958.668792      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754349958.767037      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.54.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.72G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

In [6]:
from transformers import TextStreamer
import gc
# Helper function for inference
def do_gemma_3n_inference(model, messages, max_new_tokens = 128):
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True, # Must add for generation
        tokenize = True,
        return_dict = True,
        return_tensors = "pt",
    ).to("cuda")
    _ = model.generate(
        **inputs,
        max_new_tokens = max_new_tokens,
        temperature = 1.0, top_p = 0.95, top_k = 64,
        streamer = TextStreamer(tokenizer, skip_prompt = True),
    )
    # Cleanup to reduce VRAM usage
    del inputs
    torch.cuda.empty_cache()
    gc.collect()

In [7]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

In [8]:
%%capture
pip install python-docx

In [9]:
# If not already installed in this environment; uncomment if needed
# !pip install python-docx datasets transformers unsloth

import os
import re
from collections import defaultdict
from docx import Document
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer


In [10]:
# === CONFIG ===
INPUT_DIR = "/kaggle/input/raw-docs"  # folder with your .docx files
AGE_GROUP_REGEX = re.compile(r"Ages?\s*([\d\-]+)", flags=re.IGNORECASE)
SPLIT_RATIOS = {"train": 0.8, "val": 0.1, "test": 0.1}
MAX_LENGTH = 256
IGNORE_INDEX = -100

# === Helpers ===
def infer_age_group_from_filename(path):
    fname = os.path.basename(path)
    m = AGE_GROUP_REGEX.search(fname)
    return m.group(1) if m else "unknown"

def is_topic_header(text):
    return bool(re.match(r"^[A-Z][a-zA-Z& ]+$", text)) and not text.endswith("?")

def clean_answer(text):
    text = text.replace("Sparky's Answer:", "").strip()
    return re.split(r"Wow! Fact:.*?Wow!$", text, flags=re.DOTALL)[0].strip()

def normalize_subject(subject):
    if not subject:
        return "unknown"
    s = subject.strip().lower()
    if "math" in s:
        return "math"
    if "science" in s:
        return "science"
    if "geography" in s:
        return "geography"
    if "history" in s:
        return "history"
    return s


In [11]:
def parse_docx_file(path):
    print(f"Parsing {path}")
    doc = Document(path)
    age_group = infer_age_group_from_filename(path)
    examples = []
    current_topic = None
    current_question = None
    current_answer_parts = []

    for para in doc.paragraphs:
        text = para.text.strip()
        if not text:
            continue
        if is_topic_header(text):
            current_topic = text
        elif text.endswith("?") and "Sparky's Answer" not in text:
            if current_question and current_answer_parts:
                cleaned = clean_answer(" ".join(current_answer_parts))
                examples.append({
                    "question": current_question,
                    "answer": cleaned,
                    "subject": normalize_subject(current_topic),
                    "age_group": age_group,
                    "format": "open_ended",
                })
                current_answer_parts = []
            current_question = text
        elif "Sparky's Answer:" in text or current_answer_parts:
            current_answer_parts.append(text)

    if current_question and current_answer_parts:
        cleaned = clean_answer(" ".join(current_answer_parts))
        examples.append({
            "question": current_question,
            "answer": cleaned,
            "subject": normalize_subject(current_topic),
            "age_group": age_group,
            "format": "open_ended",
        })

    print(f"  → extracted {len(examples)} examples")
    return examples

# Collect all examples by age group
by_age = defaultdict(list)
for root, _, files in os.walk(INPUT_DIR):
    for fname in files:
        if fname.lower().endswith(".docx"):
            path = os.path.join(root, fname)
            exs = parse_docx_file(path)
            for ex in exs:
                by_age[ex["age_group"]].append(ex)


Parsing /kaggle/input/raw-docs/Ages 14-15.docx
  → extracted 15 examples
Parsing /kaggle/input/raw-docs/Ages 11-13.docx
  → extracted 49 examples
Parsing /kaggle/input/raw-docs/Ages 5-7.docx
  → extracted 103 examples
Parsing /kaggle/input/raw-docs/Ages 8-10.docx
  → extracted 206 examples


In [12]:
from copy import deepcopy

In [13]:
# Build raw_datasets_by_age (stratified splits)
by_age_raw = defaultdict(list)
for root, _, files in os.walk(INPUT_DIR):
    for fname in files:
        if not fname.lower().endswith(".docx"):
            continue
        path = os.path.join(root, fname)
        examples = parse_docx_file(path)  # your existing parser
        for ex in examples:
            by_age_raw[ex["age_group"]].append(ex)

raw_datasets_by_age = {}
for age_group, examples in by_age_raw.items():
    ds = Dataset.from_list(examples)
    try:
        train_testval = ds.train_test_split(
            test_size=1 - SPLIT_RATIOS["train"],
            seed=42,
            stratify_by_column="subject",
        )
        val_test = train_testval["test"].train_test_split(
            test_size=SPLIT_RATIOS["test"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"]),
            seed=43,
            stratify_by_column="subject",
        )
        print(f"-- Stratified succesfully for Age group {age_group}.\n")
    except Exception:
        train_testval = ds.train_test_split(test_size=1 - SPLIT_RATIOS["train"], seed=42)
        val_test = train_testval["test"].train_test_split(
            test_size=SPLIT_RATIOS["test"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"]),
            seed=43,
        )
        print(f"-- Stratification failed for Age group {age_group}.\n")

    raw_datasets_by_age[age_group] = DatasetDict({
        "train": train_testval["train"],
        "val": val_test["train"],
        "test": val_test["test"],
    })

Parsing /kaggle/input/raw-docs/Ages 14-15.docx
  → extracted 15 examples
Parsing /kaggle/input/raw-docs/Ages 11-13.docx
  → extracted 49 examples
Parsing /kaggle/input/raw-docs/Ages 5-7.docx
  → extracted 103 examples
Parsing /kaggle/input/raw-docs/Ages 8-10.docx
  → extracted 206 examples
-- Stratification failed for Age group 14-15.

-- Stratification failed for Age group 11-13.

-- Stratification failed for Age group 5-7.

-- Stratification failed for Age group 8-10.



In [14]:
def system_prompt_for_age(age_group):
  if age_group == "2-4":
    system_prompt = ("You are a gentle and playful tutor for toddlers aged 2 to 4."
        "Use very short, simple words and sentences. Speak like a friendly character "
        "from a children's show. Use repetition, sound effects, and lots of excitement!")
  elif age_group == "5-7":
    system_prompt  = ("You are a cheerful and friendly tutor for children aged 5 to 7."
    " Use simple words and fun metaphors to explain things clearly. Be playful and keep "
    "answers short and exciting. You can use characters like 'sugar bugs' or 'energy monsters' "
    "to make it fun.")
  elif age_group == "8-10":
    system_prompt  = ("You are a smart and encouraging tutor for children aged 8 to 10."
    " Explain things using clear, age-appropriate language. Add interesting facts or "
    "comparisons that make learning fun. You can use simple science words and "
    "real-world examples.")
  elif age_group == "11-13":
    system_prompt  = ("You are a knowledgeable and relatable tutor for preteens aged 11 to 13."
    " Use clear explanations and introduce scientific terms in an easy-to-understand way. "
    "Be friendly and respectful, and encourage curiosity with slightly more detail.")
  elif age_group == "14-15":
    system_prompt  = ("You are an insightful and respectful tutor for teenagers aged 14 to 15. "
    "Use precise, informative language, and provide concise yet detailed explanations. "
    "Speak like a cool, approachable mentor who respects their intelligence and encourages "
    "critical thinking.")
  else:
    system_prompt = ("You are an insightful and respectful tutor for people above age 16.")
  return system_prompt


In [15]:
# Add "text" field for SFTTrainer
def add_text_field(ds_dict, age_group):
    def make_text(example):
        convo = [
            {"role": "system", "content": system_prompt_for_age(age_group)},
            {"role": "user", "content": example["question"]},
            {"role": "assistant", "content": example["answer"]},
        ]
        example["text"] = tokenizer.apply_chat_template(
            convo, tokenize=False, add_generation_prompt=True
        ).removeprefix("<bos>")
        return example
    return ds_dict.map(make_text, batched=False)

# Combine datasets from all age groups
from datasets import concatenate_datasets # Import concatenate_datasets
# Combine datasets from all age groups
combined_dataset = DatasetDict()
for age_group, raw_ds in raw_datasets_by_age.items():
    # Apply the modified add_text_field to each age group's dataset
    processed_ds = add_text_field(raw_ds, age_group)
    # Concatenate the splits
    for split in ["train", "val", "test"]:
        if split not in combined_dataset:
            combined_dataset[split] = processed_ds[split]
        else:
            combined_dataset[split] = concatenate_datasets([combined_dataset[split], processed_ds[split]])

ds = combined_dataset

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Map:   0%|          | 0/39 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/82 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Map:   0%|          | 0/164 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

In [16]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.741 GB.
12.592 GB of memory reserved.


In [17]:
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from trl import SFTTrainer, SFTConfig
import torch
import copy

# Load base model once (frozen except adapters)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Add LoRA adapters to the model
model = FastModel.get_peft_model(
    model,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    r=16,
    lora_alpha=32,
    lora_dropout=0,
    bias="none",
    use_cache = False,
    use_gradient_checkpointing=True,  # True or "unsloth" for very long context
    use_rslora=True,
    random_state=73
)

model.print_trainable_parameters()

# Create SFTTrainer for this age
trainer = SFTTrainer(
    model=model,
    tokenizer= tokenizer,
    train_dataset=ds["train"],
    eval_dataset=ds.get("val"),
    args=SFTConfig(
        dataset_text_field="text",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=10,
        max_steps=100,
        learning_rate=2e-4,
        logging_steps=10,
        optim="paged_adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=42,
        report_to="none",
    ),
)

from unsloth import unsloth_train
trainer_stats = unsloth_train(trainer) # trainer.train()

Unsloth: Making `model.base_model.model.model.language_model` require gradients
trainable params: 40,189,952 || all params: 7,890,168,144 || trainable%: 0.5094


Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/297 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/37 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 297 | Num Epochs = 2 | Total steps = 100
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 40,189,952 of 7,890,168,144 (0.51% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,12.7094
20,12.6595
30,4.3966
40,5.7787
50,6.388
60,6.3961
70,5.0178
80,4.5868
90,4.9678
100,5.0894


In [18]:
model_dir = f"./saved_model"
#cleanup_directory(model_dir)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
model.save_pretrained_merged(model_dir, tokenizer, save_method="merged_4bit_forced")

Unsloth: Merging LoRA weights into 4bit model...




Unsloth: Merging finished.
Unsloth: Found skipped modules: ['model.language_model.layers.0.altup.correction_coefs', 'model.language_model.layers.0.altup.prediction_coefs', 'model.language_model.layers.0.altup.modality_router', 'model.language_model.layers.0.per_layer_projection', 'model.language_model.layers.1.altup.correction_coefs', 'model.language_model.layers.1.altup.prediction_coefs', 'model.language_model.layers.1.altup.modality_router', 'model.language_model.layers.1.per_layer_projection', 'model.language_model.layers.2.altup.correction_coefs', 'model.language_model.layers.2.altup.prediction_coefs', 'model.language_model.layers.2.altup.modality_router', 'model.language_model.layers.2.per_layer_projection', 'model.language_model.layers.3.altup.correction_coefs', 'model.language_model.layers.3.altup.prediction_coefs', 'model.language_model.layers.3.altup.modality_router', 'model.language_model.layers.3.per_layer_projection', 'model.language_model.layers.4.altup.correction_coefs', 

In [19]:
GB_CONVERSION = 1024 ** 3
SECONDS_TO_MINUTES = 60

# Memory calculations
used_memory_gb = torch.cuda.max_memory_reserved() / GB_CONVERSION
used_memory_for_training_gb = used_memory_gb - start_gpu_memory
used_percentage = (used_memory_gb / max_memory) * 100
training_percentage = (used_memory_for_training_gb / max_memory) * 100

# Time calculations
runtime_seconds = trainer_stats.metrics['train_runtime']
runtime_minutes = runtime_seconds / SECONDS_TO_MINUTES

print("TRAINING STATISTICS")
print("=" * 50)
print(f"Training time: {runtime_seconds:.1f} seconds ({runtime_minutes:.2f} minutes)")
print(f"Peak memory usage: {used_memory_gb:.3f} GB ({used_percentage:.1f}% of max)")
print(f"Memory for training: {used_memory_for_training_gb:.3f} GB ({training_percentage:.1f}% of max)")
print("=" * 50)

TRAINING STATISTICS
Training time: 822.8 seconds (13.71 minutes)
Peak memory usage: 12.592 GB (85.4% of max)
Memory for training: -0.000 GB (-0.0% of max)


In [20]:
import shutil
folder_path = "./saved_model"
zip_path = f"{folder_path}.zip"
shutil.make_archive(folder_path, 'zip', folder_path)

from IPython.display import FileLink
FileLink(zip_path)

In [None]:
[{"role": "system",
                 "content": [{"type": "text", "text": model_instruction}]
               }] 

In [25]:
messages = [
    {
    "role": "system",
    "content": [{"type": "text", "text": system_prompt_for_age("5-7")}]
    }, 
    {
    "role": "user",
    "content": [{"type": "text", "text": "Why is the sky blue?"}]
    }
]

In [26]:
do_gemma_3n_inference(model, messages, 300)

Hey there, superstar! ✨

You wanna know why the sky is blue? It's a super cool magic trick of the sun! ☀️

The sun's light is actually made of *all* the colors of the rainbow! 🌈 But when the sunlight comes to Earth, it bumps into tiny little bits in the air, like super-duper tiny sugar bugs! 🍬

These sugar bugs are really good at scattering the blue light *everywhere*! It's like they're playing a game of blue light tag! 💙

That's why when we look up, we see *mostly* blue! Isn't that amazing? 😄



<end_of_turn>
