In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from tqdm import tqdm
from bert_score import score
import pandas as pd
import evaluate
from peft import PeftModel
import os

from transformers import Idefics3ForConditionalGeneration, AutoProcessor

os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Multi-Model HyperMoE
class HyperMoE(nn.Module):
    def __init__(self, embed_dim, num_experts=4, hidden_dim=1024):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, embed_dim)
            ) for _ in range(num_experts)
        ])
        self.gate = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        # x: (B, embed_dim)
        weights = F.softmax(self.gate(x), dim=-1)  # (B, num_experts)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)  # (B, num_experts, embed_dim)
        expert_mean = expert_outputs.mean(dim=1, keepdim=True)  # (B, 1, embed_dim)
        enhanced = expert_outputs + 0.1 * (expert_mean - expert_outputs)  # (B, num_experts, embed_dim)
        fused = torch.einsum("be,bed->bd", weights, enhanced)  # (B, embed_dim)
        return fused

# --- Load base model and processor ---
model_id = "HuggingFaceTB/SmolVLM-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2"
)
# Load LoRA adapter
model = PeftModel.from_pretrained(model, "Anonymous/SmolVLM_Idiom_VL", adapter_name="default")
model.gradient_checkpointing_enable()


# Determine embedding dimension for hypermoe
embed_dim = model.base_model.config.text_config.hidden_size

# Create multiple HyperMoE modules
num_moe = 3
moe_modules = nn.ModuleList([HyperMoE(embed_dim) for _ in range(num_moe)]).to(DEVICE)

def multi_model_hypermoe_fusion(features: torch.Tensor) -> torch.Tensor:
    """
    Pass the features (batch of [embed_dim] vectors) through multiple HyperMoE modules and fuse their outputs.
    """
    # features shape: (B, embed_dim)
    fused_outputs = [moe(features) for moe in moe_modules]  # list of (B, embed_dim) from each HyperMoE
    stacked = torch.stack(fused_outputs, dim=1)             # shape (B, num_moe, embed_dim)
    final_fusion = stacked.mean(dim=1)                      # simple averaging across the num_moe outputs -> (B, embed_dim)
    return final_fusion

# Hook to modify text model output (the combined vision-language model output)
def model_hook(module, inputs, outputs):
    if hasattr(outputs, "last_hidden_state"):
        h = outputs.last_hidden_state
    elif isinstance(outputs, tuple) and isinstance(outputs[0], torch.Tensor):
        h = outputs[0]
    else:
        return outputs
    cls = h[:, 0, :]
    fused = multi_model_hypermoe_fusion(cls)
    h[:, 0, :] = fused
    return outputs.__class__(**{**outputs.__dict__, "last_hidden_state": h})


# Hook to modify image encoder output (after connector) before merging into text embeddings
def connector_hook(module, inputs, output):
    """
    Forward hook for the vision-to-text connector output. Modifies the first token of image features using HyperMoE.
    """
    # output from connector is expected to be a tensor of shape (N, image_seq_len, embed_dim)
    if isinstance(output, torch.Tensor) and output.dim() == 3:
        image_features = output  # shape (N, seq_len_img, embed_dim)
        # Apply HyperMoE to the first token of each image feature sequence (e.g., CLS token if present)
        cls_image_embed = image_features[:, 0, :]           # (N, embed_dim)
        fused_image_embed = multi_model_hypermoe_fusion(cls_image_embed)  # fuse image CLS embeddings
        image_features[:, 0, :] = fused_image_embed         # replace first token features with fused embeddings
        return image_features
    # If output is not a tensor or not 3D, return it unchanged
    return output

# attach text-fusion hook
model.base_model.model.register_forward_hook(model_hook)

# dynamically find and hook the vision encoder
from transformers.models.idefics3.modeling_idefics3 import Idefics3VisionTransformer
vision_module = next(
    m for m in model.base_model.model.modules()
    if isinstance(m, Idefics3VisionTransformer)
)
vision_module.register_forward_hook(connector_hook)

# Load fine-tuning dataset
data = load_from_disk("idiom_vaibhav")
train_dataset = data["train"]
val_dataset = data["validation"]

# System message for prompting
system_message = (
    "You are a polyglot assistant with exceptional linguistic and cultural knowledge. "
    "You are a native speaker of Hindi, Bengali, and Thai."
)

# Collate function for DataLoader
def collate_fn(examples):
    images = [ex["image"] for ex in examples]
    prompts = []
    for ex in examples:
        system_msg = {"role": "system", "content": [{"type": "text", "text": system_message}]}
        user_msg = {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": ex["Actual idiom"]}]}
        assistant_msg = {"role": "assistant", "content": [{"type": "text", "text": ex["Descriptive Meaning(Human Annotation)"]}]}
        # Create prompt including answer
        messages = [system_msg, user_msg, assistant_msg]
        prompt = processor.apply_chat_template(messages, add_generation_prompt=False)
        prompts.append(prompt)
    # Process inputs with padding
    inputs = processor(text=prompts, images=[[img] for img in images], padding=True, return_tensors="pt")
    # Prepare labels by cloning input_ids
    labels = inputs.input_ids.clone()
    # Mask padding and image tokens
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == model.config.image_token_id] = -100
    inputs["labels"] = labels
    return inputs

# DataLoaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = torch.utils.data.DataLoader(val_dataset,   batch_size=4, shuffle=False, collate_fn=collate_fn)

# Move model and MoE modules to device
model.to(DEVICE)
moe_modules.to(DEVICE)

# Ensure all parameters trainable (full fine-tuning)
model.train()
for p in model.parameters():
    p.requires_grad = True
moe_modules.train()

optimizer = torch.optim.AdamW(list(model.parameters()) + list(moe_modules.parameters()), lr=2e-5)

# --- Training loop ---
from torch.cuda.amp import autocast

for epoch in range(1):  # 1 epoch
    model.train()
    epoch_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with autocast(dtype=torch.bfloat16):
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch.get("attention_mask", None),
                pixel_values=batch.get("pixel_values", None),
                pixel_attention_mask=batch.get("pixel_attention_mask", None),
                labels=batch["labels"]
            )
            loss = outputs.loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1} completed. Avg Loss: {epoch_loss/len(train_loader):.4f}")
