In [1]:
!nvidia-smi

Sat Sep 13 12:26:28 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   40C    P0             53W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
!pip install transformers dataset peft accelerate



In [2]:
import json

input_file = 'thirukkural_plain_text.txt'
output_file = 'thirukkural_qa_pairs.jsonl'

dataset = []

with open(input_file, 'r', encoding='utf-8') as f:
    content = f.read().split('<kural_id>:')
    
    for entry in content[1:]:
        lines = entry.strip().split('\n')
        kural_id = lines[0].strip()
        explanation = next((line.split(':', 1)[1].strip() for line in lines if line.startswith('Explanation:')), '')
        english = next((line.split(':', 1)[1].strip() for line in lines if line.startswith('English:')), '')
        
        output_text = explanation or english
        
        dataset.append({
            "input": f"What is Kural {kural_id}?",
            "output": output_text
        })

with open(output_file, 'w', encoding='utf-8') as out_f:
    for sample in dataset:
        json.dump(sample, out_f)
        out_f.write('\n')

print(f"Supervised dataset created with {len(dataset)} samples at {output_file}")

Supervised dataset created with 1331 samples at thirukkural_qa_pairs.jsonl


In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

tokenizer = AutoTokenizer.from_pretrained("sft_tuned_token")
model = AutoModelForCausalLM.from_pretrained("sft_tuned_model")

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    inference_mode=False,
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1
)
model = get_peft_model(model, lora_config)

dataset = load_dataset('json', data_files={'train': 'thirukkural_qa_pairs.jsonl'})

def preprocess(examples):
    inputs = ["Question: " + q for q in examples['input']]
    model_inputs = tokenizer(
        inputs,
        max_length=512,
        truncation=True,
        padding="max_length"
    )
    labels = tokenizer(
        examples['output'],
        max_length=512,
        truncation=True,
        padding="max_length"
    )["input_ids"]
    model_inputs["labels"] = labels
    return model_inputs

tokenized_dataset = dataset.map(preprocess, batched=True)

training_args = TrainingArguments(
    output_dir="./gemma3-thirukkural-sft",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=15,
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,
    learning_rate=5e-5,
    report_to="tensorboard"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
)

trainer.train()
trainer.save_model("./gemma3-thirukkural-sft")
tokenizer.save_pretrained("./gemma3-thirukkural-sft")


It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


Step,Training Loss
10,52.7627
20,52.6759
30,52.6548
40,52.6138
50,52.5792
60,52.5673
70,52.5965
80,52.6093
90,52.5357
100,52.562


('./gemma3-thirukkural-sft/tokenizer_config.json',
 './gemma3-thirukkural-sft/special_tokens_map.json',
 './gemma3-thirukkural-sft/tokenizer.json')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("./gemma3-thirukkural-sft")
tokenizer = AutoTokenizer.from_pretrained("./gemma3-thirukkural-sft")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def get_kural_explanation(kural_id):
    prompt = f"Question: What is Kural {kural_id}?\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    output_ids = model.generate(
        inputs["input_ids"],
        max_length=200,
        num_beams=5,
        no_repeat_ngram_size=2,
        early_stopping=True
    )

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

kural_1 = get_kural_explanation(1)
print(f"Kural 1 Explanation:\n{kural_1}")
