In [None]:
from datasets import load_dataset
import torch
from PIL import Image

In [None]:
print("++++++Reading the Dataset++++++++++")
dataset = load_dataset("Anonymous/Final_idiom_all",split='train') #<======== Please change the dataset in csv file.

In [None]:
from huggingface_hub import login

login(token='Anonymous_xyzkajwjewkncjqnkj') #<======= PlaceHolder

In [None]:
system_message='''You are an polyglot, who are having exceptional linguistic and cultural domain knowledge. Also, you are an native speaker of hindi, bengali and thai.'''

In [None]:
def format_data(sample):
    return {
    "messages":[
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": system_message
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": sample['Actual idiom'],
                },
                {
                    "type": "image",
                    "image": sample["image"],
                },

            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": sample["Descriptive Meaning(Human Annotation)"]
                }
            ],
        },
    ],
    }

In [None]:
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

In [None]:
print(dataset)

In [None]:
# print(dataset[3132])
# print(len(dataset))

In [None]:
print("++++++Seperating the Dataset on Lingual Basis++++++++++")
dataset_hindi = dataset.select(range(0,1277))
dataset_thai = dataset.select(range(1382,3133))
bengali_indicies = list(range(1277,1382))+list(range(3133,3533))
dataset_bengali= dataset.select(bengali_indicies)

In [None]:
from datasets import Dataset,concatenate_datasets

In [None]:
def split_dataset(dataset1):
    train_testvalid = dataset1.train_test_split(test_size=0.3, seed=42)
    train_dataset = train_testvalid['train']
    temp_dataset = train_testvalid['test']

    # Step 2: Split the remaining 30% into 2/3 (validation) and 1/3 (test)
    # 2/3 of 30% = 20%, 1/3 of 30% = 10%
    val_test = temp_dataset.train_test_split(test_size=2/3, seed=42)
    val_dataset = val_test['train']    # 20%
    test_dataset = val_test['test']    # 10%

    return train_dataset, val_dataset, test_dataset

In [None]:
print("++++++Splitting the Dataset and Merging++++++++++")
train_dataset_hindi, val_dataset_hindi, test_dataset_hindi = split_dataset(dataset_hindi)
train_dataset_thai, val_dataset_thai, test_dataset_thai = split_dataset(dataset_thai)
train_dataset_bengali, val_dataset_bengali, test_dataset_bengali = split_dataset(dataset_bengali)


#Merging the dataset
train_dataset_final = concatenate_datasets([train_dataset_hindi,train_dataset_thai,train_dataset_bengali])
val_dataset_final = concatenate_datasets([val_dataset_hindi,val_dataset_thai,val_dataset_bengali])
test_dataset_final = concatenate_datasets([test_dataset_hindi,test_dataset_thai,test_dataset_bengali])

print(len(train_dataset_final),len(val_dataset_final),len(test_dataset_final))


# test_dataset_final.save_to_disk("test_dataset_Gemma3")

In [None]:
print("++++++Converting the Dataset to JSON format++++++++++")
train_dataset = [format_data(sample) for sample in train_dataset_final]
eval_dataset = [format_data(sample) for sample in val_dataset_final]
test_dataset = [format_data(sample) for sample in test_dataset_final]

In [None]:
print(train_dataset[2000])

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

print("+++++++++++Loading Model+++++++++++")
# Hugging Face model id
model_id = "google/gemma-3-12b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it")

In [None]:
print("++++++Configuring LoRA and peft++++++++++")
from peft import LoraConfig, get_peft_model

# Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="Gemma_3_Idiom_VL",     # directory to save and repository id
    num_train_epochs=3,                         # number of training epochs
    per_device_train_batch_size=4,              # batch size per device during training
    gradient_accumulation_steps=8,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=10,                            # log every 5 steps
    save_strategy="steps",
    save_steps=20,                    # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                          # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    push_to_hub=True,                           # push model to hub
    report_to="none",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
args.remove_unused_columns = False # important for collator

In [None]:
# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        # print(examples)
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch


In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

In [None]:
print("++++++Starting the training++++++++++")
trainer.train()

In [None]:
print("++++++Saving the Model++++++++++")
trainer.save_model(args.output_dir)


In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
