In [None]:
#final version

import os
import gc
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, ToPILImage
from torchvision.datasets.folder import default_loader

# ------------------------------- Dataset & Transform -------------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class TinyImageNetValDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_paths = sorted([
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.endswith(".JPEG")
        ])
        self.transform = transform
        self.loader = default_loader

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = self.loader(self.image_paths[idx])
        if self.transform:
            image = self.transform(image)
        return image, self.image_paths[idx]

# ------------------------------- Hook & Patch Extraction -------------------------------
def get_activation(name, activations):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

def get_top_patches(activation_list, image_paths, dataset, top_k=100):
    neuron_patches = {}
    for neuron_idx in range(activation_list[0].shape[0]):
        activations = [(act[neuron_idx].max().item(), i) for i, act in enumerate(activation_list)]
        top_k_indices = sorted(activations, key=lambda x: x[0], reverse=True)[:top_k]
        patches = []
        for _, img_idx in top_k_indices:
            img = Image.open(image_paths[img_idx]).convert('RGB')
            img = transform(img)
            act_map = activation_list[img_idx][neuron_idx]
            y, x = np.unravel_index(torch.argmax(act_map).item(), act_map.shape)
            scale = dataset[0][0].shape[1] / act_map.shape[1]
            x1, y1 = int(x * scale), int(y * scale)
            x2, y2 = x1 + int(scale), y1 + int(scale)
            x1 = max(0, min(x1, img.shape[2]))
            x2 = max(0, min(x2, img.shape[2]))
            y1 = max(0, min(y1, img.shape[1]))
            y2 = max(0, min(y2, img.shape[1]))
            patch = TF.crop(img, y1, x1, y2 - y1, x2 - x1)
            patches.append(patch)
        neuron_patches[neuron_idx] = patches
    return neuron_patches

# ------------------------------- Neuron Feature Computation -------------------------------
def compute_neuron_features(neuron_patches, activation_list, top_k=100):
    neuron_features = {}
    for neuron_idx, patches in neuron_patches.items():
        if len(patches) == 0:
            continue
        patch_tensors = torch.stack([p if isinstance(p, torch.Tensor) else ToTensor()(p) for p in patches])
        activations = [activation_list[i][neuron_idx].max().item() for i in range(top_k)]
        norm_activations = torch.tensor(activations) / sum(activations)
        nf = (patch_tensors * norm_activations[:, None, None, None]).sum(dim=0)
        neuron_features[neuron_idx] = nf
    return neuron_features

def save_neuron_features(neuron_features, out_dir, max_to_save=100):
    os.makedirs(out_dir, exist_ok=True)
    for i, (neuron_idx, nf_tensor) in enumerate(neuron_features.items()):
        if i >= max_to_save:
            break
        nf_tensor = nf_tensor.cpu().clamp(0, 1)
        img = ToPILImage()(nf_tensor)
        img.save(os.path.join(out_dir, f"neuron_{neuron_idx}.png"))
        del nf_tensor, img
        gc.collect()

def create_neuron_grid(neuron_features, grid_size, out_path):
    fig, axes = plt.subplots(*grid_size, figsize=(20, 20))
    for i, (neuron_idx, nf_tensor) in enumerate(neuron_features.items()):
        if i >= grid_size[0] * grid_size[1]:
            break
        nf_tensor = nf_tensor.cpu().clamp(0, 1)
        axes.flat[i].imshow(ToPILImage()(nf_tensor))
        axes.flat[i].axis('off')
        axes.flat[i].set_title(f"Neuron {neuron_idx}", fontsize=8)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

# ------------------------------- CSI Computation -------------------------------
def rgb2opp(img_np):
    R, G, B = img_np[..., 0], img_np[..., 1], img_np[..., 2]
    O1 = (R - G) / np.sqrt(2)
    O2 = (R + G - 2 * B) / np.sqrt(6)
    O3 = (R + G + B) / np.sqrt(3)
    return np.stack([O1, O2, O3], axis=-1)

def image2max_gray(opp_img):
    return np.max(np.abs(opp_img), axis=-1)

def compute_csi_for_neuron(neuron_idx, activation_list, image_paths):
    rgb_activations, gray_activations = [], []
    for i, act in enumerate(activation_list):
        act_map = act[neuron_idx]
        rgb_activations.append(act_map.max().item())
        img_np = np.array(Image.open(image_paths[i]).resize((224, 224))).astype(np.float32) / 255.
        im_gray = image2max_gray(rgb2opp(img_np))
        gray_activations.append(np.max(im_gray))
    rgb_activations, gray_activations = np.array(rgb_activations), np.array(gray_activations)
    if np.sum(rgb_activations) == 0:
        return 0.0
    norm_rgb = rgb_activations / np.max(rgb_activations)
    norm_gray = gray_activations / np.max(rgb_activations)
    return np.mean(1 - np.clip(norm_gray / norm_rgb, 0, 1))

# ------------------------------- Run for Layer 0 Only -------------------------------
base_dir = "/Users/charlotteimbert/Documents/SP2025/NEUR189B/tiny-imagenet-200/val"
image_dirs = sorted([os.path.join(base_dir, d) for d in os.listdir(base_dir)
                     if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("images")])
target_layers = [0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
combined_csi_rows = []

for image_dir in tqdm(image_dirs, desc="Image folders"):
    for target_layer in target_layers:
        print(f"Processing {image_dir}, layer {target_layer}...")
        vgg16 = models.vgg16(pretrained=True).eval().to(device)
        layer_name = f"features.{target_layer}"
        activations = {}
        vgg16.features[target_layer].register_forward_hook(get_activation(layer_name, activations))

        dataset = TinyImageNetValDataset(image_dir, transform=transform)
        loader = DataLoader(dataset, batch_size=16, shuffle=False)

        activation_list, image_paths = [], []
        with torch.no_grad():
            for imgs, paths in tqdm(loader, desc=f"{os.path.basename(image_dir)} L{target_layer}"):
                _ = vgg16(imgs.to(device))
                activation_list.extend(activations[layer_name].cpu())
                image_paths.extend(paths)

        patches = get_top_patches(activation_list, image_paths, dataset)
        features = compute_neuron_features(patches, activation_list)

        nf_dir = f"neuron_features/layer{target_layer}_{os.path.basename(image_dir)}"
        save_neuron_features(features, nf_dir, max_to_save=100)
        create_neuron_grid(features, (10,10), os.path.join(nf_dir, "neuron_grid.png"))

        for neuron_idx in range(min(100, activation_list[0].shape[0])):
            csi = compute_csi_for_neuron(neuron_idx, activation_list, image_paths)
            combined_csi_rows.append({
                "image_folder": os.path.basename(image_dir),
                "layer_idx": target_layer,
                "neuron_idx": neuron_idx,
                "CSI": csi
            })

# Save final CSV
pd.DataFrame(combined_csi_rows).to_csv("all_csi_results.csv", index=False)
print("\n✅ Done. All neuron features and CSI values saved for layer 0.")



Processing /Users/charlotteimbert/Documents/SP2025/NEUR189B/tiny-imagenet-200/val/images, layer 0...



images L0:   0%|                                        | 0/625 [00:00<?, ?it/s][A
images L0:   0%|                                | 1/625 [00:01<13:12,  1.27s/it][A
images L0:   0%|                                | 2/625 [00:02<13:02,  1.26s/it][A
images L0:   0%|▏                               | 3/625 [00:03<12:54,  1.24s/it][A
images L0:   1%|▏                               | 4/625 [00:04<12:50,  1.24s/it][A
images L0:   1%|▎                               | 5/625 [00:06<12:53,  1.25s/it][A
images L0:   1%|▎                               | 6/625 [00:07<12:56,  1.25s/it][A
images L0:   1%|▎                               | 7/625 [00:08<12:51,  1.25s/it][A
images L0:   1%|▍                               | 8/625 [00:10<12:53,  1.25s/it][A
images L0:   1%|▍                               | 9/625 [00:11<12:55,  1.26s/it][A
images L0:   2%|▍                              | 10/625 [00:12<12:56,  1.26s/it][A
images L0:   2%|▌                              | 11/625 [00:13<12:54,  1.26