In [None]:
import os
import joblib
import json
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForPreTraining
import torch

In [None]:
model_name_or_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
processor = AutoProcessor.from_pretrained(model_name_or_path)
model = AutoModelForPreTraining.from_pretrained(model_name_or_path)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
file = "/Users/stanislav/Invisible-Relevance-Bias/flickr/Flickr30k/captions.txt"
new_file = "./flickr_merge/flickr30k_test_llama_caps.txt"

In [None]:
output_dir = os.path.dirname(new_file)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [None]:
with open(file, 'r') as f, open(new_file, 'w') as f_write:
    # List to store the final processed data
    data_final = []

    # Prompt to consolidate captions
    prompt_template = (
        'Consolidate the five descriptions, avoid redundancy while including the scene described in each sentence, '
        'and make a concise summary:\n'
    )

    # Limit the number of API requests to avoid overloading
    MAX_REQUESTS = 10

    # Read all lines from the file
    lines = f.readlines()

    # Iterate over each line (with tqdm for progress tracking)
    for idx, line in tqdm(enumerate(lines[:MAX_REQUESTS]), total=MAX_REQUESTS):
        # Check if the data is already processed
        if idx < len(data_final):
            continue

        # Split the line into parts
        parts = line.strip().split(',')

        # Ensure that we have the correct format (image_name, label, and captions)
        if len(parts) < 3:
            print(f"Skipping line {idx} due to incorrect format")
            continue

        # Prepare data structure
        new_one_data = {
            'image_name': parts[0],  # First part is the image name
            'label': parts[1],       # Second part is the label
            'caption': []            # We will store the new consolidated caption here
        }

        # Captions (parts[2:] handles the case of commas in the caption)
        captions = parts[2:]

        # Format the text for the prompt
        text = prompt_template + "\n".join([f"{i + 1}. {caption}" for i, caption in enumerate(captions)])

        # Tokenize the input text for the model
        inputs = processor(text, return_tensors="pt").to(device)

        # Generate the caption using LLaMA
        outputs = model.generate(**inputs, max_length=100, temperature=0.7)
        new_text = processor.decode(outputs[0], skip_special_tokens=True)

        # Save the generated caption
        new_one_data['caption'] = new_text

        # Write the new processed data to the output file
        f_write.write(json.dumps(new_one_data) + "\n")

        # Append the new processed data to the final list
        data_final.append(new_one_data)

        # Save intermediate results
        joblib.dump(data_final, './flickr_merge/flickr30k_test_llama_caps')

In [None]:
print("Processing completed and file saved.")