In [2]:
import os
from mistralai import Mistral
from dotenv import load_dotenv
from common import BongardDataset
import base64
import pandas as pd
from tqdm import tqdm
import time
import re
from typing import List
import json

load_dotenv()


def encode_image(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: The file {image_path} was not found.")
        return None
    except Exception as e:
        print(f"Error: {e}")
        return None

In [3]:
def parse_prompts(text) -> List[str]:
    regex = r"\[(.*?)\]"

    match = re.search(regex, text, re.DOTALL)

    if match:
        return json.loads(match.group(0))
    else:
        print(f"Error: Could not parse prompts from text: {text}")
        return []

In [14]:
PROMPT = """
You are tasked with modifying a given prompt for a diffusion model. 
Your primary objective is to preserve the specified **concept** while altering unrelated details or environments. 
You are encouraged to create diverse, creative, and unique augmentations that stay true to the concept but introduce variety in interpretation.

### Example:
Prompt: "An empty white bowl with a thin black rim placed on a solid blue background."
Concept: "Empty picture"
Output Augmentations: [
    "An empty red plate on a wooden table.",
    "A clear glass cup sitting on a marble countertop.",
    "A white ceramic vase on a patterned fabric."
]

Now, it's your turn: 
Prompt: {prompt}
Concept: {concept}

### Instructions:
1. Generate **{n_augmentations} unique augmentations** for the provided prompt.
2. Output the results in the following JSON string array format:
[
    "<augmented_prompt_1>",
    "<augmented_prompt_2>",
    ...
]

Ensure that each augmentation aligns with the concept and introduces creative variations in other details.
"""

In [15]:
OUTPUT_FILE = 'bonagrd_rwr_prompts_augmented_plus_15.csv'
prompts = pd.read_csv("bonagrd_rwr_prompts.csv")
augmented = pd.read_csv(OUTPUT_FILE) if os.path.exists(OUTPUT_FILE) else pd.DataFrame(columns=["problem_id", "file", "side", "positive", "negative"]) 
api_key = os.environ["MISTRAL_API_KEY"]
model = "pixtral-12b-2409"
client = Mistral(api_key=api_key)
dataset = BongardDataset("../data/bongard-rwr")
answers = []

for problem_id, file_name, side, file_path in tqdm(dataset.all_fragments()):
    if augmented.query(f"problem_id == {problem_id} and file == '{file_name}' and side == '{side}'").shape[0] > 0:
        continue

    try: 
        prompt_data = prompts.query(f"problem_id == {problem_id} and file == '{file_name}' and side == '{side}'")
        prompt = prompt_data['positive'].iloc[0]
        negative = prompt_data['negative'].iloc[0]

        concept = dataset.get_label(problem_id, side)
        question = PROMPT.format(prompt=prompt, concept=concept, n_augmentations=15)

        chat_response = client.chat.complete(
            model= model,
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": question
                        },
                    ]
                },
            ]
        )

        new_prompts = parse_prompts(chat_response.choices[0].message.content.strip())

        answers.extend({
            "problem_id": problem_id,
            "file": file_name,
            "side": side,
            "positive":positive, 
            "negative": negative
        } for positive in new_prompts)

        df = pd.concat([pd.DataFrame(answers), augmented])
        df.to_csv(OUTPUT_FILE, index=False)

        time.sleep(5)
        
    except Exception as e:
        print(f"Error: {e}")
        print(f"Problem ID: {problem_id}, File: {file_name}, Side: {side}")
        continue

840it [2:14:49,  9.63s/it]
