In [None]:
import re
import os
import gc

import numpy as np
import torch
from torchvision import models, datasets, transforms

import config as config

import replica_correlations as rep

device = torch.device('cuda')

weights = torch.load(f"{config.model_path}model_final_checkpoint_phase999.torch")
trunk = weights['classy_state_dict']['base_model']['model']['trunk']
trunk = {re.sub('_feature_blocks\.', '', key) : val for key, val in trunk.items()}
dummy_weight = torch.rand((1000, 2048))
dummy_bias = torch.rand((1000, ))
trunk['fc.weight'] = dummy_weight
trunk['fc.bias'] = dummy_bias
model = models.resnet50()
model.load_state_dict(trunk)
model = model.to(device)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
test_trnsfrm = transforms.Compose([transforms.Resize(256),
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(), 
                                  # transforms.ToDtype(torch.float32, scale=True),
                                  transforms.Normalize(mean, std)
                                 ])


from data_util import getDatasetFromLabel, getPILDatasetFromLabel
from torch.utils.data import DataLoader
from torchvision.models.feature_extraction import create_feature_extractor

embeding_dim = 4000
kappa = 0    # Specify the margin (usually 0) 
n_t = 200    # Specify the number of Gaussian vectors to sample (200 or 300 is a good default)

c, r, d = [], [], []

node_names = ['x', 'avgpool']#['x', 'layer1.0.relu', 'layer1.1.relu_2', 'layer2.0.relu_1', 'layer2.2.relu', 'layer2.3.relu_2', 'layer3.1.relu_1', 'layer3.3.relu', 'layer3.4.relu_2', 'layer4.0.relu_1', 'layer4.1.relu_1', 'layer4.2.relu', 'layer4.2.relu_1', 'layer4.2.relu_2', 'avgpool']
category_folder = os.listdir(config.imagenet_path)

for node in node_names:
    extractor = create_feature_extractor(model, return_nodes=[node])
    features = []
    for label in category_folder[0:20]:
        dataset = getPILDatasetFromLabel(label, top=45, transform=test_trnsfrm)
        dataloader = DataLoader(dataset, batch_size=128, shuffle=False)
        for batch in dataloader:
            with torch.no_grad():
                feature = extractor(batch.to(device))
                feature = {key: torch.flatten(val.detach(), start_dim=1).cpu().numpy() for key, val in feature.items()}
                if not hasattr(extractor, 'random_index'):
                    random_index = {}
                    for key, val in feature.items():
                        print(f"{key}: {val.shape[1]}")
                        if val.shape[1] > embeding_dim:
                            M = np.random.randn(embeding_dim, val.shape[1])
                            M /= np.sqrt(np.sum(M*M, axis=1, keepdims=True))
                            random_index[key] = M
                        else:
                            random_index[key] = np.eye(val.shape[1])
                    extractor.random_index = random_index
                feature = {key: val @ extractor.random_index[key].T for key, val in feature.items()}
        features.append(feature)
    # alpha, radius, dimension = manifold_analysis([feature[node].T for feature in features], kappa, n_t)  # layer or X?
    capacity, *_ = rep.manifold_analysis_corr([feature[node].T for feature in features], kappa, n_t)  # layer or X?
    c.append(capacity)
    # c.append(1 / np.mean(1 / alpha))
    # r.append(np.mean(radius))
    # d.append(np.mean(dimension))
    del features
    del extractor
    gc.collect()
    
import matplotlib.pyplot as plt
fig = plt.figure()
plt.plot(c, label='Capacity')
# fig.savefig('/n/holylabs/LABS/sompolinsky_lab/Everyone/xupan/hierarchy_manifold/figures/c.png')

# fig = plt.figure()
# plt.plot(r, label='R')
# fig.savefig('/n/holylabs/LABS/sompolinsky_lab/Everyone/xupan/hierarchy_manifold/figures/r.png')

# fig = plt.figure()
# plt.plot(d, label='D')
# fig.savefig('/n/holylabs/LABS/sompolinsky_lab/Everyone/xupan/hierarchy_manifold/figures/d.png')