In [None]:
# ---- Full Fine-tuning with HyperMoE integration ----

from datasets import load_dataset, concatenate_datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import SFTConfig, SFTTrainer
import wandb

# Login to Hugging Face Hub
login(token='Anonymous_xvueicgiuewuhouhqwhixh')

# Define system message used in dataset formatting
system_message = '''You are an polyglot, who are having exceptional linguistic and cultural domain knowledge. Also, you are a native speaker of Hindi, Bengali and Thai.'''

# Function to convert samples to conversation format
def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": sample['Actual idiom']},
                    {"type": "image", "image": sample["image"]},
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["Descriptive Meaning(Human Annotation)"]}],
            },
        ],
    }

# Helper to extract image objects from messages
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]
        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                image = element.get("image", element)
                image_inputs.append(image.convert("RGB"))
    return image_inputs

print("++++++Reading the Dataset++++++++++")
dataset = load_dataset("Anonymous/Final_idiom_all", split='train')

print("++++++Separating the Dataset on Lingual Basis++++++++++")
dataset_hindi = dataset.select(range(0,1277))
dataset_thai = dataset.select(range(1382,3133))
bengali_indices = list(range(1277,1382)) + list(range(3133,3533))
dataset_bengali = dataset.select(bengali_indices)

def split_dataset(dataset1):
    train_testvalid = dataset1.train_test_split(test_size=0.3, seed=42)
    train_dataset = train_testvalid['train']
    temp_dataset = train_testvalid['test']
    val_test = temp_dataset.train_test_split(test_size=2/3, seed=42)
    val_dataset = val_test['train']    # 20%
    test_dataset = val_test['test']    # 10%
    return train_dataset, val_dataset, test_dataset

print("++++++Splitting the Dataset and Merging++++++++++")
train_dataset_hindi, val_dataset_hindi, test_dataset_hindi = split_dataset(dataset_hindi)
train_dataset_thai, val_dataset_thai, test_dataset_thai = split_dataset(dataset_thai)
train_dataset_bengali, val_dataset_bengali, test_dataset_bengali = split_dataset(dataset_bengali)

train_dataset_final = concatenate_datasets([train_dataset_hindi, train_dataset_thai, train_dataset_bengali])
val_dataset_final = concatenate_datasets([val_dataset_hindi, val_dataset_thai, val_dataset_bengali])
test_dataset_final = concatenate_datasets([test_dataset_hindi, test_dataset_thai, test_dataset_bengali])

print(f"Train size: {len(train_dataset_final)}, Val size: {len(val_dataset_final)}, Test size: {len(test_dataset_final)}")

print("++++++Converting the Dataset to JSON format++++++++++")
train_dataset = [format_data(sample) for sample in train_dataset_final]
eval_dataset = [format_data(sample) for sample in val_dataset_final]
test_dataset = [format_data(sample) for sample in test_dataset_final]

# --------- Model & Processor Loading ---------
print("+++++++++++Loading Model+++++++++++")
model_id = "google/gemma-3-12b-pt"

if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_storage=torch.bfloat16,
    ),
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it")

# --------- HyperMoE Logic ---------
class HyperMoE(nn.Module):
    def __init__(self, embed_dim, num_experts=4, hidden_dim=1024):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, embed_dim)
            ) for _ in range(num_experts)
        ])
        self.gate = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        weights = F.softmax(self.gate(x), dim=-1)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        expert_mean = expert_outputs.mean(dim=1, keepdim=True)
        enhanced = expert_outputs + 0.1 * (expert_mean - expert_outputs)
        fused = torch.einsum("be,bed->bd", weights, enhanced)
        return fused

# Get embedding dimension from model config dynamically
def get_embed_dim(model):
    if hasattr(model.base_model.config, "text_config") and hasattr(model.base_model.config.text_config, "hidden_size"):
        return model.base_model.config.text_config.hidden_size
    elif hasattr(model.base_model.config, "hidden_size"):
        return model.base_model.config.hidden_size
    else:
        raise ValueError("Cannot find embedding dimension in the model config")

embed_dim = get_embed_dim(model)
num_moe = 3
moe_modules = nn.ModuleList([HyperMoE(embed_dim) for _ in range(num_moe)]).to("cuda")

def multi_model_hypermoe_fusion(features: torch.Tensor) -> torch.Tensor:
    fused_outputs = [moe(features) for moe in moe_modules]  # list of (B, embed_dim)
    stacked = torch.stack(fused_outputs, dim=1)             # (B, num_moe, embed_dim)
    final_fusion = stacked.mean(dim=1)                      # average to (B, embed_dim)
    return final_fusion

def model_hook(module, inputs, outputs):
    if hasattr(outputs, "last_hidden_state"):
        h = outputs.last_hidden_state
    elif isinstance(outputs, tuple) and isinstance(outputs[0], torch.Tensor):
        h = outputs[0]
    else:
        return outputs
    cls = h[:, 0, :]
    fused = multi_model_hypermoe_fusion(cls)
    # Create a new tensor instead of modifying in-place
    h_new = h.clone()
    h_new[:, 0, :] = fused
    # Return a copy of the same class as outputs, but with the modified hidden state
    return outputs.__class__(**{**outputs.__dict__, "last_hidden_state": h_new})

def connector_hook(module, inputs, output):
    if isinstance(output, torch.Tensor) and output.dim() == 3:
        image_features = output
        cls_image_embed = image_features[:, 0, :]
        fused_image_embed = multi_model_hypermoe_fusion(cls_image_embed)
        image_features_new = image_features.clone()
        image_features_new[:, 0, :] = fused_image_embed
        return image_features_new
    return output


# Register hooks to the model
model.base_model.model.register_forward_hook(model_hook)

# Try to find the vision module to register connector hook
# NOTE: For Gemma3 you may need to verify the exact class name of vision encoder
vision_module = None
for m in model.base_model.model.modules():
    # Adapt this to Gemma3â€™s vision transformer class if different
    if m.__class__.__name__ == "Idefics3VisionTransformer":
        vision_module = m
        break

if vision_module is not None:
    vision_module.register_forward_hook(connector_hook)
else:
    print("Warning: Vision module hook not registered; adjust class name for Gemma3 vision encoder.")

# --------- LoRA Config ---------
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
)

# --------- SFT Configuration ---------
args = SFTConfig(
    output_dir="Gemma_3_Idiom_hypermoe_VL",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=10,
    save_strategy="steps",
    save_steps=20,
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="wandb",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)
args.remove_unused_columns = False

print("++++++connecting to wandb++++++++++")
wandb.init(
    project="Gemma_3_Idiom_VL",
    name="Gemma_3_Idiom_VL",
    config=args,
)

# --------- Data Collator for SFTTrainer ---------
def collate_fn(examples):
    texts, images = [], []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(example["messages"], add_generation_prompt=False, tokenize=False)
        texts.append(text.strip())
        images.append(image_inputs)
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()

    # Mask pad tokens and image tokens
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map.get("boi_token", ""))
    labels[labels == processor.tokenizer.pad_token_id] = -100
    if image_token_id != "":
        labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100  # Just in case

    batch["labels"] = labels
    return batch

# --------- Trainer initialization ---------
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

# Move moe modules to device and set train mode
moe_modules.to("cuda")
moe_modules.train()
for p in moe_modules.parameters():
    p.requires_grad = True

print("++++++Starting the training++++++++++")
trainer.train()

print("++++++Saving the Model++++++++++")
trainer.save_model(args.output_dir)