In [1]:
import torch
import numpy as np
from egg.models import models
from tqdm import tqdm

In [2]:
def center_kernel(K: torch.Tensor):
    """
    Center kernel matrix K by:
    1) construcing the centering matrix H
        (see: https://en.wikipedia.org/wiki/Centering_matrix)
    2) left multiply K by H to center the columns of K
    3) right multiply K by H to center the rows of K
    """
    # construct centering matrix
    n = K.shape[0]
    unit = torch.ones([n, n], device=K.device, dtype=K.dtype)
    I = torch.eye(n, device=K.device, dtype=K.dtype)
    H = I - unit / n
    # center rows and columns
    return torch.mm(torch.mm(H, K), H)


def scaled_HSIC(K: torch.Tensor, L: torch.Tensor):
    """
    Given two kernel matrices, computes the scaled empirical estimator of the
    Hilbert-Schmidt Independence Criterion by:
    1) centering the kernel matrices
    2) computing the trace of the product of the centered kernel matrices
    """
    # center kernel matrices
    K_cent = center_kernel(K)
    L_cent = center_kernel(L)
    # compute and return trace
    return torch.sum(K_cent * L_cent)


def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    """
    Given two representation matrices, computes the Centralized Kernel Alignment
    based on the linear kernel k(x_i, x_j) = <x_i, x_j>
    """
    # compute kernel matrices for linear kernel
    K = torch.mm(X, X.T)
    L = torch.mm(Y, Y.T)
    # compute CKA
    normalization = torch.sqrt(scaled_HSIC(K, K)) * torch.sqrt(scaled_HSIC(L, L))
    return scaled_HSIC(K, L) / normalization

In [3]:
images = torch.Tensor(np.load("../data/75_monkey_test_imgs.npy"))

In [31]:
images.reshape(75, -1).norm(dim=1).mean(), images.reshape(75, -1).norm(dim=1).std()

(tensor(89.7271), tensor(22.6774))

In [4]:
td_models = models['task_driven']

[2023-08-15 14:45:37,698][INFO]: Connecting @localhost:3306


datajoint connection not established, skipping model imports from nnfabrik tables


[2023-08-15 14:45:39,104][INFO]: Connecting @localhost:3306


datajoint connection not established, skipping model imports from nnfabrik tables


In [5]:

responses = []
for i in tqdm(np.arange(0, 75, 5)):
    responses.append(td_models['train'](images[i:i+5]))

100%|██████████| 15/15 [00:12<00:00,  1.25it/s]


In [6]:
val_responses = []
for i in tqdm(np.arange(0, 75, 5)):
    val_responses.append(td_models['val'](images[i:i+5]))

100%|██████████| 15/15 [00:10<00:00,  1.46it/s]


In [7]:
cross_val_responses = []
with torch.no_grad():
    for i in tqdm(np.arange(0, 75, 1)):
        cross_val_responses.append(td_models['cross-val'](images[i:i+1].unsqueeze(1), data_key='all_sessions'))

100%|██████████| 75/75 [01:06<00:00,  1.12it/s]


In [12]:
cross_val_responses_2 = []
with torch.no_grad():
    for i in tqdm(np.arange(0, 75, 1)):
        cross_val_responses_2.append(models['v4_multihead_attention']['val'](images[i:i+1].unsqueeze(1), data_key='all_sessions'))

100%|██████████| 75/75 [00:59<00:00,  1.25it/s]


In [8]:
linear_CKA(torch.cat(responses), torch.cat(val_responses))

tensor(0.9994, grad_fn=<DivBackward0>)

In [9]:
linear_CKA(torch.cat(responses), torch.cat(cross_val_responses))

tensor(0.9133, grad_fn=<DivBackward0>)

In [10]:
linear_CKA(torch.cat(val_responses), torch.cat(cross_val_responses))

tensor(0.9116, grad_fn=<DivBackward0>)

In [14]:
linear_CKA(torch.cat(cross_val_responses_2), torch.cat(cross_val_responses))

tensor(0.9949)

In [15]:
linear_CKA(torch.cat(responses), torch.cat(cross_val_responses_2))

tensor(0.9145, grad_fn=<DivBackward0>)

In [16]:
linear_CKA(torch.cat(val_responses), torch.cat(cross_val_responses_2))

tensor(0.9129, grad_fn=<DivBackward0>)

In [17]:
linear_CKA(torch.cat(val_responses), torch.cat(val_responses))

tensor(1.0000, grad_fn=<DivBackward0>)