In [1]:
import PIL
import torchvision.transforms as transforms
import math
import copy
import pandas as pd

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import backbone
from methods.baselinetrain import BaselineTrain
from io_utils import parse_args, get_resume_file, get_best_file, get_assigned_file
from datasets import miniImageNet_few_shot, ISIC_few_shot, EuroSAT_few_shot, CropDisease_few_shot, Chest_few_shot
from data.datamgr import SimpleDataManager, SetDataManager

In [2]:
class Classifier(nn.Module):
    def __init__(self, dim, n_way):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(dim, n_way)

    def forward(self, x):
        x = self.fc(x)
        return x

def get_pretrained_model(model, method, pretrained_dataset, save_dir):
    num_classes = 200
    
    model_dict = {model: backbone.ResNet10(method=method)}
    pretrained_model = BaselineTrain(model_dict[model], num_classes, loss_type='softmax')
    
    checkpoint_dir = '%s/checkpoints/%s/%s_%s_aug' %(save_dir, pretrained_dataset, model, method)

    modelfile = get_resume_file(checkpoint_dir)
    state = torch.load(modelfile)['state']

    pretrained_model.load_state_dict(state, strict=True)
    pretrained_model.cuda()
    pretrained_model.train()
    
    return pretrained_model

In [None]:
image_size = 224
iter_num = 1

n_way = 5
n_support = 5
n_query = 15
few_shot_params = dict(n_way=n_way, n_support=n_support)

dataset_names = ["CropDisease"] # ["miniImageNet_test", "CropDisease", "EuroSAT", "ISIC", "ChestX"]
method = 'baseline'
model = 'ResNet10'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'
freeze_backbone = False

for dataset_name in dataset_names:
    print (dataset_name)
    if dataset_name == "miniImageNet" or dataset_name == "miniImageNet_test":
        datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_episode=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "CropDisease":
        datamgr = CropDisease_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "EuroSAT":
        datamgr = EuroSAT_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ISIC":
        datamgr = ISIC_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ChestX":
        datamgr = Chest_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
        
    if dataset_name == "miniImageNet_test":
        novel_loader = datamgr.get_data_loader(aug=False, train=False)
    else:
        novel_loader = datamgr.get_data_loader(aug=False)
    
    prt_w = []
    prt_wo = []
    ft_w = []
    ft_wo = []
    df = pd.DataFrame(np.zeros([iter_num, 4]), columns=['prt_wo', 'ft_wo', 'prt_w', 'ft_w'])
    
    for task_num, (x, y) in tqdm(enumerate(novel_loader)):
        pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir)
        
        ####################################################################################
        
        classifier = Classifier(pretrained_model.feature.final_feat_dim, n_way)
        classifier_opt = torch.optim.SGD(classifier.parameters(), lr = 1e-2, momentum=0.9, dampening=0.9, weight_decay=0.001)
        classifier.cuda()
        classifier.train()
        
        if freeze_backbone is False:
#             var_init = 0.1
#             pretrained_model.feature.trunk[1].running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[1].running_var.data.fill_(var_init)

#             pretrained_model.feature.trunk[4].BN1.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[4].BN1.running_var.data.fill_(var_init)
#             pretrained_model.feature.trunk[4].BN2.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[4].BN2.running_var.data.fill_(var_init)

#             pretrained_model.feature.trunk[5].BN1.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[5].BN1.running_var.data.fill_(var_init)
#             pretrained_model.feature.trunk[5].BN2.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[5].BN2.running_var.data.fill_(var_init)
#             pretrained_model.feature.trunk[5].BNshortcut.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[5].BNshortcut.running_var.data.fill_(var_init)

#             pretrained_model.feature.trunk[6].BN1.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[6].BN1.running_var.data.fill_(var_init)
#             pretrained_model.feature.trunk[6].BN2.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[6].BN2.running_var.data.fill_(var_init)
#             pretrained_model.feature.trunk[6].BNshortcut.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[6].BNshortcut.running_var.data.fill_(var_init)

#             pretrained_model.feature.trunk[7].BN1.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[7].BN1.running_var.data.fill_(var_init)
#             pretrained_model.feature.trunk[7].BN2.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[7].BN2.running_var.data.fill_(var_init)
#             pretrained_model.feature.trunk[7].BNshortcut.running_mean.data.fill_(0.)
#             pretrained_model.feature.trunk[7].BNshortcut.running_var.data.fill_(var_init)

#             for name, p in pretrained_model.named_parameters():
#                 if 'trunk.7' in name:
#                     if 'BN' in name:
#                         if 'weight' in name:
#                             p.data.fill_(1.)
#                         else:
#                             p.data.fill_(0.)
#                     else:
#                         nn.init.kaiming_uniform_(p.data, a=math.sqrt(5)) # p.data[half:,:,:,:]

            delta_opt = torch.optim.SGD(filter(lambda p: p.requires_grad, pretrained_model.parameters()), lr = 1e-2, momentum=0.9, dampening=0.9, weight_decay=0.001)
        loss_fn = nn.CrossEntropyLoss().cuda()
        ####################################################################################
        

        n_query = x.size(1) - n_support
        x = x.cuda()
        x_var = Variable(x)

        x_a_i = x_var[:,:n_support,:,:,:].contiguous().view( n_way* n_support, *x.size()[2:]) # (25 (5-way * 5-n_support), 3, 224, 224)
        x_b_i = x_var[:,n_support:,:,:,:].contiguous().view( n_way* n_query,  *x.size()[2:]) # (75 (5-way * 15-n_qeury), 3, 224, 224)
        y_a_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_support ) )).cuda() # (25,)
        y_b_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_query ) )).cuda() # (75,)
        y_b_i = y_b_i.detach().cpu().numpy()

        ####################################################################################

        support_size = n_way * n_support
        finetune_epoch = 100
        batch_size = 4
        
        for epoch in range(finetune_epoch):
            
            if epoch == 0:
                with torch.no_grad():
                    pretrained_model.eval()
                    classifier.eval()
                    
                    w_clas_scores = classifier(pretrained_model.feature(x_b_i))
                    topk_scores, topk_labels = w_clas_scores.data.topk(1, 1, True, True)
                    topk_ind = topk_labels.cpu().numpy()

                    top1_correct = np.sum(topk_ind[:,0] == y_b_i)
                    correct_this, count_this = float(top1_correct), len(y_b_i)
                    prt_w.append(correct_this/count_this*100)
                    # print ('prt w clas {:2.4f}'.format(correct_this/count_this*100))
                    
                    nil_cls = torch.zeros([5, 512])
                    for i in range(5):
                        cls_idx = y_a_i==i
                        nil_cls[i] = torch.mean(pretrained_model.feature(x_a_i)[cls_idx].cpu(), dim=0)
                    wo_clas_scores = torch.mm(pretrained_model.feature(x_b_i).cpu(), nil_cls.T)

                    topk_scores, topk_labels = wo_clas_scores.data.topk(1, 1, True, True)
                    topk_ind = topk_labels.cpu().numpy()

                    top1_correct = np.sum(topk_ind[:,0] == y_b_i)
                    correct_this, count_this = float(top1_correct), len(y_b_i)
                    prt_wo.append(correct_this/count_this*100)
                    # print ('prt w/o clas {:2.4f}'.format(correct_this/count_this*100))
                        
            pretrained_model.train()
            classifier.train()

            if freeze_backbone:
                pretrained_model.eval()

            rand_id = np.random.permutation(support_size)
            for j in range(0, support_size, batch_size):
                classifier_opt.zero_grad()
                if freeze_backbone is False:
                    delta_opt.zero_grad()
                #####################################
                selected_id = torch.from_numpy( rand_id[j: min(j+batch_size, support_size)]).cuda()
                z_batch = x_a_i[selected_id]
                y_batch = y_a_i[selected_id] 
                #####################################
                output = pretrained_model.feature(z_batch)
                scores = classifier(output)
                loss = loss_fn(scores, y_batch)
                #####################################
                loss.backward()
                classifier_opt.step()
                if freeze_backbone is False:
                    delta_opt.step()

            pretrained_model.eval()
            classifier.eval()
            scores = classifier(pretrained_model.feature(x_b_i.cuda()))

            topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()

            top1_correct = np.sum(topk_ind[:,0] == y_b_i)
            correct_this, count_this = float(top1_correct), len(y_b_i)
            # print (correct_this / count_this *100)
            
            if epoch == finetune_epoch-1:
                with torch.no_grad():
                    pretrained_model.eval()
                    classifier.eval()
                    
                    w_clas_scores = classifier(pretrained_model.feature(x_b_i))
                    topk_scores, topk_labels = w_clas_scores.data.topk(1, 1, True, True)
                    topk_ind = topk_labels.cpu().numpy()

                    top1_correct = np.sum(topk_ind[:,0] == y_b_i)
                    correct_this, count_this = float(top1_correct), len(y_b_i)
                    ft_w.append(correct_this/count_this*100)
                    # print ('ft w clas {:2.4f}'.format(correct_this/count_this*100))
                    
                    nil_cls = torch.zeros([5, 512])
                    for i in range(5):
                        cls_idx = y_a_i==i
                        nil_cls[i] = torch.mean(pretrained_model.feature(x_a_i)[cls_idx].cpu(), dim=0)
                    wo_clas_scores = torch.mm(pretrained_model.feature(x_b_i).cpu(), nil_cls.T)

                    topk_scores, topk_labels = wo_clas_scores.data.topk(1, 1, True, True)
                    topk_ind = topk_labels.cpu().numpy()

                    top1_correct = np.sum(topk_ind[:,0] == y_b_i)
                    correct_this, count_this = float(top1_correct), len(y_b_i)
                    ft_wo.append(correct_this/count_this*100)
                    # print ('ft w/o clas {:2.4f}'.format(correct_this/count_this*100))
    
    df['prt_wo'] = prt_wo
    df['ft_wo'] = ft_wo
    df['prt_w'] = prt_w
    df['ft_w'] = ft_w
    df.to_csv('./NIL_results/{}_LBreinit.csv'.format(dataset_name), index=False)
    
    print ('prt_w mean: {:2.4f}, 95% conf.: {:2.4f}'.format(np.mean(prt_w), 1.96*np.std(prt_w)/np.sqrt(iter_num)))
    print ('prt_wo mean: {:2.4f}, 95% conf.: {:2.4f}'.format(np.mean(prt_wo), 1.96*np.std(prt_wo)/np.sqrt(iter_num)))
    print ('ft_w mean: {:2.4f}, 95% conf.: {:2.4f}'.format(np.mean(ft_w), 1.96*np.std(ft_w)/np.sqrt(iter_num)))
    print ('ft_wo mean: {:2.4f}, 95% conf.: {:2.4f}'.format(np.mean(ft_wo), 1.96*np.std(ft_wo)/np.sqrt(iter_num)))

---
Relative change of pre-trained network layers 

In [None]:
image_size = 224
iter_num = 600

n_way = 5
n_support = 5
n_query = 15
few_shot_params = dict(n_way=n_way, n_support=n_support)

dataset_names = ["miniImageNet_test", "CropDisease", "EuroSAT", "ISIC", "ChestX"]
method = 'baseline'
model = 'ResNet10'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'
freeze_backbone = False

for dataset_name in dataset_names:
    print (dataset_name)
    if dataset_name == "miniImageNet" or dataset_name == "miniImageNet_test":
        datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_episode=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "CropDisease":
        datamgr = CropDisease_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "EuroSAT":
        datamgr = EuroSAT_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ISIC":
        datamgr = ISIC_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ChestX":
        datamgr = Chest_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
        
    if dataset_name == "miniImageNet_test":
        novel_loader = datamgr.get_data_loader(aug=False, train=False)
    else:
        novel_loader = datamgr.get_data_loader(aug=False)
        
    params_list = ["Stem.Conv.weight", "Stem.BN.scale", "Stem.BN.shift",
                   "Block1.Conv1.weight", "Block1.BN1.scale", "Block1.BN1.shift", "Block1.Conv2.weight", "Block1.BN2.scale", "Block1.BN2.shift",
                   "Block2.Conv1.weight", "Block2.BN1.scale", "Block2.BN1.shift", "Block2.Conv2.weight", "Block2.BN2.scale", "Block2.BN2.shift", "Block2.ShortCutConv.weight", "Block2.ShortCutBN.scale", "Block2.ShortCutBN.shift",
                   "Block3.Conv1.weight", "Block3.BN1.scale", "Block3.BN1.shift", "Block3.Conv2.weight", "Block3.BN2.scale", "Block3.BN2.shift", "Block3.ShortCutConv.weight", "Block3.ShortCutBN.scale", "Block3.ShortCutBN.shift",
                   "Block4.Conv1.weight", "Block4.BN1.scale", "Block4.BN1.shift", "Block4.Conv2.weight", "Block4.BN2.scale", "Block4.BN2.shift", "Block4.ShortCutConv.weight", "Block4.ShortCutBN.scale", "Block4.ShortCutBN.shift",
                   "Classifier.weight", "Classifier.bias"]
    df = pd.DataFrame(np.zeros([iter_num, len(params_list)]), columns=params_list)
    
    for task_num, (x, y) in tqdm(enumerate(novel_loader)):
        relative_changed_norm = []

        pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir)
        
        ####################################################################################
        
        classifier = Classifier(pretrained_model.feature.final_feat_dim, n_way)
        classifier_opt = torch.optim.SGD(classifier.parameters(), lr = 1e-2, momentum=0.9, dampening=0.9, weight_decay=0.001)
        classifier.cuda()
        classifier.train()
        
        before_extractor = copy.deepcopy(pretrained_model.feature)
        before_classifier = copy.deepcopy(classifier)
        
        if freeze_backbone is False:
            for name, p in pretrained_model.named_parameters():
                if 'trunk.7' in name:
                    if 'shortcut' in name:
                        if 'BN' in name:
#                             pass
                            if 'weight' in name:
                                p.data.fill_(1.)
                            else:
                                p.data.fill_(0.)
                        else:
#                             pass
                            nn.init.kaiming_uniform_(p.data, a=math.sqrt(5))
                    else:
                        if 'C1' in name:
                            pass
#                             nn.init.kaiming_uniform_(p.data, a=math.sqrt(5))
                        if 'BN1' in name:
                            pass
#                             if 'weight' in name:
#                                 p.data.fill_(1.)
#                             else:
#                                 p.data.fill_(0.)
                        if 'C2' in name:
#                             pass
                            nn.init.kaiming_uniform_(p.data, a=math.sqrt(5))
                        if 'BN2' in name:
#                             pass
                            if 'weight' in name:
                                p.data.fill_(1.)
                            else:
                                p.data.fill_(0.)

            delta_opt = torch.optim.SGD(filter(lambda p: p.requires_grad, pretrained_model.parameters()), lr = 1e-2, momentum=0.9, dampening=0.9, weight_decay=0.001)
        loss_fn = nn.CrossEntropyLoss().cuda()
        ####################################################################################
        

        n_query = x.size(1) - n_support
        x = x.cuda()
        x_var = Variable(x)

        x_a_i = x_var[:,:n_support,:,:,:].contiguous().view( n_way* n_support, *x.size()[2:]) # (25 (5-way * 5-n_support), 3, 224, 224)
        x_b_i = x_var[:,n_support:,:,:,:].contiguous().view( n_way* n_query,  *x.size()[2:]) # (75 (5-way * 15-n_qeury), 3, 224, 224)
        y_a_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_support ) )).cuda() # (25,)
        y_b_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_query ) )).cuda() # (75,)
        y_b_i = y_b_i.detach().cpu().numpy()

        ####################################################################################

        support_size = n_way * n_support
        finetune_epoch = 100
        batch_size = 4
        
        for epoch in range(finetune_epoch):
            pretrained_model.train()
            classifier.train()

            if freeze_backbone:
                pretrained_model.eval()

            rand_id = np.random.permutation(support_size)
            for j in range(0, support_size, batch_size):
                classifier_opt.zero_grad()
                if freeze_backbone is False:
                    delta_opt.zero_grad()
                #####################################
                selected_id = torch.from_numpy( rand_id[j: min(j+batch_size, support_size)]).cuda()
                z_batch = x_a_i[selected_id]
                y_batch = y_a_i[selected_id] 
                #####################################
                output = pretrained_model.feature(z_batch)
                scores = classifier(output)
                loss = loss_fn(scores, y_batch)
                #####################################
                loss.backward()
                classifier_opt.step()
                if freeze_backbone is False:
                    delta_opt.step()
                    
        after_extractor = copy.deepcopy(pretrained_model.feature)
        after_classifier = copy.deepcopy(classifier)
        
        for b, a in zip(before_extractor.parameters(), after_extractor.parameters()):
            relative_changed_norm.append((torch.norm(b-a)/torch.norm(b)).item())

        for b, a in zip(before_classifier.parameters(), after_classifier.parameters()):
            relative_changed_norm.append((torch.norm(b-a)/torch.norm(b)).item())
            
        df.loc[task_num] = relative_changed_norm
    df.to_csv('./{}_relative_norm_reinit.csv'.format(dataset_name))

In [None]:

params_list = ["Stem.Conv.weight", "Stem.BN.scale", "Stem.BN.shift",
               "Block1.Conv1.weight", "Block1.BN1.scale", "Block1.BN1.shift", "Block1.Conv2.weight", "Block1.BN2.scale", "Block1.BN2.shift",
               "Block2.Conv1.weight", "Block2.BN1.scale", "Block2.BN1.shift", "Block2.Conv2.weight", "Block2.BN2.scale", "Block2.BN2.shift", "Block2.ShortCutConv.weight", "Block2.ShortCutBN.scale", "Block2.ShortCutBN.shift",
               "Block3.Conv1.weight", "Block3.BN1.scale", "Block3.BN1.shift", "Block3.Conv2.weight", "Block3.BN2.scale", "Block3.BN2.shift", "Block3.ShortCutConv.weight", "Block3.ShortCutBN.scale", "Block3.ShortCutBN.shift",
               "Block4.Conv1.weight", "Block4.BN1.scale", "Block4.BN1.shift", "Block4.Conv2.weight", "Block4.BN2.scale", "Block4.BN2.shift", "Block4.ShortCutConv.weight", "Block4.ShortCutBN.scale", "Block4.ShortCutBN.shift",
               "Classifier.weight", "Classifier.bias"]
    
plt.figure(figsize=(12, 4))

df = pd.read_csv('./miniImageNet_test_relative_norm.csv', index_col=0)
plt.plot(range(len(params_list)), list(df.mean()), label='miniImageNet_test', marker='o')

df = pd.read_csv('./CropDisease_relative_norm.csv', index_col=0)
plt.plot(range(len(params_list)), list(df.mean()), label='CropDisease', marker='o')

df = pd.read_csv('./EuroSAT_relative_norm.csv', index_col=0)
plt.plot(range(len(params_list)), list(df.mean()), label='EuroSAT', marker='o')

df = pd.read_csv('./ISIC_relative_norm.csv', index_col=0)
plt.plot(range(len(params_list)), list(df.mean()), label='ISIC', marker='o')

df = pd.read_csv('./ChestX_relative_norm.csv', index_col=0)
plt.plot(range(len(params_list)), list(df.mean()), label='ChestX', marker='o')

plt.xticks(range(len(params_list)), params_list, rotation=90)
plt.legend()
# plt.show()
plt.savefig('./src/relative_change.pdf', bbox_inches='tight', format='pdf')
plt.close()

# CKA

In [None]:
def gram_linear(x):
    """Compute Gram (kernel) matrix for a linear kernel.

    Args:
    x: A num_examples x num_features matrix of features.

    Returns:
    A num_examples x num_examples Gram matrix of examples.
    """
    return x.dot(x.T)


def gram_rbf(x, threshold=1.0):
    """Compute Gram (kernel) matrix for an RBF kernel.

    Args:
    x: A num_examples x num_features matrix of features.
    threshold: Fraction of median Euclidean distance to use as RBF kernel
      bandwidth. (This is the heuristic we use in the paper. There are other
      possible ways to set the bandwidth; we didn't try them.)

    Returns:
    A num_examples x num_examples Gram matrix of examples.
    """
    dot_products = x.dot(x.T)
    sq_norms = np.diag(dot_products)
    sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
    sq_median_distance = np.median(sq_distances)
    return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))


def center_gram(gram, unbiased=False):
    """Center a symmetric Gram matrix.

    This is equvialent to centering the (possibly infinite-dimensional) features
    induced by the kernel before computing the Gram matrix.

    Args:
    gram: A num_examples x num_examples symmetric matrix.
    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
      estimate of HSIC. Note that this estimator may be negative.

    Returns:
    A symmetric matrix with centered columns and rows.
    """
    if not np.allclose(gram, gram.T):
        raise ValueError('Input must be a symmetric matrix.')
    gram = gram.copy()

    if unbiased:
        # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
        # L. (2014). Partial distance correlation with methods for dissimilarities.
        # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
        # stable than the alternative from Song et al. (2007).
        n = gram.shape[0]
        np.fill_diagonal(gram, 0)
        means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
        means -= np.sum(means) / (2 * (n - 1))
        gram -= means[:, None]
        gram -= means[None, :]
        np.fill_diagonal(gram, 0)
    else:
        means = np.mean(gram, 0, dtype=np.float64)
        means -= np.mean(means) / 2
        gram -= means[:, None]
        gram -= means[None, :]

    return gram

def cka(gram_x, gram_y, debiased=False):
    """Compute CKA.

    Args:
    gram_x: A num_examples x num_examples Gram matrix.
    gram_y: A num_examples x num_examples Gram matrix.
    debiased: Use unbiased estimator of HSIC. CKA may still be biased.

    Returns:
    The value of CKA between X and Y.
    """
    gram_x = center_gram(gram_x, unbiased=debiased)
    gram_y = center_gram(gram_y, unbiased=debiased)

    # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
    # n*(n-3) (unbiased variant), but this cancels for CKA.
    scaled_hsic = gram_x.ravel().dot(gram_y.ravel())

    normalization_x = np.linalg.norm(gram_x)
    normalization_y = np.linalg.norm(gram_y)
    return scaled_hsic / (normalization_x * normalization_y)


def _debiased_dot_product_similarity_helper(xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, n):
    """Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
    # This formula can be derived by manipulating the unbiased estimator from
    # Song et al. (2007).
    return (
      xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)
      + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))


def feature_space_linear_cka(features_x, features_y, debiased=False):
    """Compute CKA with a linear kernel, in feature space.

    This is typically faster than computing the Gram matrix when there are fewer
    features than examples.

    Args:
    features_x: A num_examples x num_features matrix of features.
    features_y: A num_examples x num_features matrix of features.
    debiased: Use unbiased estimator of dot product similarity. CKA may still be
      biased. Note that this estimator may be negative.

    Returns:
    The value of CKA between X and Y.
    """
    features_x = features_x - np.mean(features_x, 0, keepdims=True)
    features_y = features_y - np.mean(features_y, 0, keepdims=True)

    dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2
    normalization_x = np.linalg.norm(features_x.T.dot(features_x))
    normalization_y = np.linalg.norm(features_y.T.dot(features_y))

    if debiased:
        n = features_x.shape[0]
        # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
        sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
        sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
        squared_norm_x = np.sum(sum_squared_rows_x)
        squared_norm_y = np.sum(sum_squared_rows_y)

        dot_product_similarity = _debiased_dot_product_similarity_helper(
            dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
            squared_norm_x, squared_norm_y, n)
        normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
            squared_norm_x, squared_norm_x, n))
        normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
            squared_norm_y, squared_norm_y, n))

    return dot_product_similarity / (normalization_x * normalization_y)

def cca(features_x, features_y):
    """Compute the mean squared CCA correlation (R^2_{CCA}).

    Args:
    features_x: A num_examples x num_features matrix of features.
    features_y: A num_examples x num_features matrix of features.

    Returns:
    The mean squared CCA correlations between X and Y.
    """
    qx, _ = np.linalg.qr(features_x)  # Or use SVD with full_matrices=False.
    qy, _ = np.linalg.qr(features_y)
    return np.linalg.norm(qx.T.dot(qy)) ** 2 / min(
      features_x.shape[1], features_y.shape[1])

In [None]:
# cka = feature_space_linear_cka()