In [None]:
import pandas as pd
import pickle
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import pickle
import torch
from rapidfuzz import process, fuzz
import pandas as pd
from transformers import BlipProcessor
from models_patching import ModifiedBlipForQuestionAnswering
from plotting_utils import *

In [None]:
def generate_matches(row):
    sentence = row['clean_prompt']
    triplets = clean_and_combine_triplets(row['pos_triplet'], row['neg_triplet'])
    inputs = processor.tokenizer(sentence, return_tensors="pt")
    tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    matches = {'[CLS]': 0, '[SEP]': tokens.index('[SEP]')}
    # Exact matches first
    for token in tokens:
        clean_token = token.replace('##', '')
        for triplet in list(triplets):
            if clean_token == triplet:
                matches[triplet] = tokens.index(token)
                triplets.remove(triplet)  # Remove matched triplet to avoid re-matching

    # Special tokens and key words
    special_tokens = {'?': None}
    for key in special_tokens.keys():
        if key in tokens:
            special_tokens[key] = tokens.index(key)
    matches.update(special_tokens)

    # Similarity matching for remaining triplets
    for triplet in triplets:
        options = [(t, i) for i, t in enumerate(tokens) if i not in matches.values()]
        if not options:
            break
        option_tokens, _ = zip(*options)
        clean_options = [t.replace('##', '') for t in option_tokens]
        best_match = process.extractOne(triplet, clean_options, scorer=fuzz.WRatio)
        if best_match and best_match[1] > 55:
            best_match_token = best_match[0]
            best_match_index = next(i for i, t in enumerate(tokens) if t.replace('##', '') == best_match_token and i not in matches.values())
            matches[triplet] = best_match_index
            
    return matches


def extract_selected_one_score(row):
    scores = row['temp_list']["scores"]
    correct_answer = row['correct_answer']
    sentence = row["clean_prompt"]
    inputs = processor.tokenizer(sentence, return_tensors="pt")
    tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    correct_answer_tokens = processor.tokenizer.tokenize(correct_answer)
    
    index_to_extract = None
    for i in range(len(tokens) - len(correct_answer_tokens) + 1):
        if tokens[i:i+len(correct_answer_tokens)] == correct_answer_tokens:
            index_to_extract = i
            break
    
    if index_to_extract is not None:
        selected_scores = scores[index_to_extract]
    else:
        print("no correct answer found")
        selected_scores = None 
    # 0 index for restoration probability, 2 for logit difference. 
    return selected_scores[:, 0] if selected_scores is not None else None


def get_coordinates(flattened_indices, num_rows=12, num_cols=12):
    coordinates = []
    for index in flattened_indices:
        row = index // num_cols
        col = index % num_cols
        coordinates.append((row, col))
    return coordinates


# 2xstd logit difference
def get_important_heads(mrr, threshold):
    # Flatten the array
    
    print(mrr.shape)
    flat_arr = abs(mrr.flatten())
    # Get the indices of the top n elements by magnitude
    top_n_indices_flat = np.where(flat_arr >= threshold)[0]
    
    return get_coordinates(top_n_indices_flat)

def get_circuits(scores):
    num_samples = scores.shape[0]
    avg_by_head = np.mean(abs(np.reshape(scores, (num_samples,-1))), axis=0)
    all_avg = np.mean(avg_by_head)
    threshold = 2*np.std(avg_by_head) + all_avg
    return set(get_important_heads(avg_by_head, threshold))

In [None]:
# Initialize the list to store scores for all samples and heads
all_scores = []

# Load the DataFrame once
df_correct = pd.read_csv(df_filepath)

task='svo_probes'
block_name'text_encoder'
kind='crossattention_block'
mode='image'

# Loop over 12 attention heads
for attn_head in range(12):
    with open(f'BLIP_temp_list_{task}_{mode}_corruption_{block_name}_{kind}_head_{attn_head}_{len(df_correct)}.pkl', 'rb') as file:
        loaded_temp_list = pickle.load(file)
    
    df_correct["temp_list"] = loaded_temp_list
    df_correct['selected_scores'] = df_correct.apply(extract_selected_one_score_emotions, axis=1)
    
    selected_scores = np.array(df_correct['selected_scores'].tolist())
    all_scores.append(selected_scores)

# Convert the list of selected scores to a numpy array
all_scores = np.stack(all_scores, axis=1) 
all_scores = np.nan_to_num(all_scores, nan=0.0, posinf=0.0, neginf=0.0)
all_scores = all_scores.transpose(0, 2, 1)

In [None]:
get_circuits(all_scores)

# VISUALING CROSS ATTENTION FOR IMPORTANT HEADS

In [None]:
import pickle
import torch
from rapidfuzz import process, fuzz
import pandas as pd
from transformers import BlipProcessor
from models_patching import ModifiedBlipForQuestionAnswering
from plotting_utils import *
from PIL import Image
from io import BytesIO
from scipy.ndimage import zoom
import seaborn as sns
from torchvision.transforms.functional import to_pil_image
import io 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ModifiedBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
model.to(device)

processor=BlipProcessor.from_pretrained('Salesforce/blip-vqa-base')

In [None]:
from collections import defaultdict
def get_avg_attention(outputs, toi_index):
    avg = []
    cls_attention = defaultdict(int)
    for layer in range(12):
        for head in range(12):
            # Compute average attention matrix.                       #index of the tuple  #batch_index #j = attention head index
            cross_attention_vals = outputs['cross_attentions'][layer][0][head][toi_index]
            cross_attention_vals = cross_attention_vals.detach().cpu().numpy()[1:]
            cls_attention[str(layer) + "_" + str(head)] = cross_attention_vals[0]
            avg.append(cross_attention_vals)

    return np.mean(np.vstack(avg), axis=0), cls_attention

# Function to create a hook that collects the states
def create_hook(layer_outputs, layer_num):
    def hook(module, input, output):
        layer_outputs[layer_num] = output[1]
    return hook

# Function to register hooks to the specified layers
def register_hooks(model):
    layer_outputs = {}
    for num in range(12):  # from 0 to 12 inclusive
        layer_name = f"text_encoder.encoder.layer.{num}.crossattention.self"
        layer = dict(model.named_modules())[layer_name]
        layer.register_forward_hook(create_hook(layer_outputs, num))
    return layer_outputs

layer_outputs = register_hooks(model)

In [None]:
def visualize_avg_crossattention(outputs, pixel_values, avg_attention, layer, head, toi_index):
        # Compute average attention matrix.                       #index of the tuple  #batch_index #j = attention head index
        # layer, batch_index which is 0, attention head, correct_answer_token. 
        average_attention_matrix_vals = avg_attention #* mask

        # Assuming you have the base image as pixel_values
        pixel_values = pixel_values.squeeze().permute(1, 2, 0).cpu().numpy()

        # Normalize the pixel values to the range [0, 255]
        pixel_values = (pixel_values - pixel_values.min()) / (pixel_values.max() - pixel_values.min()) * 255
        pixel_values = pixel_values.astype(np.uint8)

        # Reshape to a column vector (576x1)
        attention_vector = average_attention_matrix_vals.reshape(-1, 1)

        # Assuming the image is divided into a 24x24 grid of patches of size 16x16
        grid_size = 24
        patch_size = 16

        # Resize attention_vector to match the number of patches
        n_patches = grid_size * grid_size
        resized_attention_vector = zoom(attention_vector, (n_patches / len(attention_vector), 1)).reshape(grid_size, grid_size)

        # Normalize attention values
        patch_attention_normalized = (resized_attention_vector - np.min(resized_attention_vector)) / (np.max(resized_attention_vector) - np.min(resized_attention_vector))

        # Resize the patch attention heatmap to match the base image dimensions
        patch_attention_resized = zoom(patch_attention_normalized, (pixel_values.shape[0] / grid_size, pixel_values.shape[1] / grid_size))

        # Create a figure and a single subplot with high DPI
        fig, ax = plt.subplots(figsize=(8, 8), dpi=200)

        # Display the base image
        ax.imshow(pixel_values, aspect='auto')

        # Create the attention map overlay with the 'seismic' colormap
        sns.heatmap(patch_attention_resized, cmap='seismic', alpha=0.5, ax=ax, zorder=2, cbar=True)

        # Remove the axis labels for a cleaner look
        ax.set_xticks([])
        ax.set_yticks([])

        # Add a title
        ax.set_title(f'Avg. Cross-Attention',fontsize=28)
        #file_path = f'catt_visualization_samples/Minus_AVG_AttentionOverlay_L{i}_H{j}.png'
        #plt.savefig(file_path)

        # Show the plot
        plt.show()


In [None]:
# Loop through different attention matrices
def visualize_crossattention(outputs, pixel_values, avg_attention, layer, head, toi_index):
        # Compute average attention matrix.                       #index of the tuple  #batch_index #j = attention head index
        # layer, batch_index which is 0, attention head, correct_answer_token. 
        average_attention_matrix_vals = outputs[layer][0][head][toi_index]
        average_attention_matrix_vals = average_attention_matrix_vals.detach().cpu().numpy()[1:] - avg_attention #* mask

        # Assuming you have the base image as pixel_values
        pixel_values = pixel_values.squeeze().permute(1, 2, 0).cpu().numpy()

        # Normalize the pixel values to the range [0, 255]
        pixel_values = (pixel_values - pixel_values.min()) / (pixel_values.max() - pixel_values.min()) * 255
        pixel_values = pixel_values.astype(np.uint8)

        # Reshape to a column vector (576x1)
        attention_vector = average_attention_matrix_vals.reshape(-1, 1)

        # Assuming the image is divided into a 24x24 grid of patches of size 16x16
        grid_size = 24
        patch_size = 16

        # Resize attention_vector to match the number of patches
        n_patches = grid_size * grid_size
        resized_attention_vector = zoom(attention_vector, (n_patches / len(attention_vector), 1)).reshape(grid_size, grid_size)

        # Normalize attention values
        patch_attention_normalized = (resized_attention_vector - np.min(resized_attention_vector)) / (np.max(resized_attention_vector) - np.min(resized_attention_vector))

        # Resize the patch attention heatmap to match the base image dimensions
        patch_attention_resized = zoom(patch_attention_normalized, (pixel_values.shape[0] / grid_size, pixel_values.shape[1] / grid_size))

        # Create a figure and a single subplot with high DPI
        fig, ax = plt.subplots(figsize=(8, 8), dpi=200)

        # Display the base image
        ax.imshow(pixel_values, aspect='auto')

        # Create the attention map overlay with the 'seismic' colormap
        sns.heatmap(patch_attention_resized, cmap='seismic', alpha=0.5, ax=ax, zorder=2, cbar=True)

        # Remove the axis labels for a cleaner look
        ax.set_xticks([])
        ax.set_yticks([])

        # Add a title
        ax.set_title(f'L.{layer} H.{head}',fontsize=34)
        #file_path = f'catt_visualization_samples/Minus_AVG_AttentionOverlay_L{i}_H{j}.png'
        #plt.savefig(file_path)

        # Show the plot
        plt.show()


### UNIVERSAL HEADS

In [None]:
universal_heads = {(5,3),(3,0),(0,11)}

### Plot the attention pattern over image patches

In [None]:
# NOTE: IF looking at MIT STATES, use toi_index=3 to get the object. Otherswise it will find the correct adj. 
def get_toi_index(row):
    correct_answer = row['correct_answer']
    sentence = row["clean_prompt"]

    inputs = processor.tokenizer(sentence, return_tensors="pt")
    tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    correct_answer_tokens = processor.tokenizer.tokenize(correct_answer)
    
    index_to_extract = None
    for i in range(len(tokens) - len(correct_answer_tokens) + 1):
        if tokens[i:i+len(correct_answer_tokens)] == correct_answer_tokens:
            print("TOI", tokens[i:i+len(correct_answer_tokens)])
            return i
    return f"NO MATCH FOUND for {correct_answer} in {sentence}" 

def plot_sample(df_sample)
    for idx, row in df_sample.iterrows():
        image_path = row["correct_image_path"]
        image = Image.open(image_path)
        prompt = row['clean_prompt']
        inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
        outputs = model(**inputs, output_attentions=True)

        toi_index = get_toi_index(row)
        avg_attention, cls_attention = get_avg_attention(outputs, toi_index)

        for universal_head in universal_heads: #universal_heads:
            layer, head = universal_head

            visualize_crossattention(layer_outputs, inputs.pixel_values, avg_attention, layer, head, toi_index)
        visualize_avg_crossattention(layer_outputs, inputs.pixel_values, avg_attention, layer, head, toi_index)