In [2]:
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoConfig,
    TrainingArguments,
    Trainer,PretrainedConfig,
    DataCollatorWithPadding,
    default_data_collator

)
from datasets import load_dataset
import evaluate
from distillation import DistilModel

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
teacher=AutoModelForSequenceClassification.from_pretrained("/scratch/pb2276/GLUE-pretrain/out_p2_l/mrpc")

In [None]:
teacher

In [1]:
import torch

In [11]:
device=torch.device("cuda")

In [57]:
wx=torch.rand((12,128, 1024)).to(device).requires_grad_()

In [58]:
wy= torch.rand((12,128,1024)).to(device).requires_grad_()

In [50]:
wxy = torch.bmm(wx.transpose(1, 2), wy)

In [42]:
wxy.shape

torch.Size([12, 1024, 1024])

In [51]:
with torch.no_grad():
    wxy = torch.bmm(wx.transpose(1, 2), wy)
    U, _, Vt = torch.linalg.svd(wxy,driver="gesvd")

In [52]:
loss=torch.linalg.norm( torch.bmm(wx, U) - torch.bmm(wy, Vt.transpose(1,2)), ord="fro", dim=(1, 2))

In [53]:
loss.mean().backward()

In [54]:
wx.is_leaf

True

In [59]:
wx.grad

# CKA

In [1]:
import torch
from torch import Tensor
from torch.nn.functional import pad

from typing import Literal, Tuple, Optional, List


In [40]:
r1=torch.rand((20, 128,1024))

In [41]:
r2=torch.rand((20,128,1024))

In [23]:
r2_sim= torch.bmm(r2, torch.transpose(r2, 1, 2)) 

In [8]:
r1_sim=r1@r1.T
r2_sim=r2@r2.T

In [None]:
128*128 == 128

In [None]:
torch.bmm(r1, r1.T)

In [15]:
mask=torch.eye(r1_sim.shape[-1], dtype=torch.bool).unsqueeze(0)

In [22]:
r1_sim.masked_fill_(mask[0], 0).shape

torch.Size([2560, 2560])

In [27]:
kl=r1_sim@r2_sim

In [30]:
kl.diagonal().sum()

tensor(4.2904e+11)

In [31]:
r1_sim.sum()

tensor(1.6763e+09)

In [10]:
r1.diagonal(dim1=1, dim2=2).shape

torch.Size([20, 128])

In [None]:
def hsic(K,L):
    """MiniBatch CKA using unbiased estimator of HSIC from Song et al 2012 (https://www.jmlr.org/papers/volume13/song12a/song12a.pdf)"""
    kl=torch.bmm(K,L)
    
    

In [44]:
import torch
from torch import Tensor
from torch.nn.functional import pad

from typing import Literal, Tuple, Optional, List
import random

class CKA(torch.nn.Module):
    def __init__(self,dim_matching='zero_pad', reduction='mean', kernel="linear", similarity_token_strategy="flatten"):
        super(CKA, self).__init__()
        assert dim_matching in [None, 'none', 'zero_pad', 'pca']
        self.dim_matching = dim_matching
        self.reduction = reduction
        self.kernel=kernel
        self.similarity_token_strategy=similarity_token_strategy
        self.random_tokens=None


    def generate_random_token_index(self, token_size, selected_size=10):
        self.random_tokens = random.sample(range(token_size),selected_size) 
    
    def create_sim_matrix(self, X:Tensor, diag_zero=True):
        if self.similarity_token_strategy =="flatten":
            X=torch.flatten(X, end_dim=-2)
        elif self.similarity_token_strategy=="random":
            if not self.random_tokens:
                self.generate_random_token_index(X.shape[1])
            X=torch.fatten(X[:, self.random_tokens, :], end_dim=-2)
            
        """Similarity matrix"""
        if self.kernel == "linear":
            sim_m = X@ X.T
        else:
            raise NotImplementedError

        if diag_zero:
            diagonal_mask = torch.eye(sim_m.shape[-1], dtype=torch.bool).unsqueeze(0)

            #has a batch dimension
            if len(sim_m.shape)>2:
                sim_m.masked_fill_(diagonal_mask, 0)
            else:
                sim_m.masked_fill_(diagonal_mask[0], 0)
        
        return sim_m


    def HSIC(self, K,L):
        """ K, L are similarity matrices of form (B,N,N) where B is batch or (N,N)
        """
        n=K.shape[-1]
        if len(K.shape)==3:
            pass
            #TODO: IMPLEMENT BATCHED CKA
            """
            kl=torch.bmm(K,L)"""
            
        else:
            kl=K@L
            trace=kl.diagonal().sum()
            middle=(K.sum()*L.sum())/((n-1)*(n-2))
            last= -2*kl.sum()/(n-2)
            return (trace+middle+last)/(n*(n-3))


    def forward(self, X: Tensor, Y: Tensor):
        if X.shape[:-1] != Y.shape[:-1] or X.ndim != 3 or Y.ndim != 3:
            raise ValueError('Expected 3D input matrices to match in all dimensions but last.'
                             f'But got {X.shape} and {Y.shape} instead.')

        if X.shape[-1] != Y.shape[-1]:
            if self.dim_matching is None or self.dim_matching == 'none':
                raise ValueError(f'Expected same dimension matrices got instead {X.shape} and {Y.shape}. '
                                 f'Set dim_matching or change matrix dimensions.')
            elif self.dim_matching == 'zero_pad':
                size_diff = Y.shape[-1] - X.shape[-1]
                if size_diff < 0:
                    raise ValueError(f'With `zero_pad` dimension matching expected X dimension to be smaller then Y. '
                                     f'But got {X.shape} and {Y.shape} instead.')
                X = pad(X, (0, size_diff))
            elif self.dim_matching == 'pca':
                raise NotImplementedError
            else:
                raise ValueError(f'Unrecognized dimension matching {self.dim_matching}')

        X_sim_matrix = self.create_sim_matrix(X)
        Y_sim_matrix = self.create_sim_matrix(Y)

        self_hsic_x=self.HSIC(X_sim_matrix, X_sim_matrix)
        self_hsic_y= self.HSIC(Y_sim_matrix, Y_sim_matrix)
        cross_hsic=self.HSIC(X_sim_matrix, Y_sim_matrix)
        
        batched_cka = cross_hsic/ (torch.sqrt(self_hsic_x) * torch.sqrt(self_hsic_y))
            
        if self.reduction == 'mean':
            return batched_cka.mean()
        elif self.reduction == 'sum':
            return batched_cka.sum()
        elif self.reduction == 'none' or self.reduction is None:
            return batched_cka
        else:
            raise ValueError(f'Unrecognized reduction {self.reduction}')

In [48]:
s=CKA(reduction="sum")

In [49]:
out= s(r1, r2)

In [50]:
out

tensor(0.0014)