In [8]:
import numpy as np
import torch
from utils import *
import matplotlib.pyplot as plt
from scipy.stats import special_ortho_group
import sys
from torch import nn

# 2 Gaussians (or 2 cubes) translation exp

In [3]:
cuda = torch.device('cuda')

def get_cka_test(mean1 = 0,
                 mean2 = 0,
                 var1 = 1,
                 var2 = 1,
                 num_dims = 100,
                 num_pts = 1000,
                 seed = 0,
                 c = 1000,
                 verbose = False,
                 distribution = 'gaussian',
                 median = median):
    np.random.seed(seed)
    
    d = np.random.normal(0,1,[num_dims])
    d /= np.linalg.norm(d)
    
    if distribution == 'gaussian':
        X = np.concatenate( [np.random.normal(mean1, var1, [num_pts, num_dims]), np.random.normal(mean2, var2, [num_pts, num_dims])], axis = 0)
        Y = torch.Tensor(X + np.concatenate([np.zeros([num_pts, num_dims]), c*np.matmul(np.ones([num_pts,1]), d.reshape([1,num_dims]))], axis = 0)).to(cuda)
    elif distribution == 'uniform':
        # in this case var = side and mean = center
        X = np.concatenate([var1*(np.random.rand(num_pts, num_dims)-var1*0.5*np.ones([num_pts,num_dims]))+mean1*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1), var2*(np.random.rand(num_pts, num_dims)-var2*0.5*np.ones([num_pts,num_dims]))+mean2*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1)], axis = 0)
        Y = torch.Tensor(X + np.concatenate([np.zeros([num_pts, num_dims]), c*np.matmul(np.ones([num_pts,1]), d.reshape([1,num_dims]))], axis = 0)).to(cuda)
    
    X = torch.Tensor(X).to(cuda)
    
    CKA = rbfCKA(median=median)
    if verbose:
        return CKA(X,Y).item(), torch.where(X==Y)
    else:
        return CKA(X,Y).item()

## Multiple seeds

In [6]:
num_pts = 10000
num_dims = 1000
num_seeds = 10
c_list = [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

for median in [0.2, 0.5, 1]:
    diff = []
    data = np.zeros([num_seeds, len(c_list)])
    for seed in range(num_seeds):
        print(f'seed {seed}')
        for i, c in enumerate(c_list):
            data[seed, i], v = get_cka_test(mean2=1.1, num_dims = num_dims, num_pts = num_pts, c = c, seed = seed, distribution = 'uniform', verbose = True, median = median)
            diff.append(v)
    
    if median == 1:
        np.save('two_cubes_exp_median_rbfcka2__means_0_1.1_{}seeds_v2.npy'.format(num_seeds), data)
    else:
        np.save('two_cubes_exp_median{}_rbfcka2__means_0_1.1_{}seeds_v2.npy'.format(median, num_seeds), data)

seed 0
seed 1
seed 2
seed 3
seed 4
seed 5
seed 6
seed 7
seed 8
seed 9
