# 1. Setup and Installation

In [None]:
%%capture
!pip install unsloth
!pip install --upgrade transformers timm peft
!pip install jinja2==3.1.0 rarfile

# 2. Importing Libraries

In [None]:
import unsloth
from unsloth import FastVisionModel
from datasets import Dataset
from transformers import TrainingArguments, Trainer
from PIL import Image
from peft import PeftModel
from huggingface_hub import login, hf_hub_download, HfApi
import rarfile

import os
import json
import torch
import time
import logging
import pandas as pd
import numpy as np
from datetime import datetime
from getpass import getpass

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seed for reproducibility
def set_seed(seed=3407):    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(3407)

# 3. Configuration
This section sets up all the key configurations for our training run, including model IDs, output directories, and Hugging Face authentication.

In [None]:
# Model configuration
base_model_id = "unsloth/gemma-3n-e2b-it-unsloth-bnb-4bit"
previous_adapter_hub_id = "shreyansh24/gemma3n-e2b-cddm-finetune"

# Training configuration
output_dir = "./gemma-cddm-agrillava-finetuned"
hub_model_id = "shreyansh24/gemma3n-cddm-AgriLLava-finetune"
hub_repo_id = "shreyansh24/AgriVision-gemma3n"
hub_private_repo = False

# Secure token handling
hf_token = os.getenv("HF_TOKEN") or getpass("Enter Hugging Face token: ")
if not hf_token:
    raise ValueError("HF_TOKEN environment variable or token input required")

login(token=hf_token)

# System message
system_message = "You are a helpful AI assistant specialized in crop disease diagnosis. Provide concise and accurate information."
print("✅ Configuration and login complete.")

# 4. Data Preparation

In [None]:
print("\n--- 🚜 Preparing Agri-LLaVA Dataset ---")

# --- 4.1. Download and Extract Agri-LLaVA Data ---
agri_repo = "Agri-LLaVA-Anonymous/Agricultural_pests_and_diseases_instruction_tuning_data"
data_dir = "./agri_llava_data/"  # Changed to relative path
image_dir = os.path.join(data_dir, "Img")
os.makedirs(data_dir, exist_ok=True)

# Download function with error handling
def download_hf_file(filename, repo, local_dir, max_retries=3):
    """Safe download with retry logic"""
    path = os.path.join(local_dir, filename)
    for attempt in range(max_retries):
        try:
            if not os.path.exists(path):
                print(f"Downloading {filename}...")
                hf_hub_download(repo_id=repo, filename=filename, repo_type="dataset", local_dir=local_dir)
            else:
                print(f"{filename} already exists.")
            return path
        except Exception as e:
            if attempt == max_retries - 1:
                raise e
            print(f"Retry {attempt+1}/{max_retries} for {filename}")
            time.sleep(2)
    return path

# Download components
try:
    json_full_path = download_hf_file("agri_llava_instruction_tuning.json", agri_repo, data_dir)
    json_1k_path = download_hf_file("agri_llava_instruction_tuning_1k.json", agri_repo, data_dir)
    rar_path = download_hf_file("Img.rar", agri_repo, data_dir)
except Exception as e:
    logger.error(f"Failed to download files: {e}")
    raise

# Extract images with error handling
if not os.path.exists(image_dir):
    try:
        print(f"Extracting {rar_path}...")
        with rarfile.RarFile(rar_path) as rf:
            rf.extractall(data_dir)
        print("Extraction complete.")
    except Exception as e:
        logger.error(f"Failed to extract images: {e}")
        raise
else:
    print("Image directory already exists.")

# --- 4.2. Load and Merge JSON files ---
try:
    with open(json_full_path, 'r') as f: 
        data_full = json.load(f)
    with open(json_1k_path, 'r') as f: 
        data_1k = json.load(f)
except Exception as e:
    logger.error(f"Failed to load JSON files: {e}")
    raise

# Combine and remove duplicates, then convert to Hugging Face Dataset
df = pd.concat([pd.DataFrame(data_full), pd.DataFrame(data_1k)], ignore_index=True)
df.drop_duplicates(subset=['image'], inplace=True, keep='first')
raw_dataset = Dataset.from_pandas(df)
print(f"Loaded and merged JSONs. Total unique samples: {len(raw_dataset)}")

def sample_has_image(sample):
    """Checks if the image file for a sample exists on disk."""
    if not sample['image']:
        return False
    image_path = os.path.join(image_dir, sample['image'])
    return os.path.exists(image_path)

print("\nFiltering dataset to include only samples with valid images...")
image_dataset_raw = raw_dataset.filter(sample_has_image, num_proc=2)
print(f"Original dataset size: {len(raw_dataset)}. Filtered image-only dataset size: {len(image_dataset_raw)}.")

In [None]:
TARGET = (512, 512)  # or (256, 256)

def format_agri_data_for_trl(sample):
    """Format data for TRL training with error handling"""
    image_path = os.path.join(image_dir, sample['image'])
    loaded_image = None
    try:
        loaded_image = Image.open(image_path).convert("RGB")
        loaded_image = loaded_image.resize(TARGET, resample=Image.BILINEAR)
    except Exception as e:
        print(f"Warning: Could not load image {image_path}: {e}")
        return None

    trl_messages = [{"role": "system", "content": [{"type": "text", "text": system_message}]}]
    for i, turn in enumerate(sample['conversations']):
        role = "user" if turn["from"] == "human" else "assistant"
        text_content = turn["value"]
        content_parts = []
        if role == "user" and i == 0 and loaded_image:
            content_parts.append({"type": "image", "image": loaded_image})
            text_content = text_content.replace("<image>", "").strip()
        if text_content:
            content_parts.append({"type": "text", "text": text_content})
        if not content_parts:
            content_parts.append({"type": "text", "text": ""})
        trl_messages.append({"role": role, "content": content_parts})
    return {"messages": trl_messages}

# Apply formatting and build the final dataset list
print("\nApplying formatting and loading images directly into a list...")
train_dataset = [format_agri_data_for_trl(sample) for sample in image_dataset_raw]
train_dataset = [item for item in train_dataset if item is not None]
print(f"Final dataset size: {len(train_dataset)}")
print("Sample data structure:")
print(train_dataset[0])

# 5. Model Loading and QLoRA Configuration

In [None]:
print("\n--- ⚙️ Loading Model and Adapter (Final) ---")

try:
    model, processor = FastVisionModel.from_pretrained(
        model_name = base_model_id,
        max_seq_length = 2048,
        dtype = torch.float16,
        load_in_4bit = True,
        trust_remote_code=True,
        unsloth_force_compile = True,  
    )
except Exception as e:
    logger.error(f"Failed to load base model: {e}")
    raise

print(f"Applying adapter from: {previous_adapter_hub_id}...")
try:
    model = PeftModel.from_pretrained(model, previous_adapter_hub_id, is_trainable=True)
    print("✅ Previous adapter has been applied successfully.")
    model.print_trainable_parameters()
except Exception as e:
    logger.error(f"Failed to load adapter: {e}")
    raise

# 6. Training

In [None]:
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    """Extract image information from messages"""
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]
        for element in content:
            if isinstance(element, dict) and element.get("type") == "image":
                if "image" in element and isinstance(element["image"], Image.Image):
                    image_inputs.append(element["image"].convert("RGB"))
    return image_inputs

def format_messages_to_gemma_chat(messages):
    """Manually format messages into Gemma's chat format without Jinja2"""
    formatted = ""
    for message in messages:
        role = message["role"]
        content_parts = []
        
        # Process ALL content parts (both text AND image markers)
        for part in message["content"]:
            if part["type"] == "text":
                content_parts.append(part["text"])
            elif part["type"] == "image":
                # CRITICAL: Add the special image token that the model recognizes
                content_parts.append("<image>")
        
        content = " ".join(content_parts)
        
        if role == "system":
            formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n"
        elif role == "user":
            formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
        elif role == "assistant":
            formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
    return formatted.strip()

# The `collate_fn` is a crucial function that prepares batches of data for the trainer. 
# Its main job is to correctly format the text and images. 
# Most importantly, it creates the `labels` for training by masking out the user's questions and the system prompt.
# This teaches the model to predict only the assistant's responses, which is the standard and correct way to fine-tune a conversational AI.
def collate_fn(examples):
    """Custom collate function for training"""
    all_texts = []
    all_images = []
    for ex in examples:
        images_for_sample = process_vision_info(ex["messages"])
        all_images.append(images_for_sample)
        text = processor.apply_chat_template(
            ex["messages"], add_generation_prompt=False, tokenize=False
        ).strip()
        all_texts.append(text)
    
    batch = processor(
        text=all_texts,
        images=all_images,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048,
    )
    labels = batch["input_ids"].clone()
    model_prompt_start_tokens = processor.tokenizer.encode("model\n", add_special_tokens=False)
    end_of_turn_id = processor.tokenizer.eos_token_id

    for i in range(len(labels)):
        ignore_mask = torch.ones_like(labels[i], dtype=torch.bool)
        current_sequence_list = labels[i].tolist()
        for j in range(len(current_sequence_list)):
            if current_sequence_list[j : j + len(model_prompt_start_tokens)] == model_prompt_start_tokens:
                response_start_index = j + len(model_prompt_start_tokens)
                try:
                    response_end_index = current_sequence_list.index(end_of_turn_id, response_start_index)
                except ValueError:
                    response_end_index = len(current_sequence_list) - 1
                ignore_mask[response_start_index : response_end_index + 1] = False
        labels[i][ignore_mask] = -100
    batch["labels"] = labels
    return batch

# Training arguments
args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    max_steps=-1, 
    gradient_checkpointing=True,
    optim="adamw_8bit",
    logging_steps=5,
    save_strategy="steps",
    save_steps=250,
    save_total_limit=2,
    save_only_model=False,
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="checkpoint",
    hub_token=hf_token,
    hub_private_repo=hub_private_repo,
    report_to="tensorboard",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    remove_unused_columns=False,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

# Configure PyTorch optimizations
torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.recompile_limit = 100

# 7. Training Execution

In [None]:
print("\n--- 🚀 Starting Continued Fine-tuning ---")
try:
    trainer_stats = trainer.train()
    print("\n--- ✅ Training Complete ---")
    
    # Save training metrics
    if hasattr(trainer_stats, 'training_loss'):
        print(f"Final training loss: {trainer_stats.training_loss[-1]:.4f}")
        
except Exception as e:
    logger.error(f"Training failed: {e}")
    raise

# 8. Save the Final Model

In [None]:
print("\n--- 💾 Saving Final Model ---")

# Set up paths
adapter_path = output_dir
merged_model_path = "./gemma3n-agrivision-merged"

# Load the base model
print("Loading base model...")
try:
    model, processor = FastVisionModel.from_pretrained(
        model_name=base_model_id,
        max_seq_length=2048,
        dtype=torch.float16,
        load_in_4bit=True,
        trust_remote_code=True
    )
except Exception as e:
    logger.error(f"Failed to load base model: {e}")
    raise

# Load the adapter
print("Loading adapter...")
try:
    model = PeftModel.from_pretrained(model, adapter_path)
    print("Adapter loaded successfully!")
except Exception as e:
    logger.error(f"Failed to load adapter: {e}")
    raise

# Merge the adapter with the base model
print("Merging adapter with base model...")
try:
    merged_model = model.merge_and_unload()
    print("Model merged successfully!")
except Exception as e:
    logger.error(f"Failed to merge model: {e}")
    raise

# Save the merged model
print(f"Saving merged model to {merged_model_path}...")
os.makedirs(merged_model_path, exist_ok=True)

try:
    merged_model.save_pretrained(merged_model_path)
    processor.save_pretrained(merged_model_path)
    print("Merged model saved successfully!")
except Exception as e:
    logger.error(f"Failed to save merged model: {e}")
    raise

# Create repository and push to Hub
print(f"Creating repository: {hub_repo_id}")
try:
    api = HfApi()
    api.create_repo(
        repo_id=hub_repo_id,
        repo_type="model",
        private=hub_private_repo,
        exist_ok=True
    )
except Exception as e:
    logger.error(f"Failed to create repository: {e}")
    raise

print(f"Pushing merged model to Hub: {hub_repo_id}...")
try:
    api.upload_folder(
        folder_path=merged_model_path,
        repo_id=hub_repo_id,
        repo_type="model",
        commit_message="Upload merged Gemma3N-AgriVision model"
    )
    print("Model successfully pushed to Hub!")
except Exception as e:
    logger.error(f"Failed to push model to Hub: {e}")
    raise

# 9. Inference

In [None]:
print("\n--- 🤖 Model Inference ---")

# Install required packages
!pip install --upgrade transformers timm

import torch
from transformers import AutoModelForCausalLM, AutoProcessor

# Load model and processor
print(f"Loading model from: {hub_repo_id}")
try:
    model = AutoModelForCausalLM.from_pretrained(
        hub_repo_id,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    processor = AutoProcessor.from_pretrained(hub_repo_id)
    print("✅ Model and processor loaded successfully!")
except Exception as e:
    logger.error(f"Failed to load model for inference: {e}")
    raise

# Enhanced inference function
def run_inference(image_path=None, question="What's wrong with this crop?"):
    """Complete inference pipeline with optional image"""
    try:
        if image_path and os.path.exists(image_path):
            # Load and process image
            image = Image.open(image_path).convert("RGB")
            messages = [
                {"role": "user", "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": question}
                ]}
            ]
            inputs = processor(
                text=processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False),
                images=[image],
                return_tensors="pt",
            )
        else:
            # Text-only inference
            messages = [
                {"role": "user", "content": question}
            ]
            inputs = processor(
                text=processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False),
                return_tensors="pt",
            )
        
        # Generate response
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7)
        
        response = processor.decode(outputs[0], skip_special_tokens=True)
        return response.split("model\n")[-1].strip()
        
    except Exception as e:
        logger.error(f"Inference failed: {e}")
        return f"Error during inference: {str(e)}"

# Test inference
print("\n🧪 Testing inference...")
test_response = run_inference(question="Hii, can you help me identify this crop disease?")
print(f"Model's Reply: {test_response}")

# 10. Training Visualization (Optional)

In [None]:
import matplotlib.pyplot as plt

def plot_training_metrics(trainer_stats):
    """Plot training loss and learning rate"""
    if hasattr(trainer_stats, 'training_loss'):
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(trainer_stats.training_loss)
        plt.title('Training Loss')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        
        plt.subplot(1, 2, 2)
        if hasattr(trainer_stats, 'learning_rate'):
            plt.plot(trainer_stats.learning_rate)
            plt.title('Learning Rate')
            plt.xlabel('Steps')
            plt.ylabel('LR')
        
        plt.tight_layout()
        plt.savefig('training_metrics.png')
        plt.show()
        print("Training metrics saved as 'training_metrics.png'")

# Uncomment to generate training plots
# plot_training_metrics(trainer_stats)

print("\n🎉 Fine-tuning pipeline completed successfully!")