In [1]:
import torch
import torchvision
from einops import rearrange, repeat
from loguru import logger

from cka import gram, centering_mat, centered_gram, unbiased_hsic_xy, MinibatchCKA

DEVICE = torch.device('cpu')
if torch.cuda.is_available():
  DEVICE = torch.device('cuda')

In [2]:
#########################
# sanity tests

x = torch.randn(8, 7, 5, device=DEVICE)
y = torch.randn(8, 17, 9, device=DEVICE)

g = gram(x, device=DEVICE)
assert g.shape == (x.shape[0], x.shape[0])
cg = centered_gram(x, device=DEVICE) # centered = column and row means subtracted
assert cg.shape == g.shape

In [3]:
cka = MinibatchCKA().to(DEVICE)
cka.update(x, y, device=DEVICE)
cka.compute()

tensor(0.3433, device='cuda:0')

In [20]:
X = torch.randn(8, 10, device=DEVICE)
Y = torch.randn(8, 20, device=DEVICE)

In [21]:
cka = MinibatchCKA().to(DEVICE)
cka.update(X, Y, device=DEVICE)
cka.compute()

tensor(0.3563, device='cuda:0')

In [34]:
X = torch.randn(32, 100, device=DEVICE)
Y = torch.randn(32, 200, device=DEVICE)

In [35]:
cka = MinibatchCKA().to(DEVICE)
cka.update(X, Y, device=DEVICE)
cka.compute()

tensor(0.0874, device='cuda:0')

In [24]:
X = torch.randn(128, 100, device=DEVICE)
Y = torch.randn(128, 200, device=DEVICE)

In [25]:
cka = MinibatchCKA().to(DEVICE)
cka.update(X, Y, device=DEVICE)
cka.compute()

tensor(0., device='cuda:0')

In [30]:
X = torch.randn(256, 100, device=DEVICE)
Y = torch.randn(256, 200, device=DEVICE)

In [31]:
cka = MinibatchCKA().to(DEVICE)
cka.update(X, Y, device=DEVICE)
cka.compute()

tensor(0., device='cuda:0')

# CKA sanity checks

In [5]:
x = torch.arange(16).view(4, 2, 2)
y = torch.rot90(x, 1, [1, 2])
z = torch.rot90(x, -1, [1, 2])
x_2 = x * 2.5
check = torch.rand(24).view(4, 6) * 10

### Invariant to isotropic scaling

In [6]:
cka = MinibatchCKA().to(DEVICE)
cka.update(x, x_2, device=DEVICE)
cka.compute()

tensor(1., device='cuda:0')

In [7]:
cka = MinibatchCKA().to(DEVICE)
cka.update(x, check, device=DEVICE)
cka.compute()

tensor(0.5059, device='cuda:0')

In [8]:
cka = MinibatchCKA().to(DEVICE)
cka.update(x_2, check, device=DEVICE)
cka.compute()

tensor(0.5059, device='cuda:0')

### Invariant to orthogonal transformations

In [11]:
same = torch.rand(2, 2)
check = torch.rand(16).view(4, 4) * 10

In [12]:
U, epsilon, V = torch.linalg.svd(same)

In [13]:
xu = torch.matmul(x.float(), U)
xv = torch.matmul(x.float(), V)

assert (xu != xv).any()

In [14]:
cka = MinibatchCKA().to(DEVICE)
cka.update(xu,xv, device=DEVICE)
cka.compute()

tensor(1.0000, device='cuda:0')

In [15]:
cka = MinibatchCKA().to(DEVICE)
cka.update(x,check, device=DEVICE)
cka.compute()

tensor(0.4479, device='cuda:0')

In [16]:
cka = MinibatchCKA().to(DEVICE)
cka.update(xu, torch.matmul(check.float().reshape(-1, 2, 2), V), device=DEVICE)
cka.compute()

tensor(0.4479, device='cuda:0')