In [1]:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from dotenv import load_dotenv
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load model
model_id = "google/translategemma-4b-it"
load_dotenv()
access_token = os.getenv("HF_TOKEN")

processor = AutoProcessor.from_pretrained(model_id, token=access_token)

model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    device_map="cpu", 
    token=access_token
)

"""
model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
    token=access_token
)
with torch.no_grad():
    model.to(device)
"""

print("Model loaded successfully")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00,  9.30it/s]


Model loaded successfully


In [3]:
import pandas as pd
from datasets import load_dataset

In [4]:
ds = load_dataset("google/wmt24pp", "en-nl_NL")
df = pd.DataFrame(ds)

sources = [row['source'] for row in df['train']]
targets = [row['target'] for row in df['train']]
original_targets = [row['original_target'] for row in df['train']]

In [None]:
# set padding (for batch generation)
processor.tokenizer.padding_side = "left" 
processor.tokenizer.pad_token = processor.tokenizer.eos_token # Ensure pad token exists

# create list of prompts
batch_messages = [
    [{"role": "user", "content": [{"type": "text", "source_lang_code": "en", "target_lang_code": "nl-NL", "text": txt}]}]
    for txt in sources[:50]
]

# apply template to the entire batch
inputs = processor.apply_chat_template(
    batch_messages, 
    tokenize=True, 
    add_generation_prompt=True, 
    return_dict=True, 
    return_tensors="pt",
    padding=True
).to(model.device)

# generate model output for the whole batch at once
with torch.inference_mode():
    outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)

# decode model output 
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)

# save model output to a list
all_translations = [translated for translated in decoded_outputs]

# verify that translation worked
print(f"Captured {len(all_translations)} translations.")

Captured 50 translations.


In [13]:
from comet import download_model, load_from_checkpoint

comet_model_path = download_model("Unbabel/wmt22-comet-da")
comet_model = load_from_checkpoint(comet_model_path)
comet_data = [
    {"src": s, "mt": t, "ref": r} 
    for s, t, r in zip(sources[:50], all_translations, targets[:50])
]

model_output = comet_model.predict(comet_data, batch_size=16, gpus=0)
print(model_output)

Fetching 5 files: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5/5 [00:00<?, ?it/s]
Lightning automatically upgraded your loaded checkpoint from v1.8.3.post1 to v2.6.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\bramv\.cache\huggingface\hub\models--Unbabel--wmt22-comet-da\snapshots\2760a223ac957f30acfb18c8aa649b01cf1d75f2\checkpoints\model.ckpt`
Encoder model frozen.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Predicting DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [00:21<00:00,  5.27s/it]

Prediction({'scores': [0.9530108571052551, 0.8261133432388306, 0.8326753973960876, 0.8731812238693237, 0.7812501192092896, 0.9162182211875916, 0.7998655438423157, 0.6808553338050842, 0.8368621468544006, 0.8672711253166199, 0.9156062006950378, 0.9216163158416748, 0.8569576144218445, 0.9203564524650574, 0.8951510190963745, 0.8310083746910095, 0.852101743221283, 0.8589473366737366, 0.8727201819419861, 0.908475935459137, 0.862957775592804, 0.8283209800720215, 0.7891461253166199, 0.7820630669593811, 0.8237118124961853, 0.7510408163070679, 0.8159080147743225, 0.8214782476425171, 0.8129777312278748, 0.8479556441307068, 0.8799853920936584, 0.8224462866783142, 0.8798291087150574, 0.8511192202568054, 0.8972647786140442, 0.8711307644844055, 0.861009418964386, 0.8762597441673279, 0.877729594707489, 0.7296474575996399, 0.8510574102401733, 0.7705533504486084, 0.7788724899291992, 0.8370561003684998, 0.8295997977256775, 0.8388290405273438, 0.7343507409095764, 0.8457963466644287, 0.7907679080963135, 0.




In [15]:
print(len(model_output["scores"]))

50
