LLAVA case study

In [1]:
# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText, Trainer, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlavaForConditionalGeneration
import torch
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
from PIL import Image
from torch.amp import autocast
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
# device = "cuda"
model_name = "llava-hf/llava-1.5-7b-hf"

model = LlavaForConditionalGeneration.from_pretrained(
    model_name, 
    device_map="auto",
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
)

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.89s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


load dataset

In [2]:
train_dataset = pd.read_csv("generated_img_dataset/train.csv")
valid_dataset = pd.read_csv("generated_img_dataset/val.csv")
test_dataset = pd.read_csv("generated_img_dataset/test.csv")

In [3]:
test_dataset

Unnamed: 0,EM,EN,unicode,label,strategy,image
0,👔📈,big business,U+1F454 U+1F4C8,1,1,./generated_img_dataset/test_google/0.png
1,🏢🤑🤑,big business,U+1F3E2 U+1F911 U+1F911,1,1,./generated_img_dataset/test_google/1.png
2,👨‍💻🤝,big business,U+1F468 U+200D U+1F4BB U+1F91D,1,1,./generated_img_dataset/test_google/2.png
3,🏢🧑‍🤝‍🧑🧑‍🤝‍🧑🧑‍🤝‍🧑,big business,U+1F3E2 U+1F9D1 U+200D U+1F91D U+200D U+1F9D1 ...,1,1,./generated_img_dataset/test_google/3.png
4,👩‍💻🤑,big business,U+1F469 U+200D U+1F4BB U+1F911,1,1,./generated_img_dataset/test_google/4.png
...,...,...,...,...,...,...
513,👍👣,effective entrance,U+1F44D U+1F463,0,6,./generated_img_dataset/test_google/513.png
514,👏🪜,effective entrance,U+1F44F U+1FA9C,0,6,./generated_img_dataset/test_google/514.png
515,😤🗣️💬,effective entrance,U+1F624 U+1F5E3 U+FE0F U+1F4AC,0,6,./generated_img_dataset/test_google/515.png
516,💨🤬,effective entrance,U+1F4A8 U+1F92C,0,6,./generated_img_dataset/test_google/516.png


In [4]:
import torch
from PIL import Image
from tqdm.auto import tqdm

# --- Configuration ---
BATCH_SIZE = 8 # <<< START SMALL (e.g., 2 or 4) and increase if possible

# Helper function to yield batches from the DataFrame
def generate_batches(df, batch_size):
    for i in range(0, len(df), batch_size):
        yield df.iloc[i:min(i + batch_size, len(df))]

def batch_zero_shot_predict(batch_samples):
    """
    Given a batch of samples (a slice of your DataFrame),
    construct conversations, generate prompts, process images and text
    through the model to produce zero-shot predictions for the batch.
    Returns a tuple (list_of_predictions, list_of_generated_texts).
    """
    batch_prompts_structured = []
    batch_raw_images = []
    batch_indices_processed = [] # Keep track of which samples were successful

    # 1. Prepare batch data (prompts and images)
    for index, sample in batch_samples.iterrows():
        # Construct prompt message
        prompt_message = f"Does this emoji sequence mean '{sample['EN']}'? Answer yes or no."

        # Build the conversation structure for the processor template
        # Note: We still build one structure per sample before applying the template
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt_message},
                    {"type": "image"},
                ],
            },
        ]
        batch_prompts_structured.append(conversation)

        # Load the image
        try:
            raw_image = Image.open(sample['image']).convert("RGB")
            batch_raw_images.append(raw_image)
            batch_indices_processed.append(index) # Store original index
        except Exception as e:
            print(f"Warning: Error loading image {sample['image']} for index {index}: {e}. Skipping this sample.")
            # Append placeholders to keep lists aligned for now, will filter later
            # Or handle differently if your processor/model is sensitive to None
            batch_raw_images.append(None) # Placeholder

    # Filter out samples where image loading failed
    valid_indices = [i for i, img in enumerate(batch_raw_images) if img is not None]
    if not valid_indices:
        print("Warning: No valid images found in this batch. Skipping.")
        # Return empty lists matching the expected output structure
        return [0] * len(batch_samples), ["<IMAGE_LOAD_ERROR>"] * len(batch_samples)

    # Keep only valid images and their corresponding structured prompts
    final_batch_raw_images = [batch_raw_images[i] for i in valid_indices]
    final_batch_prompts_structured = [batch_prompts_structured[i] for i in valid_indices]
    original_indices_for_valid = [batch_samples.index[i] for i in valid_indices] # Get original df index

    # 2. Apply chat template and tokenize prompts
    # We need to apply the template individually then tokenize as a batch
    # because apply_chat_template usually works on a single conversation.
    batch_final_prompts_text = []
    for conv in final_batch_prompts_structured:
         # Use add_generation_prompt=True for the model's turn
        prompt_text = processor.apply_chat_template(conv, add_generation_prompt=True)
        batch_final_prompts_text.append(prompt_text)

    # 3. Process the batch of images and text prompts
    # Use padding=True for batching text inputs
    inputs = processor(
        images=final_batch_raw_images,
        text=batch_final_prompts_text,
        return_tensors="pt",
        padding=True, # Crucial for batching text
        truncation=True # Good practice
    )
    inputs = {k: v for k, v in inputs.items()}

    # 4. Generate outputs from the model for the batch
    generated_ids = None
    try:
        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=50)
    except Exception as e:
         print(f"Error during model generation for a batch (possible OOM): {e}")
         # Handle error: return default predictions for this batch
         # Create placeholder results aligned with the original batch size
         batch_predictions = [0] * len(batch_samples)
         batch_generated_texts = ["<GENERATION_ERROR>"] * len(batch_samples)
         return batch_predictions, batch_generated_texts

    # Clean up GPU memory if applicable
    # del inputs
    # if torch.cuda.is_available():
    #     torch.cuda.empty_cache()

    # 5. Decode the generated tokens to text for the batch
    # Use batch_decode for efficiency
    batch_decoded_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

    # 6. Interpret results and map back to original batch structure
    batch_predictions_map = {}
    batch_generated_texts_map = {}
    for i, decoded_text in enumerate(batch_decoded_texts):
        cleaned_text = decoded_text.strip().lower()
        # The LLaVA output might include the prompt, need to isolate the answer
        # Find the last occurrence of the assistant's turn marker if apply_chat_template added one
        # Or simply check the end of the string if the template is simple
        # Example check (adapt based on your model's exact output format):
        answer_part = cleaned_text.split("assistant:")[-1] # separator

        prediction = 1 if "yes" in answer_part else 0
        original_index = original_indices_for_valid[i] # Map back using original df index
        batch_predictions_map[original_index] = prediction
        batch_generated_texts_map[original_index] = cleaned_text # Store full generated text

    # Create final lists in the original batch order, filling in defaults for skipped items
    
    final_predictions = []
    final_gen_texts = []
    for index in batch_samples.index:
        final_predictions.append(batch_predictions_map.get(index, 0)) # Default 0 if skipped/error
        final_gen_texts.append(batch_generated_texts_map.get(index, "<SKIPPED_OR_ERROR>"))

    return final_predictions, final_gen_texts

# --- Main prediction loop using batches ---
predictions = []
generated_texts = []
true_labels = test_dataset["label"].tolist()

# Create the batch generator with a progress bar
batch_generator = generate_batches(test_dataset, BATCH_SIZE)
num_batches = (len(test_dataset) + BATCH_SIZE - 1) // BATCH_SIZE

print(f"Starting prediction with batch size {BATCH_SIZE} on device ...")
for batch_df in tqdm(batch_generator, total=num_batches, desc="Processing Batches"):
    batch_preds, batch_gen_texts = batch_zero_shot_predict(batch_df)
    predictions.extend(batch_preds)
    generated_texts.extend(batch_gen_texts)

    # Optional: Print progress for the first item in the batch
    first_idx = batch_df.index[0]
    print(f"Batch starting row {first_idx}: EN = {batch_df.iloc[0]['EN']} -> Prediction: {batch_preds[0]}, Generated: {batch_gen_texts[0][:100]}...") # Print truncated generated text

print("Finished prediction.")

# from sklearn.metrics import accuracy_score, f1_score
# print(f"Accuracy: {accuracy_score(true_labels, predictions)}")
# print(f"F1 Macro: {f1_score(true_labels, predictions, average='macro')}")


Starting prediction with batch size 8 on device ...


Processing Batches:   0%|          | 0/65 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Processing Batches:   2%|▏         | 1/65 [00:17<18:38, 17.48s/it]

Batch starting row 0: EN = big business -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'big business'? answer yes or no. assistant: yes....


Processing Batches:   3%|▎         | 2/65 [00:20<09:35,  9.14s/it]

Batch starting row 8: EN = big expenditure -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'big expenditure'? answer yes or no. assistant: yes....


Processing Batches:   5%|▍         | 3/65 [00:24<06:41,  6.47s/it]

Batch starting row 16: EN = big voice -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'big voice'? answer yes or no. assistant: yes....


Processing Batches:   6%|▌         | 4/65 [00:27<05:11,  5.11s/it]

Batch starting row 24: EN = big group -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'big group'? answer yes or no. assistant: yes....


Processing Batches:   8%|▊         | 5/65 [00:30<04:25,  4.42s/it]

Batch starting row 32: EN = big man -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big man'? answer yes or no. assistant: no....


Processing Batches:   9%|▉         | 6/65 [00:33<03:52,  3.94s/it]

Batch starting row 40: EN = big city -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'big city'? answer yes or no. assistant: yes....


Processing Batches:  11%|█         | 7/65 [00:36<03:34,  3.69s/it]

Batch starting row 48: EN = big tipper -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big tipper'? answer yes or no. assistant: no....


Processing Batches:  12%|█▏        | 8/65 [00:39<03:23,  3.57s/it]

Batch starting row 56: EN = big day -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'big day'? answer yes or no. assistant: yes....


Processing Batches:  14%|█▍        | 9/65 [00:42<03:11,  3.42s/it]

Batch starting row 64: EN = hot doll -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot doll'? answer yes or no. assistant: no....


Processing Batches:  15%|█▌        | 10/65 [00:46<03:06,  3.38s/it]

Batch starting row 72: EN = hot water -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'hot water'? answer yes or no. assistant: yes....


Processing Batches:  17%|█▋        | 11/65 [00:49<03:01,  3.36s/it]

Batch starting row 80: EN = hot stove -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'hot stove'? answer yes or no. assistant: yes....


Processing Batches:  18%|█▊        | 12/65 [00:52<02:57,  3.34s/it]

Batch starting row 88: EN = hot topic -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot topic'? answer yes or no. assistant: no....


Processing Batches:  20%|██        | 13/65 [00:56<02:53,  3.33s/it]

Batch starting row 96: EN = hot merchandise -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot merchandise'? answer yes or no. assistant: no....


Processing Batches:  22%|██▏       | 14/65 [00:57<02:18,  2.72s/it]

Batch starting row 104: EN = hot temper -> Prediction: 0, Generated: <SKIPPED_OR_ERROR>...


Processing Batches:  23%|██▎       | 15/65 [01:00<02:16,  2.72s/it]

Batch starting row 112: EN = hot argument -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot argument'? answer yes or no. assistant: no....


Processing Batches:  25%|██▍       | 16/65 [01:03<02:22,  2.91s/it]

Batch starting row 120: EN = hot forehead -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot forehead'? answer yes or no. assistant: no....


Processing Batches:  26%|██▌       | 17/65 [01:06<02:28,  3.09s/it]

Batch starting row 128: EN = full game -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'full game'? answer yes or no. assistant: no....


Processing Batches:  28%|██▊       | 18/65 [01:10<02:26,  3.12s/it]

Batch starting row 136: EN = full auditorium -> Prediction: 0, Generated: <SKIPPED_OR_ERROR>...


Processing Batches:  29%|██▉       | 19/65 [01:13<02:25,  3.17s/it]

Batch starting row 144: EN = full attention -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'full attention'? answer yes or no. assistant: no....


Processing Batches:  31%|███       | 20/65 [01:16<02:23,  3.19s/it]

Batch starting row 152: EN = full glass -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'full glass'? answer yes or no. assistant: yes....


Processing Batches:  32%|███▏      | 21/65 [01:19<02:19,  3.17s/it]

Batch starting row 160: EN = full life -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'full life'? answer yes or no. assistant: no....


Processing Batches:  34%|███▍      | 22/65 [01:23<02:17,  3.20s/it]

Batch starting row 168: EN = little boy -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'little boy'? answer yes or no. assistant: yes....


Processing Batches:  35%|███▌      | 23/65 [01:26<02:14,  3.21s/it]

Batch starting row 176: EN = little man -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'little man'? answer yes or no. assistant: yes....


Processing Batches:  37%|███▋      | 24/65 [01:29<02:11,  3.22s/it]

Batch starting row 184: EN = little house -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'little house'? answer yes or no. assistant: yes....


Processing Batches:  38%|███▊      | 25/65 [01:32<02:09,  3.23s/it]

Batch starting row 192: EN = thin oil -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'thin oil'? answer yes or no. assistant: no....


Processing Batches:  40%|████      | 26/65 [01:36<02:05,  3.23s/it]

Batch starting row 200: EN = thin soup -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'thin soup'? answer yes or no. assistant: yes....


Processing Batches:  42%|████▏     | 27/65 [01:39<02:02,  3.22s/it]

Batch starting row 208: EN = thin air -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'thin air'? answer yes or no. assistant: no....


Processing Batches:  43%|████▎     | 28/65 [01:42<02:01,  3.28s/it]

Batch starting row 216: EN = thin line -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'thin line'? answer yes or no. assistant: no....


Processing Batches:  45%|████▍     | 29/65 [01:45<01:51,  3.09s/it]

Batch starting row 224: EN = ineffectual ruler -> Prediction: 0, Generated: <SKIPPED_OR_ERROR>...


Processing Batches:  46%|████▌     | 30/65 [01:48<01:52,  3.20s/it]

Batch starting row 232: EN = ineffectual therapy -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'ineffectual therapy'? answer yes or no. assistant: no....


Processing Batches:  48%|████▊     | 31/65 [01:51<01:48,  3.19s/it]

Batch starting row 240: EN = effective step -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'effective step'? answer yes or no. assistant: no....


Processing Batches:  49%|████▉     | 32/65 [01:55<01:44,  3.18s/it]

Batch starting row 248: EN = effective reprimand -> Prediction: 0, Generated: <SKIPPED_OR_ERROR>...


Processing Batches:  51%|█████     | 33/65 [01:58<01:43,  3.25s/it]

Batch starting row 256: EN = effective entrance -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'effective entrance'? answer yes or no. assistant: no....


Processing Batches:  52%|█████▏    | 34/65 [02:01<01:42,  3.31s/it]

Batch starting row 264: EN = big business -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big business'? answer yes or no. assistant: no....


Processing Batches:  54%|█████▍    | 35/65 [02:05<01:40,  3.36s/it]

Batch starting row 272: EN = big expenditure -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big expenditure'? answer yes or no. assistant: no....


Processing Batches:  55%|█████▌    | 36/65 [02:08<01:35,  3.31s/it]

Batch starting row 280: EN = big voice -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big voice'? answer yes or no. assistant: no....


Processing Batches:  57%|█████▋    | 37/65 [02:11<01:29,  3.20s/it]

Batch starting row 288: EN = big group -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big group'? answer yes or no. assistant: no....


Processing Batches:  58%|█████▊    | 38/65 [02:14<01:26,  3.21s/it]

Batch starting row 296: EN = big man -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big man'? answer yes or no. assistant: no....


Processing Batches:  60%|██████    | 39/65 [02:18<01:25,  3.28s/it]

Batch starting row 304: EN = big city -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'big city'? answer yes or no. assistant: yes....


Processing Batches:  62%|██████▏   | 40/65 [02:21<01:22,  3.32s/it]

Batch starting row 312: EN = big tipper -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big tipper'? answer yes or no. assistant: no....


Processing Batches:  63%|██████▎   | 41/65 [02:24<01:17,  3.21s/it]

Batch starting row 320: EN = big day -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'big day'? answer yes or no. assistant: no....


Processing Batches:  65%|██████▍   | 42/65 [02:27<01:11,  3.12s/it]

Batch starting row 328: EN = hot doll -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot doll'? answer yes or no. assistant: no....


Processing Batches:  66%|██████▌   | 43/65 [02:30<01:06,  3.03s/it]

Batch starting row 336: EN = hot water -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot water'? answer yes or no. assistant: no....


Processing Batches:  68%|██████▊   | 44/65 [02:33<01:02,  2.96s/it]

Batch starting row 344: EN = hot stove -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot stove'? answer yes or no. assistant: no....


Processing Batches:  69%|██████▉   | 45/65 [02:36<01:02,  3.13s/it]

Batch starting row 352: EN = hot topic -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot topic'? answer yes or no. assistant: no....


Processing Batches:  71%|███████   | 46/65 [02:39<00:59,  3.15s/it]

Batch starting row 360: EN = hot merchandise -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot merchandise'? answer yes or no. assistant: no....


Processing Batches:  72%|███████▏  | 47/65 [02:43<00:57,  3.19s/it]

Batch starting row 368: EN = hot temper -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot temper'? answer yes or no. assistant: no....


Processing Batches:  74%|███████▍  | 48/65 [02:46<00:55,  3.29s/it]

Batch starting row 376: EN = hot argument -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot argument'? answer yes or no. assistant: no....


Processing Batches:  75%|███████▌  | 49/65 [02:50<00:53,  3.36s/it]

Batch starting row 384: EN = hot forehead -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'hot forehead'? answer yes or no. assistant: no....


Processing Batches:  77%|███████▋  | 50/65 [02:53<00:51,  3.42s/it]

Batch starting row 392: EN = full game -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'full game'? answer yes or no. assistant: no....


Processing Batches:  78%|███████▊  | 51/65 [02:56<00:46,  3.30s/it]

Batch starting row 400: EN = full auditorium -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'full auditorium'? answer yes or no. assistant: no....


Processing Batches:  80%|████████  | 52/65 [02:59<00:42,  3.27s/it]

Batch starting row 408: EN = full glass -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'full glass'? answer yes or no. assistant: no....


Processing Batches:  82%|████████▏ | 53/65 [03:03<00:39,  3.32s/it]

Batch starting row 416: EN = full life -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'full life'? answer yes or no. assistant: no....


Processing Batches:  83%|████████▎ | 54/65 [03:06<00:36,  3.29s/it]

Batch starting row 424: EN = little boy -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'little boy'? answer yes or no. assistant: no....


Processing Batches:  85%|████████▍ | 55/65 [03:09<00:32,  3.27s/it]

Batch starting row 432: EN = little man -> Prediction: 1, Generated: user:  
does this emoji sequence mean 'little man'? answer yes or no. assistant: yes....


Processing Batches:  86%|████████▌ | 56/65 [03:13<00:29,  3.25s/it]

Batch starting row 440: EN = little house -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'little house'? answer yes or no. assistant: no....


Processing Batches:  88%|████████▊ | 57/65 [03:16<00:25,  3.25s/it]

Batch starting row 448: EN = thin oil -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'thin oil'? answer yes or no. assistant: no....


Processing Batches:  89%|████████▉ | 58/65 [03:19<00:22,  3.22s/it]

Batch starting row 456: EN = thin soup -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'thin soup'? answer yes or no. assistant: no....


Processing Batches:  91%|█████████ | 59/65 [03:22<00:19,  3.22s/it]

Batch starting row 464: EN = thin air -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'thin air'? answer yes or no. assistant: no....


Processing Batches:  92%|█████████▏| 60/65 [03:26<00:16,  3.30s/it]

Batch starting row 472: EN = thin line -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'thin line'? answer yes or no. assistant: no....


Processing Batches:  94%|█████████▍| 61/65 [03:29<00:13,  3.28s/it]

Batch starting row 480: EN = ineffectual ruler -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'ineffectual ruler'? answer yes or no. assistant: no....


Processing Batches:  95%|█████████▌| 62/65 [03:32<00:09,  3.15s/it]

Batch starting row 488: EN = ineffectual therapy -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'ineffectual therapy'? answer yes or no. assistant: no....


Processing Batches:  97%|█████████▋| 63/65 [03:35<00:06,  3.24s/it]

Batch starting row 496: EN = effective step -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'effective step'? answer yes or no. assistant: no....


Processing Batches:  98%|█████████▊| 64/65 [03:39<00:03,  3.31s/it]

Batch starting row 504: EN = effective reprimand -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'effective reprimand'? answer yes or no. assistant: no....


Processing Batches: 100%|██████████| 65/65 [03:41<00:00,  3.41s/it]

Batch starting row 512: EN = effective entrance -> Prediction: 0, Generated: user:  
does this emoji sequence mean 'effective entrance'? answer yes or no. assistant: no....
Finished prediction.





save predictions

In [5]:
generated_texts

["user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: yes.",
 "user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: yes.",
 "user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: yes.",
 "user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  \ndoes this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  \ndoes this emoji sequence mean 'big expenditure'? answer yes or no. assistant: yes.",
 "user:  \ndoes this emoji sequence mean 'big expenditure'? answer yes or no. assistant: yes.",
 "user:  \ndoes this emoji sequence mean 'big expenditure'? answer ye

In [6]:
cleaned_generated_texts = [text.replace('\n', '') for text in generated_texts]
cleaned_generated_texts

["user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: yes.",
 "user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: yes.",
 "user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: yes.",
 "user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: no.",
 "user:  does this emoji sequence mean 'big expenditure'? answer yes or no. assistant: yes.",
 "user:  does this emoji sequence mean 'big expenditure'? answer yes or no. assistant: yes.",
 "user:  does this emoji sequence mean 'big expenditure'? answer yes or no. assistant: ye

In [7]:
test_dataset['predicted_label'] = predictions
test_dataset['generated_text'] = cleaned_generated_texts
test_dataset.to_csv("results.csv", index=False)

In [8]:
import pandas as pd

cleaned_generated_texts = pd.read_csv("results.csv")
cleaned_generated_texts["generated_text"][0]

"user:  does this emoji sequence mean 'big business'? answer yes or no. assistant: yes."

In [9]:
predictions2 = []
for text in cleaned_generated_texts["generated_text"]:
    parts = text.split("assistant:")
    assistant_response = parts[-1].strip() if len(parts) > 1 else text.strip()
    prediction = 1 if "yes" in assistant_response.lower() else 0

    print(parts, prediction)
    predictions2.append(prediction)

["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' yes.'] 1
["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' yes.'] 1
["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' no.'] 0
["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' no.'] 0
["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' no.'] 0
["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' yes.'] 1
["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' no.'] 0
["user:  does this emoji sequence mean 'big business'? answer yes or no. ", ' no.'] 0
["user:  does this emoji sequence mean 'big expenditure'? answer yes or no. ", ' yes.'] 1
["user:  does this emoji sequence mean 'big expenditure'? answer yes or no. ", ' yes.'] 1
["user:  does this emoji sequence mean 'big expenditure'? answer yes or no. ", ' yes.'] 1
["user:  does this emoji sequence mean 

In [10]:
from sklearn.metrics import accuracy_score, f1_score
true_labels = cleaned_generated_texts["label"]
print(f"Accuracy: {accuracy_score(true_labels, predictions2)}")
print(f"F1 Macro: {f1_score(true_labels, predictions2, average='macro')}")

Accuracy: 0.6583011583011583
F1 Macro: 0.6265685903153676


In [11]:
test_dataset = pd.read_csv("generated_img_dataset/test.csv")


In [12]:
test_dataset['predicted_label'] = predictions2
test_dataset['generated_text'] = cleaned_generated_texts["generated_text"]
test_dataset.to_csv("results2.csv", index=False)