<a href="https://colab.research.google.com/github/prikshitkverma/Gemma_fine_tuning/blob/main/gemini_1b_it_val.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================
# 1. INSTALL LIBRARIES
# ============================================
#!pip uninstall -y torch torchvision torchaudio transformers trl accelerate datasets huggingface_hub
#!pip install -q torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
#!pip install -q transformers==4.45.2 trl==0.11.6 accelerate==1.1.1 datasets==3.1.0 huggingface_hub==0.28.1 sentencepiece pyarrow==18.0.0 evaluate tensorboard
!pip uninstall -y torch torchvision torchaudio transformers trl accelerate datasets huggingface_hub
!pip install -q torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
!pip install -q transformers==4.45.2 trl==0.11.6 accelerate==1.1.1 datasets==3.1.0 huggingface_hub==0.28.1 sentencepiece pyarrow==18.0.0
!pip install -q datasets trl sentencepiece huggingface_hub
!pip install evaluate
# ============================================
# 2. SETUP AND AUTHENTICATION
# ============================================
import torch
import os
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    pipeline,
)
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login
import evaluate

# Login to Hugging Face
HF_TOKEN = "hf_token_here"
login(token=HF_TOKEN)

# ============================================
# 3. CONFIGURE MODEL AND DIRECTORIES
# ============================================
base_model = "google/gemma-3-1b-it"
output_dir = "./gemma-natural-farming-qa"

# ============================================
# 4. LOAD AND PREPARE THE DATASET
# ============================================
data_file = "/content/natural_farming_dataset_perplexity.jsonl"
dataset = load_dataset("json", data_files=data_file, split="train")

def format_dataset(sample):
    return {
        "messages": [
            {"role": "user", "content": sample["question"]},
            {"role": "assistant", "content": sample["answer"]}
        ]
    }

formatted_dataset = dataset.map(format_dataset, remove_columns=dataset.features)

# Split into 80% train, 10% validation, 10% test
split_dataset = formatted_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)
val_test_split = split_dataset["test"].train_test_split(test_size=0.5, shuffle=True, seed=42)

dataset_dict = {
    "train": split_dataset["train"],
    "validation": val_test_split["train"],
    "test": val_test_split["test"]
}

print("‚úÖ Dataset Split Summary:")
print(f"Train: {len(dataset_dict['train'])} | Validation: {len(dataset_dict['validation'])} | Test: {len(dataset_dict['test'])}")
print("\nExample data sample:")
print(dataset_dict["train"][0]["messages"])

# ============================================
# 5. LOAD MODEL AND TOKENIZER
# ============================================
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)
print(f"‚úÖ Model loaded on {model.device} | dtype: {model.dtype}")



In [None]:
# ============================================
# 6. CONFIGURE THE TRAINING PROCESS
# ============================================
sft_config = SFTConfig(
    output_dir=output_dir,
    num_train_epochs=3,
    # max_seq_length=256, # Removed max_seq_length
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    optim="adamw_torch_fused",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    bf16=torch.cuda.is_bf16_supported(),
    fp16=False,
    push_to_hub=False,
    report_to="tensorboard"
)

# ============================================
# 7. TRAIN AND VALIDATE MODEL
# ============================================
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["validation"],
    processing_class=tokenizer,
)

print("üöÄ Starting fine-tuning...")
trainer.train()

print("üíæ Saving final model...")
trainer.save_model(output_dir)



In [None]:
!pip install rouge_score
# ============================================
# 8. VALIDATE MODEL PERFORMANCE
# ============================================
print("\nüîç Validating model on validation set...")
model.eval()

# Use a simple text-generation pipeline
val_pipe = pipeline("text-generation", model=output_dir, tokenizer=tokenizer)

bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")

generated_texts = []
reference_texts = []

for i, sample in enumerate(dataset_dict["validation"]):
    user_msg = [{"role": "user", "content": sample["messages"][0]["content"]}]
    prompt = tokenizer.apply_chat_template(user_msg, tokenize=False, add_generation_prompt=True)
    output = val_pipe(prompt, max_new_tokens=128, num_return_sequences=1)[0]["generated_text"][len(prompt):].strip()
    generated_texts.append(output)
    reference_texts.append(sample["messages"][1]["content"])
    if i < 3:
        print(f"\nExample {i+1}:")
        print(f"Q: {sample['messages'][0]['content']}")
        print(f"Model: {output}")
        print(f"Ref: {sample['messages'][1]['content']}")

# Compute metrics
bleu_score = bleu.compute(predictions=generated_texts, references=reference_texts)
rouge_score = rouge.compute(predictions=generated_texts, references=reference_texts)

print("\nüìä Validation Metrics:")
print(f"BLEU Score: {bleu_score['bleu']:.4f}")
print(f"ROUGE-L: {rouge_score['rougeL']:.4f}")

In [None]:
!pip install bert_score
bertscore = evaluate.load("bertscore")
results = bertscore.compute(predictions=generated_texts, references=reference_texts, lang="en")
print(sum(results["f1"]) / len(results["f1"]))

In [None]:

# ============================================
# 9. TEST INTERACTIVELY
# ============================================
print("\n--- Interactive Testing ---")
test_pipe = pipeline("text-generation", model=output_dir, tokenizer=tokenizer)

while True:
    question = input("\nEnter your question (or type 'exit' to quit): ").strip()
    if question.lower() == "exit":
        print("üëã Exiting...")
        break

    messages = [{"role": "user", "content": question}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = test_pipe(prompt, max_new_tokens=256)
    print(f"\nüß† Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import shutil, os

# Path to your best checkpoint
best_ckpt = "/content/gemma-natural-farming-qa/checkpoint-1200"

# Where to save clean offline model
save_dir = "./best_model"

# Load model and tokenizer from the checkpoint
model = AutoModelForCausalLM.from_pretrained(best_ckpt)
tokenizer = AutoTokenizer.from_pretrained(best_ckpt)

# Save only what's needed for inference
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

print(f"‚úÖ Saved minimal offline model to {save_dir}")


In [None]:
from google.colab import auth
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
import shutil

# 1Ô∏è‚É£ Authenticate
auth.authenticate_user()
drive_service = build('drive', 'v3')

# 2Ô∏è‚É£ Zip your model folder
local_folder = "/content/best_model"
zip_path = "/content/best_model.zip"
shutil.make_archive(zip_path.replace('.zip',''), 'zip', local_folder)
print(f"‚úÖ Zipped folder to {zip_path}")

# 3Ô∏è‚É£ Upload zip to specific Drive folder by ID
folder_id = "1v7wyPcLmawtlKgFsOqoMdcB8qMZ7fqPj"
file_metadata = {
    'name': 'best_model.zip',
    'parents': [folder_id]
}
media = MediaFileUpload(zip_path, mimetype='application/zip')
file = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()

print(f"‚úÖ Uploaded zip to Drive folder! File ID: {file.get('id')}")
