In [None]:
import pandas as pd

# This file is provided in the Github of SVO-Probes (https://github.com/google-deepmind/svo_probes)
df=pd.read_csv("svo_probes.csv")
df.head()

In [None]:
import ast

def parse_triplet(triplet_str):
    try:
        # Attempt to parse it as a literal list; use the first element if it's a list
        parsed = ast.literal_eval(triplet_str)
        if isinstance(parsed, list) and len(parsed) > 0:
            return parsed[0].split(',')
        else:
            return triplet_str.split(',')
    except:
        # Fallback for non-list strings or if ast.literal_eval fails
        return triplet_str.split(',')

def generate_fixed_prompt(row):
    # Function to parse triplets that might be stored as strings representing lists
    def parse_triplet(triplet_str):
        try:
            # Attempt to parse it as a literal list; use the first element if it's a list
            parsed = ast.literal_eval(triplet_str)
            if isinstance(parsed, list) and len(parsed) > 0:
                return parsed[0].split(',')
            else:
                raise ValueError("Triplet is not a list or is empty")
        except:
            # Fallback for non-list strings or if ast.literal_eval fails
            return triplet_str.split(',')

    try:
        pos_triplet = parse_triplet(row['pos_triplet'])
        neg_triplet = parse_triplet(row['neg_triplet'])
        
        if len(pos_triplet) != 3 or len(neg_triplet) != 3:
            raise ValueError("Triplet does not contain exactly 3 elements")

        # Handling each negation case
        if row['subj_neg']:
            prompt = f"is {pos_triplet[0]} or {neg_triplet[0]} {pos_triplet[1]} {pos_triplet[2]}?"
        elif row['verb_neg']:
            prompt = f"is {pos_triplet[0]} {pos_triplet[1]} or {neg_triplet[1]} {pos_triplet[2]}?"
        elif row['obj_neg']:
            prompt = f"is {pos_triplet[0]} {pos_triplet[1]} {pos_triplet[2]} or {neg_triplet[2]}?"
        else:
            prompt = f"is {pos_triplet[0]} {pos_triplet[1]} {pos_triplet[2]}?"
    except Exception as e:
        prompt = "Error in triplet format: " + str(e)

    return prompt

# Assuming df is already defined and has the necessary columns
df['clean_prompt'] = df.apply(generate_fixed_prompt, axis=1)
df['clean_prompt']

In [None]:
def remove_bad_triplets(df):
    def is_bad_triplet(triplet_str):
        try:
            # Attempt to parse it as a literal list
            parsed = ast.literal_eval(triplet_str)
            # Check if parsed is an empty list or contains an empty list
            if not parsed or parsed == ['[]']:
                return True
        except:
            # Direct string comparison to catch malformed strings
            if triplet_str == "['[]']":
                return True
        return False
    
    # Filter out rows where either pos_triplet or neg_triplet is bad
    return df[~df['pos_triplet'].apply(is_bad_triplet) & ~df['neg_triplet'].apply(is_bad_triplet)]

df = remove_bad_triplets(df)

In [None]:
def create_duplicated_rows_with_answers(df):
    new_rows = []

    for index, row in df.iterrows():
        # Determine the correct index based on negation flags
        correct_index = 0 if row['subj_neg'] else (1 if row['verb_neg'] else 2)
        
        # Process for positive triplet
        pos_triplet = parse_triplet(row['pos_triplet'])
        neg_triplet = parse_triplet(row['neg_triplet'])
        #print(neg_triplet)
        # Create a row for the positive scenario
        pos_row = row.copy()
        #print(pos_triplet)
        #print(row['pos_triplet'])
        pos_row['correct_answer'] = pos_triplet[correct_index]
        pos_row['url'] = row['pos_url']
        new_rows.append(pos_row)
        
        # Create a row for the negative scenario
        neg_row = row.copy()
        neg_row['correct_answer'] = neg_triplet[correct_index]
        neg_row['url'] = row['neg_url']
        new_rows.append(neg_row)

    # Create a new DataFrame
    new_df = pd.DataFrame(new_rows).reset_index(drop=True)
    return new_df

# Call the function to duplicate rows and set 'correct_answer' and 'url' appropriately
df = create_duplicated_rows_with_answers(df)
df.head()

In [None]:
df.rename(columns={'url': 'clean_image_path'}, inplace=True)

df['corrupt_image_path'] = df.apply(lambda row: row['neg_url'] if row['clean_image_path'] == row['pos_url'] else row['pos_url'], axis=1)


## Now run make_chatgpt_prompts.py using the above dataset to generate "gpt3.5_prompts.csv"

In [None]:
prompts = df["clean_prompt"].to_list()
chatgpt_prompts = pd.read_csv("gpt3.5_prompts.csv")

In [None]:
df = df.reset_index(drop=True)
df  = df[~df["clean_prompt"].str.contains('person')] #removing all prompts including "person", asking the model to choose between "woman" or "person" produces inconsistent results

In [None]:
df_new = df.drop_duplicates(["clean_prompt"])

In [None]:
df_new["clean_prompt_gpt"] = chatgpt_prompts["GPT-3.5 Prompts"]

In [None]:
#this allows for both positive first and negative first answers
#so this is exactly 14905 * 2
df_with_duplicates = df_new[df_new["clean_prompt"].isin(df["clean_prompt"].values)].drop_duplicates(["clean_prompt", "correct_answer"])


In [None]:
df_with_duplicates["clean_prompt"] = df_with_duplicates["clean_prompt_gpt"]

In [None]:
df_with_duplicates.to_csv("svo_prompts_cleaned.csv", index=False)

# NOTE: since URLs can be a bit unpredictable in the SVO probes dataset, we provide code to save images directly.
## This file will be approximately 30GB. 

In [None]:
from __future__ import annotations
import PIL
from PIL import Image
import requests
from io import BytesIO
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pickle

class ImageDataset(Dataset):
    def __init__(self, urls: list, transform: Callable[[Image.Image], torch.Tensor]) -> None:
        self.urls = urls  # List of image URLs
        self.transform = transform

    def __getitem__(self, i: int) -> Tuple[str, torch.Tensor]:
        url = self.urls[i]
        try:
            response = requests.get(url, timeout=10)  # Fetch the image
            image = Image.open(BytesIO(response.content)).convert("RGB")  # Convert to a PIL image
            image_tensor = self.transform(image)  # Apply the transform to the image
        except Exception as e:
            print(f"Error loading image from {url}: {e}")
            image_tensor = torch.zeros(3, 224, 224)  # Fallback tensor in case of error
        return url, image_tensor

    def __len__(self) -> int:
        return len(self.urls)

def save_image_tensors(image_tensors, filename="image_tensors.pkl"):
    with open(filename, "wb") as f:
        pickle.dump(image_tensors, f)

#This file is provided in the Github of SVO-Probes (https://github.com/google-deepmind/svo_probes)
file_path = 'image_urls.txt'

# Load URLs from a file
good_urls = []
with open(file_path, 'r') as file:
    for line in file:
        url = line.strip()
        good_urls.append(url)

# Define your transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert images to tensor
])

# Create the dataset
image_dataset = ImageDataset(urls=good_urls, transform=transform)
image_data_loader = DataLoader(image_dataset, batch_size=4, num_workers=2, pin_memory=True)

# Container for the image tensors
image_tensors = []

# Process the images and save them, including periodic saving
batch_count = 0
for urls, imgs in image_data_loader:
    batch_count += 1
    for url, img in zip(urls, imgs):
        image_tensors.append((url, img))
    if batch_count % 10 == 0:  # Save every 100 batches
        save_image_tensors(image_tensors, f"image_tensors_batch_{batch_count}.pkl")
        print(f"Saved checkpoint at batch {batch_count}")

# Save the final set of image tensors to a file
save_image_tensors(image_tensors, "image_tensors_final.pkl")
print("Image tensors saved.")