In [None]:
import os
from mistralai import Mistral
from dotenv import load_dotenv
from common import BongardDataset
import base64
import pandas as pd
from tqdm import tqdm
import time
import re

load_dotenv()


def encode_image(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: The file {image_path} was not found.")
        return None
    except Exception as e:
        print(f"Error: {e}")
        return None

In [25]:
def parse_prompts(text):
    regex = r"Positive prompt:\s*(.+?)\s*Negative prompt:\s*(.+)"

    match = re.search(regex, text, re.DOTALL)

    if match:
        positive_prompt = match.group(1).strip().replace("\"", "")
        negative_prompt = match.group(2).strip().replace("\"", "")
        return positive_prompt, negative_prompt
    else:
        print(f"Error: Could not parse prompts from text: {text}")
        return "", ""

In [None]:
QUESTION = """
You are given an image and a general concept. Your task is to refine the concept into two concise, visually descriptive prompts: one that aligns with the image (positive prompt) and one that contrasts with it (negative prompt). Focus on making each prompt specific, clearly grounded in the image, and reflective of the core idea. You don’t need to match every detail—just convey the main visual concept.

### Example:
Image: Human legs wearing socks with vertical lines
Concept: Vertical lines.
Positive prompt: Socks with vertical lines
Negative prompt: Socks with horizontal lines 

Now, it's your turn: 
Concept: {concept}

### Instructions:
1. Generate a positive and negative prompt based on the provided image and concept.
2. Answer using the following format:
Positive prompt:
Negative prompt:
""".strip()

In [None]:
FILE_NAME = "bonagrd_rwr_prompts.csv"
already_answered = pd.read_csv(FILE_NAME) if os.path.exists(FILE_NAME) else pd.DataFrame(columns=["problem_id", "file", "side", "positive", "negative"])
api_key = os.environ["MISTRAL_API_KEY"]
model = "pixtral-12b-2409"
client = Mistral(api_key=api_key)
dataset = BongardDataset("../data/bongard-rwr")
answers = []

for problem_id, file_name, side, file_path in tqdm(dataset.all_fragments()):
    if already_answered.query(f"problem_id == {problem_id} and file == '{file_name}' and side == '{side}'").shape[0] > 0:
        continue

    try: 
        left_label, right_label = dataset.get_labels(problem_id)
        encoded_image = encode_image(file_path)

        chat_response = client.chat.complete(
            model= model,
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": QUESTION.format(concept=left_label if side == "left" else right_label)
                        },
                        {
                            "type": "image_url",
                            "image_url": f"data:image/jpeg;base64,{encoded_image}" 
                        }
                    ]
                },
            ]
        )

        positive_prompt, negative_prompt = parse_prompts(chat_response.choices[0].message.content)

        answers.append({
            "problem_id": problem_id,
            "file": file_name,
            "side": side,
            "positive": positive_prompt,
            "negative": negative_prompt,
        })

        df = pd.concat([pd.DataFrame(answers), already_answered])
        df.to_csv(FILE_NAME, index=False)

        time.sleep(5)
        
    except Exception as e:
        print(f"Error: {e}")
        print(f"Problem ID: {problem_id}, File: {file_name}, Side: {side}")
        continue

840it [00:08, 98.29it/s] 


In [29]:
df = pd.read_csv("bonagrd_rwr_prompts.csv")
df.shape

(840, 6)