## 1) Install & restart

In [None]:

!pip -q install unsloth datasets accelerate bitsandbytes peft transformers trl huggingface_hub


## 2) Imports, login & basic config

In [None]:
from datasets import load_dataset
from unsloth import (
    FastLanguageModel,
    UnslothTrainer,
    UnslothTrainingArguments,
    is_bfloat16_supported,
)
from huggingface_hub import login
import torch, os, json

# === Login to Hugging Face (paste your token when prompted) ===
# If you've already saved it to Colab secrets or env, you can skip logging in again.
try:
    token = os.environ.get("HF_TOKEN", None)
    if token:
        login(token=token, add_to_git_credential=True)
    else:
        login()  # will prompt in Colab
except Exception as e:
    print("Login skipped or failed:", e)

# === Config ===
DATASET_NAME = "Hindi-data-hub/odaigen_hindi_pre_trained_sp"
MODEL_NAME   = "unsloth/llama-3-8b-bnb-4bit"   # Use a BASE model (not Instruct) for CPT
MAX_SEQ_LEN  = 2048
LOAD_4BIT    = True
DTYPE        = None  # auto-pick bf16/fp16
OUTPUT_DIR   = "cpt_lang_hi"
SEED         = 42
torch.manual_seed(SEED)

## 3) Load data set from Hugging Face Hub

In [None]:
# Try loading the dataset. If access is restricted, make sure you've accepted conditions on the dataset page.
# We'll first peek the dataset to discover column names and size.
from datasets import get_dataset_config_names, get_dataset_split_names

print("Checking dataset configs and splits...")
try:
    configs = get_dataset_config_names(DATASET_NAME, token=True)
except Exception as e:
    print("Could not list configs (may be gated). Proceeding with default config. Error:", e)
    configs = [None]

print("Configs:", configs)

split_names = []
for cfg in configs:
    try:
        splits = get_dataset_split_names(DATASET_NAME, config_name=cfg, token=True)
        split_names.append((cfg, splits))
    except Exception as e:
        split_names.append((cfg, ["train"]))
print("Splits:", split_names)

# Load only 10% for training; use the next ~2% for eval if no validation split exists.
try:
    if configs and configs[0] is not None:
        ds_train = load_dataset(DATASET_NAME, configs[0], split="train[:10%]", token=True)
    else:
        ds_train = load_dataset(DATASET_NAME, split="train[:10%]", token=True)
except Exception as e:
    print("Direct 'train[:10%]' split failed; trying explicit slicing fallback. Error:", e)
    if configs and configs[0] is not None:
        ds_train = load_dataset(DATASET_NAME, configs[0], split="train[:10%]", token=True)
        ds_eval  = load_dataset(DATASET_NAME, configs[0], split="train[10%:12%]", token=True)
    else:
        ds_train = load_dataset(DATASET_NAME, split="train[:10%]", token=True)
        ds_eval  = load_dataset(DATASET_NAME, split="train[10%:12%]", token=True)
else:
    # Make eval split if not created above
    try:
        ds_eval = load_dataset(DATASET_NAME, split="validation", token=True)
    except Exception:
        # If no validation split, carve ~2% out of the 10% train subset
        ds_eval  = ds_train.shard(num_shards=50, index=0)  # ~2% of the 10%
        ds_train = ds_train.shard(num_shards=50, index=1)

print(ds_train)
print(ds_eval)

# Detect the primary text column
text_column = None
for cand in ["text", "sentence", "content", "raw_text", "document", "data"]:
    if cand in ds_train.column_names:
        text_column = cand
        break

if text_column is None:
    # Heuristic: pick the first string column
    for name in ds_train.column_names:
        if isinstance(ds_train[0][name], str):
            text_column = name
            break

if text_column is None:
    raise ValueError("Could not find a text column. Please inspect ds_train.column_names and set one.")

print("Using text column:", text_column)

## 4) Tokenizer/model (4‑bit)

In [None]:

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LEN,
    dtype=DTYPE,
    load_in_4bit=LOAD_4BIT,
)


## 5) Prepare tokenized data set

In [None]:

def tok_fn(batch):
    return tokenizer(
        batch[text_column],
        truncation=True,
        max_length=MAX_SEQ_LEN,
        return_attention_mask=False,
    )

tokenized_train = ds_train.map(tok_fn, batched=True, remove_columns=[c for c in ds_train.column_names if c != text_column])
tokenized_eval  = ds_eval.map(tok_fn, batched=True, remove_columns=[c for c in ds_eval.column_names if c != text_column])


## 6) QLoRA (incl. embeddings & lm_head) and train (CPT)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    target_modules=[
        "q_proj","k_proj","v_proj","o_proj",
        "gate_proj","up_proj","down_proj"
    ],
)

args = UnslothTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=1,                 # increase for more data
    per_device_train_batch_size=1,      # tune to your GPU
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,
    learning_rate=5e-5,                 # main LR for LoRA blocks
    embedding_learning_rate=5e-6,       # smaller for embed/lm_head
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    gradient_checkpointing=True,
    optim="adamw_8bit",
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),
    logging_steps=50,
    eval_strategy="steps", # Changed from evaluation_strategy to eval_strategy
    eval_steps=1000,
    save_steps=1000,
    save_total_limit=2,
    max_grad_norm=1.0,
)

#To save memory
model.config.use_cache = False

trainer = UnslothTrainer(
    model=model,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    args=args,
    tokenizer=tokenizer, # Explicitly pass the tokenizer
    packing=True,
)

trainer.train()

## 7) Save LoRA and (optional) merged weights

In [None]:

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)

model.save_pretrained(f"{OUTPUT_DIR}/lora")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/lora")

# Merge LoRA into a single checkpoint (optional)
model.save_pretrained(f"{OUTPUT_DIR}/merged", merge=True)
tokenizer.save_pretrained(f"{OUTPUT_DIR}/merged")

print("Saved LoRA to", f"{OUTPUT_DIR}/lora")
print("Saved merged to", f"{OUTPUT_DIR}/merged")


## 8) Quick perplexity & generation sanity check

In [None]:
import math
from torch.utils.data import DataLoader
from tqdm import tqdm

# Use the trainer's built-in evaluate method which handles packing
metrics = trainer.evaluate()
pp = math.exp(metrics["eval_loss"])
print("Eval Perplexity:", pp)

# Simple generation test (raw text — no chat template during CPT)
prompt = "हिंदी में एक छोटा अनुच्छेद लिखें जो इस मॉडल की समझ का परीक्षण करे।"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
out = model.generate(
    **inputs,
    max_new_tokens=128,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))