In [None]:
from datasets import load_dataset, Dataset, concatenate_datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, AutoTokenizer
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
import wandb

print("++++++Reading the Dataset++++++++++")
dataset = load_dataset("Anonymous/Final_idiom_all", split='train')

from huggingface_hub import login
login(token='anonymous_hjhijpjiovvugviipjpjjm')

system_message = '''You are an polyglot, who are having exceptional linguistic and cultural domain knowledge. Also, you are an native speaker of hindi, bengali and thai.'''

def format_data(sample):
    return {
        "messages":[
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": system_message},
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": sample['Actual idiom']},
                    {"type": "image", "image": sample["image"]},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": sample["Descriptive Meaning(Human Annotation)"]}
                ],
            },
        ],
    }

def process_vision_info(messages):
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]
        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

print("++++++Seperating the Dataset on Lingual Basis++++++++++")
dataset_hindi = dataset.select(range(0,1277))
dataset_thai = dataset.select(range(1382,3133))
bengali_indices = list(range(1277,1382)) + list(range(3133,3533))
dataset_bengali = dataset.select(bengali_indices)

def split_dataset(dataset1):
    train_testvalid = dataset1.train_test_split(test_size=0.3, seed=42)
    train_dataset = train_testvalid['train']
    temp_dataset = train_testvalid['test']
    val_test = temp_dataset.train_test_split(test_size=2/3, seed=42)
    val_dataset = val_test['train']
    test_dataset = val_test['test']
    return train_dataset, val_dataset, test_dataset

print("++++++Splitting the Dataset and Merging++++++++++")
train_dataset_hindi, val_dataset_hindi, test_dataset_hindi = split_dataset(dataset_hindi)
train_dataset_thai, val_dataset_thai, test_dataset_thai = split_dataset(dataset_thai)
train_dataset_bengali, val_dataset_bengali, test_dataset_bengali = split_dataset(dataset_bengali)

train_dataset_final = concatenate_datasets([train_dataset_hindi, train_dataset_thai, train_dataset_bengali])
val_dataset_final = concatenate_datasets([val_dataset_hindi, val_dataset_thai, val_dataset_bengali])
test_dataset_final = concatenate_datasets([test_dataset_hindi, test_dataset_thai, test_dataset_bengali])

print(len(train_dataset_final), len(val_dataset_final), len(test_dataset_final))

print("++++++Converting the Dataset to JSON format++++++++++")
train_dataset = [format_data(sample) for sample in train_dataset_final]
eval_dataset = [format_data(sample) for sample in val_dataset_final]
test_dataset = [format_data(sample) for sample in test_dataset_final]
print(train_dataset[2000])

print("+++++++++++Loading Model+++++++++++")
model_id = "llava-hf/llava-1.5-7b-hf"

if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

model_kwargs = dict(
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

vision_tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    extra_special_tokens={"image_token": "<image>", "boi_token": "<image_start>", "eoi_token": "<image_end>"}
)
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(model_id)

print("++++++Configuring LoRA and peft++++++++++")
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

args = SFTConfig(
    output_dir="Hypermoe_Llava_Idiom_VL",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=5,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="wandb",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)
args.remove_unused_columns = False

print("++++++connecting to wandb++++++++++")
wandb.init(
    project="Llava_Idiom_VL",
    name="Llava_Idiom_VL",
    config=args,
)

# HyperMoE implementation
class HyperMoE(nn.Module):
    def __init__(self, embed_dim, num_experts=4, hidden_dim=1024):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, embed_dim)
            ) for _ in range(num_experts)
        ])
        self.gate = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        weights = F.softmax(self.gate(x), dim=-1)  # (B, num_experts)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)  # (B, num_experts, embed_dim)
        expert_mean = expert_outputs.mean(dim=1, keepdim=True)  # (B, 1, embed_dim)
        enhanced = expert_outputs + 0.1 * (expert_mean - expert_outputs)  # (B, num_experts, embed_dim)
        fused = torch.einsum("be,bed->bd", weights, enhanced)  # (B, embed_dim)
        return fused

device = "cuda" if torch.cuda.is_available() else "cpu"

embed_dim = model.config.hidden_size if hasattr(model.config, "hidden_size") else model.config.text_config.hidden_size
num_moe = 3
moe_modules = nn.ModuleList([HyperMoE(embed_dim) for _ in range(num_moe)]).to(device)

def multi_model_hypermoe_fusion(features: torch.Tensor) -> torch.Tensor:
    fused_outputs = [moe(features) for moe in moe_modules]
    stacked = torch.stack(fused_outputs, dim=1)
    final_fusion = stacked.mean(dim=1)
    return final_fusion

def model_hook(module, inputs, outputs):
    if hasattr(outputs, "last_hidden_state"):
        h = outputs.last_hidden_state
    elif isinstance(outputs, tuple) and isinstance(outputs[0], torch.Tensor):
        h = outputs[0]
    else:
        return outputs
    cls = h[:, 0, :]
    fused = multi_model_hypermoe_fusion(cls)
    h[:, 0, :] = fused
    return outputs.__class__(**{**outputs.__dict__, "last_hidden_state": h})

model.base_model.register_forward_hook(model_hook)

def connector_hook(module, inputs, output):
    if isinstance(output, torch.Tensor) and output.dim() == 3:
        image_features = output
        cls_image_embed = image_features[:, 0, :]
        fused_image_embed = multi_model_hypermoe_fusion(cls_image_embed)
        image_features[:, 0, :] = fused_image_embed
        return image_features
    return output

from transformers.models.idefics3.modeling_idefics3 import Idefics3VisionTransformer

vision_module = None
for m in model.base_model.modules():
    if m.__class__.__name__ == "Idefics3VisionTransformer":
        vision_module = m
        break

if vision_module is not None:
    vision_module.register_forward_hook(connector_hook)

def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    image_token_id = [
        vision_tokenizer.convert_tokens_to_ids(
            vision_tokenizer.special_tokens_map["boi_token"]
        )
    ]
    labels[labels == processor.tokenizer.pad_token_id] = -100
    for img_tok in image_token_id:
        labels[labels == img_tok] = -100
    labels[labels == 262144] = -100
    batch["labels"] = labels
    return batch

train_dataset = train_dataset
eval_dataset = eval_dataset

# Attach moe_modules as attribute of model so params get optimized by SFTTrainer
model.moe_modules = moe_modules

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=processor,
    # do not pass optimizer manually
)

model.to(device)
moe_modules.to(device)
model.train()
moe_modules.train()

print("++++++Starting the training++++++++++")
trainer.train()
print("++++++Saving the Model++++++++++")
trainer.save_model(args.output_dir)