In [None]:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from dotenv import load_dotenv
import os

In [None]:
# load model
model_id = "google/translategemma-4b-it"
load_dotenv()
access_token = os.getenv("HF_TOKEN")

processor = AutoProcessor.from_pretrained(model_id, token=access_token)

model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    device_map="cpu", 
    token=access_token
)

"""
model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
    token=access_token
)
with torch.no_grad():
    model.to(device)
"""

print("Model loaded successfully")

In [None]:
# sst padding side (for batch generation)
processor.tokenizer.padding_side = "left" 
processor.tokenizer.pad_token = processor.tokenizer.eos_token # Ensure pad token exists

source_texts = [
    "I love performing magic for people.",
    "Artificial intelligence is a fascinating field of study.",
    "I want to move to Spain to study with my mentor."
]

# create list of prompts
batch_messages = [
    [{"role": "user", "content": [{"type": "text", "source_lang_code": "en", "target_lang_code": "nl-NL", "text": txt}]}]
    for txt in source_texts
]

# apply template to the entire batch
inputs = processor.apply_chat_template(
    batch_messages, 
    tokenize=True, 
    add_generation_prompt=True, 
    return_dict=True, 
    return_tensors="pt",
    padding=True
).to(model.device)

# generate model output for the whole batch at once
with torch.inference_mode():
    outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)

# decode model output 
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)

for original, translated in zip(source_texts, decoded_outputs):
    print(f"Original: {original}\nTranslated: {translated}\n")