In [None]:
# load the pre-trained VGG16 and specify which layer we'll extract feature activations from
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets.folder import default_loader
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torchvision.transforms.functional as TF
from IPython.display import display, Image
from tqdm import tqdm
import os
import pandas as pd
from PIL import Image
import gc
from torchvision.transforms import ToTensor, ToPILImage
import matplotlib.pyplot as plt

# Load pretrained VGG16
vgg16 = models.vgg16(pretrained=True).eval()

# Select the layer from which we want to extract neuron activations (e.g., conv5_3)
target_layer = 0  # conv1_1

# Truncate model at selected layer
vgg16_truncated = nn.Sequential(*list(vgg16.features.children())[:target_layer + 1])

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg16_truncated = vgg16_truncated.to(device)

# transform tinyimagenet validation images so that they are in the correct shape and size for VGG16; the class allows the transformation to be applied to all 10,000 images

# Define transform: resize to 224x224 (VGG16 input size), convert to tensor, normalize
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to VGG16 input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet mean
                         std=[0.229, 0.224, 0.225])   # ImageNet std
])

# Custom Dataset to load images from a folder
class TinyImageNetValDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = sorted([
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.endswith(".JPEG")
        ])
        self.transform = transform
        self.loader = default_loader

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = self.loader(self.image_paths[idx])
        if self.transform:
            image = self.transform(image)
        return image, self.image_paths[idx]  # return path for identification

# Set path to your TinyImageNet val/images folder
image_dir = "/Users/charlotteimbert/Documents/SP2025/NEUR189B/tiny-imagenet-200/tiny-tiny/images"

# Create dataset and dataloader
dataset = TinyImageNetValDataset(image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

# create a hook to extract activations from neurons in specific layer; hook captures the layer outputs
# Load pretrained VGG16
vgg16 = models.vgg16(pretrained=True).eval().to(device)

# Choose layer name, e.g., last convolutional layer
layer_name = 'features.0'
activations = {}

def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

# Register the hook
layer = dict([*vgg16.named_modules()])[layer_name]
layer.register_forward_hook(get_activation(layer_name))

# pass val images through the model and get activations
activation_list = []
image_paths = []

# Adjust batch size as needed
loader = DataLoader(dataset, batch_size=16, shuffle=False)

with torch.no_grad():
    for images, paths in tqdm(loader):
        images = images.to(device)
        _ = vgg16(images)  # forward pass
        batch_acts = activations[layer_name].cpu()  # shape: [B, C, H, W]
        
        activation_list.extend(batch_acts)
        image_paths.extend(paths)
        
def get_top_patches(activation_list, image_paths, dataset, top_k=100):
    neuron_patches = {}  # dict: neuron_idx -> list of patches

    for neuron_idx in range(activation_list[0].shape[0]):
        activations = []
        for i, act in enumerate(activation_list):
            max_val = act[neuron_idx].max().item()
            activations.append((max_val, i))

        # Sort and get top_k
        top_k_indices = sorted(activations, key=lambda x: x[0], reverse=True)[:top_k]

        patches = []
        for _, img_idx in top_k_indices:
            img_path = image_paths[img_idx]
            original_img = Image.open(img_path).convert('RGB')
            original_img = transform(original_img)  # same transform as dataset
            patch_size = act.shape[1:]  # (H, W) of feature map
            scale_factor = dataset[0][0].shape[1] / patch_size[1]  # image / activation width

            # Get max location (x, y)
            act_map = activation_list[img_idx][neuron_idx]
            y, x = np.unravel_index(torch.argmax(act_map).item(), act_map.shape)

            # Convert to pixel coordinates
            x1 = int(x * scale_factor)
            y1 = int(y * scale_factor)
            x2 = x1 + int(scale_factor)
            y2 = y1 + int(scale_factor)

            # Clip and crop
            x1, y1, x2, y2 = map(lambda v: max(0, min(v, original_img.shape[1])), [x1, y1, x2, y2])
            patch = TF.crop(original_img, y1, x1, y2 - y1, x2 - x1)
            patches.append(patch)

        neuron_patches[neuron_idx] = patches

    return neuron_patches
neuron_patches = get_top_patches(activation_list, image_paths, dataset)

def compute_neuron_features(neuron_patches, activation_list, top_k=100):
    neuron_features = {}

    for neuron_idx, patches in neuron_patches.items():
        if len(patches) == 0:
            continue

        # Convert patches to tensors and stack
        patch_tensors = torch.stack([
            p if isinstance(p, torch.Tensor) else ToTensor()(p)
            for p in patches
        ])

        # Optionally use activations to weight patches
        activations = []
        for img_idx in range(top_k):
            act_map = activation_list[img_idx][neuron_idx]
            max_val = act_map.max().item()
            activations.append(max_val)

        activations = torch.tensor(activations)
        norm_activations = activations / activations.sum()

        # Compute weighted average (neuron feature)
        nf = (patch_tensors * norm_activations[:, None, None, None]).sum(dim=0)  # shape: (3, H, W)
        neuron_features[neuron_idx] = nf

    return neuron_features

neuron_features = compute_neuron_features(neuron_patches, activation_list)

def save_neuron_features(neuron_features, out_dir="neuron_features", max_to_save=None):
    os.makedirs(out_dir, exist_ok=True)
    for i, (neuron_idx, nf_tensor) in enumerate(neuron_features.items()):
        if max_to_save is not None and i >= max_to_save:
            break
        nf_tensor = nf_tensor.cpu().clamp(0, 1)  # move to CPU and clip
        img = ToPILImage()(nf_tensor)
        img.save(os.path.join(out_dir, f"neuron_{neuron_idx}.png"))
        del nf_tensor, img
        gc.collect()

save_neuron_features(neuron_features, max_to_save=100)

image_dir = "neuron_features"

def create_neuron_grid(neuron_features, grid_size=(10, 10), out_path="neuron_grid.png"):
    fig, axes = plt.subplots(*grid_size, figsize=(20, 20))
    axes = axes.flatten()

    for i, (neuron_idx, nf_tensor) in enumerate(neuron_features.items()):
        if i >= grid_size[0] * grid_size[1]:
            break
        nf_tensor = nf_tensor.cpu().clamp(0, 1)
        img = ToPILImage()(nf_tensor)
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(f"Neuron {neuron_idx}", fontsize=8)

    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print(f"Saved neuron grid to {out_path}")
create_neuron_grid(neuron_features, grid_size=(10, 10), out_path="neuron_grid_original_conv1_1.png")

# ---------- Helper Functions from Nefesi ----------

def rgb2opp(img_np):
    """Convert RGB image to opponent color space (approximate)."""
    R, G, B = img_np[..., 0], img_np[..., 1], img_np[..., 2]
    O1 = (R - G) / np.sqrt(2)
    O2 = (R + G - 2 * B) / np.sqrt(6)
    O3 = (R + G + B) / np.sqrt(3)  # Intensity channel
    return np.stack([O1, O2, O3], axis=-1)

def image2max_gray(opp_img):
    """Collapse color channels to max projection of absolute value."""
    return np.max(np.abs(opp_img), axis=-1)

# ---------- Main CSI Function ----------

def compute_csi_for_neuron(neuron_idx, activation_list, image_paths, transform):
    grayscale_activations = []
    rgb_activations = []

    for i, act in enumerate(activation_list):
        act_map = act[neuron_idx]  # shape: [H, W]
        max_val = act_map.max().item()
        rgb_activations.append(max_val)

        # Load image and convert to numpy
        img = Image.open(image_paths[i]).convert('RGB')
        img_resized = img.resize((224, 224))
        img_np = np.array(img_resized).astype(np.float32) / 255.0

        # Convert to grayscale using opponent space
        im_opp = rgb2opp(img_np)
        im_gray = image2max_gray(im_opp)

        # Simulated grayscale activation: just max value in gray image
        gray_max_val = np.max(im_gray)
        grayscale_activations.append(gray_max_val)

    rgb_activations = np.array(rgb_activations)
    gray_activations = np.array(grayscale_activations)

    if np.sum(rgb_activations) == 0:
        return 0.0

    norm_rgb = rgb_activations / np.max(rgb_activations)
    norm_gray = gray_activations / np.max(rgb_activations)

    csi = np.mean(1 - np.clip(norm_gray / norm_rgb, 0, 1))
    return csi

csi_values = []
num_neurons = activation_list[0].shape[0]

print("Computing CSI for first 100 neurons...")
for neuron_idx in tqdm(range(min(100, num_neurons))):
    csi = compute_csi_for_neuron(neuron_idx, activation_list, image_paths, transform)
    csi_values.append((neuron_idx, csi))


df_csi = pd.DataFrame(csi_values, columns=["neuron_idx", "CSI"])
df_csi.to_csv("color_selectivity_indices.csv", index=False)


100%|███████████████████████████████████████████| 63/63 [01:20<00:00,  1.28s/it]


Saved neuron grid to neuron_grid_original_conv1_1.png
Computing CSI for first 100 neurons...


  csi = np.mean(1 - np.clip(norm_gray / norm_rgb, 0, 1))
 30%|████████████▊                              | 19/64 [00:25<01:01,  1.36s/it]