In [None]:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Load the model in half-precision
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16, device_map="auto")
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

In [None]:
import re
pattern = re.compile(
    r'^ASSISTANT:\s*.*1\.\s*Cause-Effect:\s*(.*?)\s*2\.\s*Figurative Understanding:\s*(.*?)\s*3\.\s*Mental State:\s*(.*)$',
    re.DOTALL
)

In [None]:
import os
import json
from PIL import Image
import torch
from tqdm import tqdm

# Set required attributes for image handling
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"

# Define the instruction for each image
instruction = """Analyze the following depression meme image to extract common sense reasoning in the form of triples. These relationships should
capture the following elements:
• 1. Cause-Effect: Identify concrete causes or results of the situation depicted in the meme.
• 2. Figurative Understanding: Capture underlying metaphors, analogies, or symbolic meanings that convey the meme's deeper message, including any ironic or humorous undertones.
• 3. Mental State: Capture specific mental or emotional states depicted in the meme."""

# Directory containing training images
train_dir = "/kaggle/input/meme-dataset/meme-ocr"
output_data = []
cnt = 0

# Process each image in the "train" directory
for filename in tqdm(os.listdir(train_dir), desc="Processing images"):
    if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
        # Open the image
        image_path = os.path.join(train_dir, filename)
        image = Image.open(image_path).convert("RGB")
        
        # Construct conversation input for the model
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "url": "local://dummy", "image": image},
                    {"type": "text", "text": instruction}
                ]
            }
        ]

        # Process the conversation into model inputs
        inputs = processor.apply_chat_template(
            conversation,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(model.device)

        # Compute image features manually if needed
        if "pixel_values" not in inputs or inputs["pixel_values"] is None:
            pixel_values = processor.image_processor(images=image, return_tensors="pt")["pixel_values"]
            inputs["pixel_values"] = pixel_values.to(model.device, torch.float16)

        # Generate a response
        generate_ids = model.generate(**inputs, max_new_tokens=300)
        response_text = processor.batch_decode(generate_ids, skip_special_tokens=True)

        # while True:
        #     # Generate a response
        #     generate_ids = model.generate(**inputs, max_new_tokens=300, temperature=0.5, do_sample = True)
        #     response_text = processor.batch_decode(generate_ids, skip_special_tokens=True)
    
        #     # Extract model's response after user's prompt
        #     if response_text:
        #         model_response = response_text[0].split(instruction, 1)[-1].strip()
        #         match = re.match(pattern, model_response)
        #         if match:
        #             break

        # Extract model's response after user's prompt
        if response_text:
            model_response = response_text[0].split(instruction, 1)[-1].strip()

            # Store the response in the required JSON format
            sample_id = filename.split(".")[0]  # Extract sample ID from filename
            output_data.append({
                "sample_id": sample_id,
                "figurative_reasoning": model_response
            })

        # Increment the counter
        cnt += 1

        # Dump the content into a JSON file after every 100 images
        if cnt % 100 == 0:
            output_file = f"train_figurative_{cnt}.json"
            with open(output_file, "w") as json_file:
                json.dump(output_data, json_file, indent=4)
            print(f"Responses saved to {output_file}")

    # Save any remaining responses in a JSON file
    if output_data:
        output_file = f"train_figurative_{cnt}.json"
        with open(output_file, "w") as json_file:
            json.dump(output_data, json_file, indent=4)
        print(f"Responses saved to {output_file}")
