In [None]:
import pandas as pd
import numpy as np
import torch

import torch
from models_patching import *
from transformers import BlipProcessor

In [None]:
df = pd.read_csv("BLIP_results_svo_probes.csv")

In [None]:
df['correct_answer'] = df['correct_answer'].str.lower()
df['generated_text'] = df['generated_text'].str.lower()

In [None]:
df = df[df["generated_text"] == df["correct_answer"]]
df = df[df['clean_prompt'].str.contains(r'\bor\b', case=False, na=False)]
df['clean_prompt'] = df['clean_prompt'].apply(lambda x: x if x.endswith('?') else x + '?')

In [None]:
import pandas as pd
import numpy as np


categories = ['subj', 'obj', 'verb']
conditions = [('pos_url', 'clean_image_path'), ('neg_url', 'corrupt_image_path')] 

# Initialize an empty list to hold dataframes for each condition
dfs_for_concat = []

for category in categories:
    for pos_cond, mc_pos_cond in conditions:
        # Filter based on category negativity and matching conditions
        filtered_df = df[(df[category + "_neg"] == True) & (df[pos_cond] == df[mc_pos_cond])]
        print(len(filtered_df))
        # Select exactly 500 rows
        if len(filtered_df) > 500:
            selected_df = filtered_df.sample(n=500, random_state=1)  # Random state for reproducibility
        elif len(filtered_df) == 500:
            selected_df = filtered_df
        else:
            print("NOT ENOUGH SAMPLES")
        
        # Add the selected_df to the list for concatenation
        dfs_for_concat.append(selected_df)

# Concatenate all the selected DataFrames into one
final_df = pd.concat(dfs_for_concat).reset_index(drop=True)


In [None]:
import random
import ast
from rapidfuzz import process, fuzz


def find_token_index(token, sentence):
    words = sentence.split()
    # Fuzzy match the token against all words in the sentence
    best_match = process.extractOne(token, words, scorer=fuzz.WRatio)

    # If a match with a high enough score is found, return its index
    if best_match and best_match[1] > 60:  # Threshold can be adjusted
        return words.index(best_match[0])

    return None

# Your provided functions for parsing triplets
def parse_triplets(triplets: str): #-> Sequence[Tuple[str, str, str]]:
    if triplets.startswith("["):
        return [tuple(triplet.split(",")) for triplet in ast.literal_eval(triplets)]
    else:
        return [tuple(triplets.split(","))]

def get_first_triplet(triplets): #Sequence[Tuple[str, str, str]]):
    return next(iter(triplets), ("", "", ""))


def swap_token(vocab, svo_idx, banned_words):
    # vocab is the set of all S-V-O elements that appear in the dataset. 
    # svo_idx is an integer that determines if we corrupt a subject=0, verb=1 or object=2
    # banned_words is the set of words that we cannot sample from. 
    # Suppose we want to corrupt the tokens T1 and T2 as C1 and C2, respectively 
    # for T1, banned_words=[T1,T2]
    # for T2, banned_words=[T1,T2,C1]
    # we are also going to add the other two elements of SVO to make sure we don't have 
    # prompts where the swapped T1 token is equal to the object token. 
    return random.choice([word for word in vocab[svo_idx] if word not in banned_words])


def extract_svo_vocab(df, column):
    s = set()
    v = set()
    o = set()
    vocab = {0: s, 1: v, 2: o}

    for row in df[column]:
        words = row.split(',')
        # Some neg triplets were weird and had more than 3 words.
        if len(words) > 3:
            continue 
        for idx, word in enumerate(words):
            # some neg triplets has nested lists/strings
            if "'" in word or word=="person":
                continue
            else:
                vocab[idx].add(word)
    return vocab

pos_vocab = extract_svo_vocab(df, "pos_triplet")
neg_vocab = extract_svo_vocab(df, "neg_triplet")

vocab = {i: neg_vocab[i].union(pos_vocab[i]) for i in range(3)}

def corrupt_prompt(row):
    pos_s, pos_v, pos_o = get_first_triplet(parse_triplets(row['pos_triplet']))
    neg_s, neg_v, neg_o = get_first_triplet(parse_triplets(row['neg_triplet']))
    if row['subj_neg']:
        prompt = row["clean_prompt"]
        if pd.isna(prompt):
            return "NO PROMPT MADE" 
        
        pos_idx = find_token_index(pos_s, prompt)
        if pos_idx == None:
            print("{} NOT IN SENTENCE".format(pos_s))
            return "NO PROMPT MADE" 
 
        neg_idx = find_token_index(neg_s, prompt)
        if neg_idx == None:
            print(row)
            print("{} NOT IN SENTENCE".format(neg_s))
            return "NO PROMPT MADE" 
            
        ls_prompt = prompt.split(' ')
        corrupt_1 = swap_token(vocab, svo_idx=0, banned_words=[pos_s, pos_v, pos_o, neg_s])
        ls_prompt[pos_idx] = corrupt_1
        ls_prompt[neg_idx] = swap_token(vocab, svo_idx=0, banned_words=[pos_s, neg_s, pos_v, pos_o, corrupt_1]) 
        prompt = ' '.join(ls_prompt)
        
    elif row['verb_neg']:
        prompt = row["clean_prompt"]
        if pd.isna(prompt):
            return "NO PROMPT MADE" 
        
        pos_idx = find_token_index(pos_v, prompt)
        if pos_idx == None:
            print("{} NOT IN SENTENCE".format(pos_v))
            return "NO PROMPT MADE" 
 
        neg_idx = find_token_index(neg_v, prompt)
        if neg_idx == None:
            print(row)
            print("{} NOT IN SENTENCE".format(neg_v))
            return "NO PROMPT MADE" 
        
        ls_prompt = prompt.split(' ')
        corrupt_1 = swap_token(vocab, svo_idx=1, banned_words=[pos_s, pos_v, pos_o, neg_v])
        ls_prompt[pos_idx] = corrupt_1
        ls_prompt[neg_idx] = swap_token(vocab, svo_idx=1, banned_words=[pos_s, pos_v, pos_o, neg_v, corrupt_1]) 
        prompt = ' '.join(ls_prompt)
        
    elif row['obj_neg']:
        prompt = row["clean_prompt"]
        if pd.isna(prompt):
            return "NO PROMPT MADE" 
        
        pos_idx = find_token_index(pos_o, prompt)
        if pos_idx == None:
            print("{} NOT IN SENTENCE".format(pos_o))
            return "NO PROMPT MADE" 
 
        neg_idx = find_token_index(neg_o, prompt)
        if neg_idx == None:
            print(row)
            print("{} NOT IN SENTENCE".format(neg_o))
            return "NO PROMPT MADE" 
        
        #neg_idx = find_token_index(neg_o, sentence)
        corrupt_1 = swap_token(vocab, svo_idx=2, banned_words=[pos_s, pos_v, pos_o, neg_o])
        ls_prompt = prompt.split(' ')
        ls_prompt[pos_idx] = corrupt_1
        ls_prompt[neg_idx] = swap_token(vocab, svo_idx=2, banned_words=[pos_s, pos_v, pos_o, neg_v, corrupt_1]) 
        prompt = ' '.join(ls_prompt)
    else:
        prompt = "No negation found"
    return prompt

In [None]:
final_df['corrupt_prompt'] = df.apply(lambda row: corrupt_prompt(row), axis=1, result_type='expand')

In [None]:
final_df = final_df[final_df['corrupt_prompt'] != "NO PROMPT MADE"]

In [None]:
from pattern.en import conjugate, lemma, PRESENT, PARTICIPLE
from rapidfuzz import process, fuzz

def convert_to_ing_form(verb):
    # Convert the verb to its base form
    base_form = lemma(verb)
    # Conjugate the verb to its -ing form
    ing_form = conjugate(base_form, tense=PRESENT, aspect=PARTICIPLE)
    return ing_form

def find_and_convert_verbs(sentence, verb_neg):
    # Only proceed if verb_neg is True
    if not verb_neg:
        return sentence
    
    words = sentence.split()
    
    try:
        if words[0] == "Does":
            # "Does" indicates the sentence is already in a suitable form
            return sentence
        
        # Find the index of 'or' to identify the verbs around it
        or_index = words.index('or')
        verb_before_or = words[or_index - 1]
        verb_after_or = words[or_index + 1]

        # Convert verbs to -ing form
        ing_before_or = convert_to_ing_form(verb_before_or)
        ing_after_or = convert_to_ing_form(verb_after_or)

        if ing_before_or:
            sentence = sentence.replace(verb_before_or, ing_before_or, 1)
        else:
            print("verb not found for: " + verb_before_or)

        if ing_after_or:
            sentence = sentence.replace(verb_after_or, ing_after_or, 1)
        else:
            print("verb not found for: " + verb_after_or)

    except ValueError as e:
        # 'or' not found in the sentence, or other processing error
        print(f"Error processing sentence '{sentence}': {e}")

    return sentence

final_df['corrupt_prompt'] = final_df.apply(lambda row: find_and_convert_verbs(row['corrupt_prompt'], row['verb_neg']), axis=1)


In [None]:
final_df.to_csv("BLIP_final_svo_probes.csv", index=False)