In [None]:
#paths and so on
PARQUET_INPUT = r""
ACTIVATIONS_INPUT =r""
OUTPUT_FOLDER = r""

#optional
HUGGINGFACE_CACHE_DIR = None 

# parquet column names
PATCH_IDX = "patchIdx"
FEATURE_IDX = "featureIdx"
IMAGE_IDX = "imageIdx"
ACTIVATION_VALUE = "activationValue"
LABEL = "label"
TYPE = "type"
LAYER_IDX = "layerIdx"

In [None]:
# imports 
import pandas as pd
import torchvision 
import matplotlib.pyplot as plt
from sae.basic_vision_api_call import UserMessage, ImageChatHistory, call_model

from torch.utils.data import Dataset
from datasets import load_dataset
import numpy as np
import os 
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# assuming  OPENAI_API_KEY and OPENAI_ORGANIZATION_ID is set in .env can also manually set it in env or pass it into call_model
from dotenv import load_dotenv
load_dotenv(override=True)
#########

In [None]:

df = pd.read_parquet(PARQUET_INPUT)

print(df.head(10))


In [None]:
default_dataset = load_dataset('Prisma-Multimodal/segmented-imagenet1k-subset', cache_dir=HUGGINGFACE_CACHE_DIR)

In [None]:
class PatchDataset(Dataset):
    def __init__(self, dataset, patch_size=32, width=224, height=224, return_label = True):
        """
        dataset: A list of dictionaries, each dictionary corresponds to an image and its details
        """
        self.dataset = dataset
        self.transform =  torchvision.transforms.Compose([
                        torchvision.transforms.Resize((224, 224)),
                        torchvision.transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Lambda(lambda img: img.permute(1, 2, 0))
                        ])
        self.patch_size = patch_size

        self.width = width
        self.height = height
        self.return_label = return_label
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = self.transform(item['image'])
        if self.return_label:
            masks = item['masks']
            labels = item['labels']  # Assuming labels are aligned with masks
            
            # Calculate the size of the reduced mask
            num_patches = self.width // self.patch_size
            label_array = [[[] for _ in range(num_patches)] for _ in range(num_patches)]
            
            for mask, label in zip(masks, labels):
                # Resize and reduce the mask
                mask = mask.resize((self.width, self.height))
                mask_array = np.array(mask) > 0
                reduced_mask = self.reduce_mask(mask_array)
                
                # Populate the label array based on the reduced mask
                for i in range(num_patches):
                    for j in range(num_patches):
                        if reduced_mask[i, j]:
                            label_array[i][j].append(label)
            
            # Convert label_array to a format suitable for tensor operations, if necessary
            # For now, it's a list of lists of lists, which can be used directly in Python
            
            return image, label_array, idx
        else:
            return image, idx 
    

    def reduce_mask(self, mask):
        """
        Reduce the mask size by dividing it into patches and checking if there's at least
        one True value within each patch.
        """
        # Calculate new height and width
        new_h = mask.shape[0] // self.patch_size
        new_w = mask.shape[1] // self.patch_size
        
        reduced_mask = np.zeros((new_h, new_w), dtype=bool)
        
        for i in range(new_h):
            for j in range(new_w):
                patch = mask[i*self.patch_size:(i+1)*self.patch_size, j*self.patch_size:(j+1)*self.patch_size]
                reduced_mask[i, j] = np.any(patch)  # Set to True if any value in the patch is True
        
        return reduced_mask



patch_label_dataset = PatchDataset(default_dataset['train'], return_label=False)
im, idx = patch_label_dataset[0]



In [None]:




# for a given feature (and 'type'), find all images and heatmaps associated with it. 
def get_overlays_and_images(feature_id, type_label="top", num_patches_w=7, low=0.25, high=0.75, display=False, return_indices=False):
    filtered_df = df[df[FEATURE_IDX] == feature_id]
    
    # Group by both imageIdx and type
    grouped = filtered_df.groupby([IMAGE_IDX, TYPE])
    
    images = []
    overlays = []
    indices = []
    for (image_idx, type_), group in grouped:
        if type_!=type_label:
            continue
        activations = group[ACTIVATION_VALUE].values
        # get heatmap
        heatmap = activations[1:].reshape((num_patches_w,num_patches_w))
        image, _ = patch_label_dataset[int(image_idx)]
        indices.append(int(image_idx))
        image = image.detach().cpu().numpy()
        # normalize
        heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
        # clip
        heatmap[np.logical_and(0.1 <= heatmap, heatmap <= low)] = 0.1
        heatmap[heatmap >= high] = 1
        # upscale
        heatmap = np.repeat(np.repeat(heatmap, 224//num_patches_w, axis=0), 224//num_patches_w, axis=1)
        #create 3rd dim
        heatmap = np.stack([heatmap]*3, axis=-1)
        overlay = image*heatmap

        if display:
            plt.imshow(overlay)
            plt.show()
        images.append(np.uint8(255*image))
        overlays.append(np.uint8(255*overlay))
    if return_indices:
        return indices
    return overlays, images


_,_ = get_overlays_and_images(df[FEATURE_IDX].iloc[0], display=True)

In [None]:

# some hacky code to generate a pdf 
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.platypus import Image as ReportLabImage
from reportlab.lib.styles import getSampleStyleSheet
from PIL import Image 
import io

from reportlab.platypus import PageBreak



def resize_image(image, max_width, max_height):
    # Calculate the new dimensions maintaining the aspect ratio
    width_percent = max_width / float(image.size[0])
    height_percent = max_height / float(image.size[1])
    aspect_ratio = min(width_percent, height_percent)

    # New dimensions
    width = int((float(image.size[0]) * float(aspect_ratio)))
    height = int((float(image.size[1]) * float(aspect_ratio)))

    return image.resize((width, height), Image.ANTIALIAS)
def create_pdf(content_list, file_name='output.pdf'):
    doc = SimpleDocTemplate(file_name, pagesize=letter)
    elements = []
    style_sheet = getSampleStyleSheet()

    #TODO figure out how to do this better.. 
    max_image_width = 456#letter[0] - 50
    max_image_height = 636#letter[1] - 50
    for item in content_list:
        if isinstance(item, str):
            if item == "<pagebreak>":
                elements.append(PageBreak())
                continue
            # Add text
            item = item.replace('\n', '<br/>')

            elements.append(Paragraph(item, style_sheet['BodyText']))
            elements.append(Spacer(1, 12))  # Add space after paragraph

        elif isinstance(item, np.ndarray):
            # Convert numpy array to list of lists and create a table
          #  item = item[:,:,0:3]
            image = Image.fromarray(item.astype('uint8'), 'RGB')
            if image.size[0] > max_image_width or image.size[1] > max_image_height:
                image = resize_image(image, max_image_width, max_image_height)
            image_buffer = io.BytesIO()
            image.save(image_buffer, format='PNG')
            image_buffer.seek(0)
            img = ReportLabImage(image_buffer)
            elements.append(img)

        else:
            raise ValueError("Content list must contain only strings and numpy arrays.")

    doc.build(elements)



In [None]:
all_features = df[FEATURE_IDX].unique()

print(len(all_features))
from io import BytesIO

def plot_grid_of_arrays(arrays, show=True):
    n = len(arrays)
    grid_size_w = int(np.ceil(np.sqrt(n)))
    grid_size_h = int(np.ceil(n/grid_size_w))
    _, axes = plt.subplots(grid_size_h, grid_size_w, figsize=(grid_size_w*2, grid_size_h*2))
    
    for ax, array in zip(axes.flat, arrays):
        ax.imshow(array, cmap='gray')  # Assuming the arrays are 2D grayscale images
        ax.axis('off')

    # Turn off any unused axes
    for ax in axes.flat[n:]:
        ax.axis('off')
        
    plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format='png', transparent=False)
    buf.seek(0)

    image = Image.open(buf)
    image_rgb = image.convert('RGB')  # Convert to RGB
    image_np = np.array(image_rgb)

    buf.close()
    if show:
        plt.show()
    plt.close()

    return image_np

In [None]:

SYSTEM_MESSAGE = "The following are images of a TinyCLIP neuron's activations. We want to determine what the neuron is responding too. Areas of the image that are LESS relevant are darker. \
      Focus on the visible areas. Notice small details and patterns across the images. Respond in a concise, descriptive phrase exactly what it's selecting for, and \
            try not to use words neuron, activation, or image, or focus/selective/etc, because that will be redundant for our purposes. \
            Keep your words simple but be pretty specific and granular. For example, instead of saying 'face,' note the \
              specific parts of a face, like ears, neck, etc."
num_images = 10
num_features = 25 # len(all_features)
pdf_content = []
display_images = []
descriptions = []
final_feature_ids = []
for feature_id in all_features[0:num_features]:
    top_overlays, top_images = get_overlays_and_images(feature_id)

    chat_history = ImageChatHistory()

    chat_history.add_system_msg(SYSTEM_MESSAGE)


    user_message_with_images = UserMessage()

    user_message_with_images.add_text("Here are the images:")

    for i in range(num_images):
        user_message_with_images.add_img_array(top_overlays[i])

    chat_history.add_user_msg(user_message_with_images)


    display_image = plot_grid_of_arrays(top_overlays[0:num_images])

    description = call_model(chat_history, model="gpt-4o")
    strr = f"----------------\n{feature_id}\n---------------\n=====Model output:::\n{description}\n"
    print(strr)
    display_images.append(display_image)
    descriptions.append(description)
    final_feature_ids.append(feature_id)
    pdf_content.append(display_image)
    pdf_content.append(strr)



In [None]:
# lazy hack to deal with openai errors 
to_delete = []
for i in range(1,50,2):
   # print(pdf_content[i])
    feature_id = all_features[int((i-1)/2)]
    if "'message':" in pdf_content[i]:
        to_delete.append(i-1)
        to_delete.append(i)

to_delete = to_delete[::-1]

for thing in to_delete:
    del pdf_content[thing]
    del display_images[thing]
    del descriptions[thing]
    del final_feature_ids[thing]
       # pdf_content[i] = strr



In [None]:
print(len(descriptions))

In [None]:
#sample


from scipy.stats import pearsonr



def get_image(image_idx):
    
    image, _ = patch_label_dataset[int(image_idx)]
    image = image.detach().cpu().numpy()
    image = np.uint8(255*image)
    return  image

import random
import torch 
import json 
all_acts = torch.load(ACTIVATIONS_INPUT)
pdf_content_ranking = []
arrs = []
texts = []
scores = []
assert len(final_feature_ids) == num_features, "haven't set up to handle case where need to delete things above and I'm being lazy"
for i in range(len(final_feature_ids)):
    acts = all_acts[:, i]

    acts, act_indices = torch.sort(acts, descending=True)
    p_vals = []
    three_scores = []
    total_score = 0

    for _ in range(3):

        #3 samples from top 10
        size = 1
        s = torch.randperm(10)[:3*size]
        sample_inds = act_indices[0:10][s].tolist()
        sample_acts = acts[0:10][s].tolist()
        # 1 samples from top n to n + 10 for n = 10,20...
  
        for ii in range(12):
            s = torch.randperm(10)[:1*size]
            sample_inds = sample_inds + act_indices[10*(ii+1):10*(ii+1)+10][s].tolist()
            sample_acts = sample_acts + acts[10*(ii+1):10*(ii+1)+10][s].tolist()

        # 5 random 
        s =    torch.randint(0, acts.size(0), (5*size,))  
        sample_inds = sample_inds + act_indices[s].tolist()
        sample_acts = sample_acts + acts[s].tolist()


        sorted_pairs = sorted(zip(sample_inds, sample_acts), key=lambda x: x[1],reverse=True)
        sample_inds, sample_acts = zip(*sorted_pairs)
        sample_inds = list(sample_inds)
        sample_acts = list(sample_acts)
        print(sample_inds, sample_acts)

        images = [get_image(si) for si in sample_inds]
        arr = plot_grid_of_arrays(images, show=True)

        SYSTEM_MESSAGE = f"The following are {len(images)} images of a TinyCLIP neuron's activations. The neuron was described as activating in response to the following stimulus: {descriptions[i]}.\
        Your job is to rank the images from highest activating to lowest. Each image has a name, A, B, C... Format your answer as a json dict with two keys. Under key 'info',\
            describe your thought process. Under key 'ranking' make a list of the names in order (example ['C', 'B', 'A']). Do not put comments in your json."


   

        all_names =  "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[0:len(images)]#obvious flaw if too many examples given haha

        gt_names = ["ABCDEFGHIJKLMNOPQRSTUVWXYZ"[j] for j in range(len(images))] 
        gt_ranks = [j for j in range(len(images))]
        random.shuffle(gt_names)


        

        chat_history = ImageChatHistory()
        user_message_with_images = UserMessage()
        chat_history.add_system_msg(SYSTEM_MESSAGE)
        for k in range(len(images)):
            name = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[k]

            cur_img_ind = gt_ranks[gt_names.index(name)]
            image= images[cur_img_ind]

            user_message_with_images.add_text(name)
            user_message_with_images.add_img_array(image)
        user_message_with_images.add_text("remember to format your answer in json")
        chat_history.add_user_msg(user_message_with_images)
        output = call_model(chat_history, model="gpt-4o")
        print(output)
        try:
            try:
                output = json.loads(output)
            except:
                # remove chatgpts format ```json\n...\b```
                output = output.replace("```json","").replace("```","").strip()
                output = json.loads(output)
            explanation = output["info"]
            pred_rank_names = output["ranking"]
            pred_ranks = [pred_rank_names.index(name) for name in gt_names]
        except:
            pred_ranks = None 
            explanation = None


        if pred_ranks is not None:
            score = pearsonr([z+1 for z in gt_ranks], [z+1 for z in pred_ranks])
            r_val, p_val = score[0], score[1]
            score = r_val 
            p_vals.append(p_val)
            three_scores.append(score)
        else:
            total_score = 0
            three_scores= "ERROR"
            p_vals = "ERROR"
            explanation = "ERROR"
            break 
        total_score = score + total_score
    total_score = total_score/3
    scores.append(total_score)
    print(three_scores)
    arrs.append(arr)

    scores_text = ', '.join(f"{x:.2f}" for x in three_scores) if type(three_scores) == list else "ERROR"
    p_vals_text = ', '.join(f"{x:.4f}" for x in p_vals) if type(p_vals) == list else "ERROR"
    text =  (
          f"DESCRIPTION BEING JUDGED: {descriptions[i]}\n"
          f"SCORE (avg over three runs): {total_score:.3f}\n"
          f"(all scores for three runs: {scores_text})\n"
          f"PVALUE (three runs) {p_vals_text}\n"
          f"Model justification (final run):::\n{explanation}\n"
          f"Model predicted order (final run) {pred_ranks}\n"
          f"Images are what was provided to final run")
    texts.append(text)

    print(text)
 
#print(scores, texts, arrs)
sorted_lists = zip(*sorted(zip(scores, texts, arrs), reverse=True))
sscores, sdescriptions, sarrs = [list(tup) for tup in sorted_lists]

for d, a in zip(sdescriptions, sarrs):
    pdf_content_ranking.append(d)
    pdf_content_ranking.append(a)



In [None]:



result = []
for i in range(0, len(pdf_content_ranking), 2):
    result.extend(pdf_content_ranking[i:i+2])  # Add the next two items
    result.append('<pagebreak>')

create_pdf(pdf_content, file_name=os.path.join(OUTPUT_FOLDER, "gpt4o_autointerp.pdf"))
create_pdf(result, file_name=os.path.join(OUTPUT_FOLDER, "gpt4o_autointerp_ranking.pdf"))
