LLAVA case study

In [1]:
# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, f1_score
import torch
import pandas as pd
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = AutoModelForImageTextToText.from_pretrained("llava-hf/llava-1.5-7b-hf").to(device)

  from .autonotebook import tqdm as notebook_tqdm
The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 
Loading checkpoint shards: 100%|██████████| 3/3 [01:25<00:00, 28.63s/it]


load dataset

In [2]:
train_dataset = pd.read_csv("generated_img_dataset/train.csv")
valid_dataset = pd.read_csv("generated_img_dataset/val.csv")
test_dataset = pd.read_csv("generated_img_dataset/test.csv")

In [3]:
test_dataset

Unnamed: 0,EM,EN,unicode,label,strategy,image
0,👔📈,big business,U+1F454 U+1F4C8,1,1,./generated_img_dataset/test_google/0.png
1,🏢🤑🤑,big business,U+1F3E2 U+1F911 U+1F911,1,1,./generated_img_dataset/test_google/1.png
2,👨‍💻🤝,big business,U+1F468 U+200D U+1F4BB U+1F91D,1,1,./generated_img_dataset/test_google/2.png
3,🏢🧑‍🤝‍🧑🧑‍🤝‍🧑🧑‍🤝‍🧑,big business,U+1F3E2 U+1F9D1 U+200D U+1F91D U+200D U+1F9D1 ...,1,1,./generated_img_dataset/test_google/3.png
4,👩‍💻🤑,big business,U+1F469 U+200D U+1F4BB U+1F911,1,1,./generated_img_dataset/test_google/4.png
...,...,...,...,...,...,...
488,👍👣,effective entrance,U+1F44D U+1F463,0,6,./generated_img_dataset/test_google/513.png
489,👏🪜,effective entrance,U+1F44F U+1FA9C,0,6,./generated_img_dataset/test_google/514.png
490,😤🗣️💬,effective entrance,U+1F624 U+1F5E3 U+FE0F U+1F4AC,0,6,./generated_img_dataset/test_google/515.png
491,💨🤬,effective entrance,U+1F4A8 U+1F92C,0,6,./generated_img_dataset/test_google/516.png


zero-shot prediction using prompt

In [None]:
def zero_shot_predict(sample):
    """
    Given a sample (a row from your CSV with at least columns 'EN' and 'image'),
    construct a conversation, generate the prompt using apply_chat_template, then process
    the actual image and text through the model to produce a zero-shot prediction.
    Returns a tuple (prediction, generated_text).
    """
    # Construct a prompt that asks if the emoji sequence entails the English phrase.
    prompt_message = f"Does this emoji sequence mean '{sample['EN']}'? Answer yes or no."
    
    # Build the conversation using the expected structure. Note that we leave the image field empty here.
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt_message},
                {"type": "image"},  # Placeholder; the actual image is provided below.
            ],
        },
    ]
    
    # Use apply_chat_template to get the final prompt.
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    
    # Load the image from the given path (or URL if applicable) and convert to RGB.
    try:
        raw_image = Image.open(sample['image']).convert("RGB")
    except Exception as e:
        print(f"Error loading image {sample['image']}: {e}")
        return 0, ""
    
    # Process the image and the prompt into model inputs.
    inputs = processor(images=raw_image, text=prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate output from the model.
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=50)
    
    # Decode the generated tokens to text.
    generated_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
    
    # Interpret the result: if "yes" appears, predict entailment (1); otherwise, predict 0.
    prediction = 1 if "yes" in generated_text else 0
    return prediction, generated_text

# Run zero-shot predictions over the test dataset
predictions = []
generated_texts = []
true_labels = test_dataset["label"].tolist()

for idx, row in test_dataset.iterrows():
    pred, gen_text = zero_shot_predict(row)
    predictions.append(pred)
    generated_texts.append(gen_text)
    print(f"Row {idx}: EN = {row['EN']} -> Prediction: {pred}, Generated: {gen_text}")

evaluation

In [None]:

# Compute metrics
accuracy = accuracy_score(true_labels, predictions)
# Calculate F1 Macro score
f1_macro = f1_score(true_labels, predictions, average='macro')

print(f"Zero-shot accuracy: {accuracy:.4f}")
print(f"F1 Macro: {f1_macro:.4f}")

ValueError: Found input variables with inconsistent numbers of samples: [493, 4]

zero-shot prediction without prompt

In [None]:
def zero_shot_predict(sample):
    """
    Given a sample (a row from your CSV with at least columns 'EN' and 'image'),
    construct a conversation, generate the prompt using apply_chat_template, then process
    the actual image and text through the model to produce a zero-shot prediction.
    Returns a tuple (prediction, generated_text).
    """
    # Construct a prompt that asks if the emoji sequence entails the English phrase.
    text = sample['EN']
    
    # Build the conversation using the expected structure. Note that we leave the image field empty here.
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": text},
                {"type": "image"},  # Placeholder; the actual image is provided below.
            ],
        },
    ]
    
    # Use apply_chat_template to get the final prompt.
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    
    # Load the image from the given path (or URL if applicable) and convert to RGB.

    raw_image = Image.open(sample['image']).convert("RGB")
    
    # Process the image and the prompt into model inputs.
    inputs = processor(images=raw_image, text=prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate output from the model.
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=50)
    
    # Decode the generated tokens to text.
    generated_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
    
    # Interpret the result: if "yes" appears, predict entailment (1); otherwise, predict 0. does the model answers yes/no?
    prediction = 1 if "yes" in generated_text else 0
    return prediction, generated_text

# Run zero-shot predictions over the test dataset
predictions = []
generated_texts = []
true_labels = test_dataset["label"].tolist()

for idx, row in test_dataset.iterrows():
    pred, gen_text = zero_shot_predict(row)
    predictions.append(pred)
    generated_texts.append(gen_text)
    print(f"Row {idx}: EN = {row['EN']} -> Prediction: {pred}, Generated: {gen_text}")

evaluation

In [None]:
# Calculate accuracy
accuracy = sum(1 for p, l in zip(predictions, true_labels) if p == l) / len(predictions)
# Calculate F1 Macro score
f1_macro = f1_score(true_labels, predictions, average='macro')

print(f"Zero-shot accuracy: {accuracy:.4f}")
print(f"F1 Macro: {f1_macro:.4f}")

batch prediction

In [None]:
# def batch_zero_shot_predict(samples):
#     """
#     Given a list of samples (each sample is a row from your CSV with columns 'EN' and 'image'),
#     build prompts and load images in a batch. Then, use the model to generate outputs for all
#     samples at once.
    
#     Returns:
#         predictions: list of binary predictions (1 if entailment, 0 otherwise)
#         results: list of generated text strings for each sample
#     """
#     prompts = []
#     images = []
    
#     for sample in samples:
#         # Create a prompt for each sample
#         prompt = f"Does this emoji sequence mean '{sample['EN']}'? Answer yes or no."
#         prompts.append(prompt)
        
#         # Load the image using PIL and convert to RGB
#         try:
#             image = Image.open(sample['image']).convert("RGB")
#         except Exception as e:
#             print(f"Error loading image {sample['image']}: {e}")
#             # If image loading fails, append a placeholder (or you can choose to skip the sample)
#             # Here we simply append a blank image.
#             image = Image.new("RGB", (224, 224), color=(255, 255, 255))
#         images.append(image)
    
#     # Process all images and prompts in one go to get batched tensors
#     inputs = processor(images=images, text=prompts, return_tensors="pt")
#     inputs = {k: v.to(device) for k, v in inputs.items()}
    
#     with torch.no_grad():
#         outputs = model.generate(**inputs, max_new_tokens=50)
    
#     # Decode each generated output into text and interpret the response
#     results = [processor.tokenizer.decode(out, skip_special_tokens=True).strip().lower() for out in outputs]
#     predictions = [1 if "yes" in result else 0 for result in results]
#     return predictions, results

# # Example usage:
# # Suppose you have a DataFrame 'test_dataset' and you want to batch, e.g., 8 samples at a time.
# batch_size = 32
# all_predictions = []
# all_generated_texts = []
# samples = test_dataset.to_dict("records")  # List of dictionaries for each row

# for i in range(0, len(samples), batch_size):
#     batch_samples = samples[i:i+batch_size]
#     preds, gen_texts = batch_zero_shot_predict(batch_samples)
#     all_predictions.extend(preds)
#     all_generated_texts.extend(gen_texts)

# # Now you can compute metrics using sklearn
# from sklearn.metrics import accuracy_score, f1_score
# y_true = test_dataset["label"].tolist()
# accuracy = accuracy_score(y_true, all_predictions)
# f1_macro = f1_score(y_true, all_predictions, average="macro")

# print("Zero-shot Accuracy:", accuracy)
# print("F1 Macro:", f1_macro)

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).