**You need to adapt paths based on your current working directory to run the code**

In [None]:
#import packages
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import torchvision.models as models
import numpy as np
from copy import deepcopy
from glob import glob
import os
from os.path import basename

import matplotlib.pyplot as plt
from random import randint

import time

# mount drive
from google.colab import files, drive
drive.mount('/content/gdrive')
os.chdir('gdrive/MyDrive/results/')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device is ", device)

Mounted at /content/gdrive
device is  cuda:0


In [None]:
def positive_associate_features(corr):

  # it returns a 2d array whose colums show the associations.  It also return 
  # a permuted version of corr in which the associted featurs of the larger dim
  # are sorted
  d1, d2 = corr.shape
  mindim = np.min([d1, d2])

  flat_index = np.empty((3, d1 * d2))
  
  ind = 0
  for i in range(d1):
    for j in range(d2):      
      flat_index[:, ind] = (int(i), int(j), corr[i, j])
      ind += 1

  flat_index = flat_index[:, np.argsort(- flat_index[2, :])]

  asdim1 = np.array([-1] * mindim)
  asdim2 = np.array([-1] * mindim)
  associated = 0
  ind = 0
  while associated < mindim:
    i, j = flat_index[:2, ind]
    if (i not in asdim1) and (j not in asdim2):
      asdim1[associated] = i
      asdim2[associated] = j
      associated += 1
    ind += 1

  
  if d1 <= d2:
    sortind = np.argsort(asdim1)
  else:
    sortind = np.argsort(asdim2)

  associates = np.array([asdim1[sortind], asdim2[sortind]])


  if d1 <= d2:
    rest = np.setdiff1d(flat_index[1, :], associates[1])
    permute = np.append(associates[1], rest)    

    # print('permute d2 is ', permute)
    permute_corr = corr[:, permute]
  else:
    rest = np.setdiff1d(flat_index[0, :], associates[0])
    permute = np.append(associates[0], rest)

    # print('permute d1 is ', permute)
    permute_corr = corr[permute, :]


  return associates, permute_corr

In [None]:

root = 'gdrive/MyDrive/Project/'

adds = glob(root + "*corr*")
adds = [add for add in adds if 'permute' not in add]
# print(*adds, sep='\n')

num = np.random.randint(0, len(adds) - 1)
print('num is', num , os.path.basename(adds[num]))
corr = torch.load(adds[num]).to(device)

asso1, _ = associate_features(corr)
asso2, _ = positive_associate_features(corr)

plt.subplots(figsize=(24,10))
plt.subplot(121)
ms = 2
plt.plot(asso1[0], asso1[1], '.', markersize=ms)
plt.plot(asso2[0], asso2[1], '.', markersize=ms)
plt.legend(['new', 'old'])

plt.subplot(122)
plt.plot(asso1[1], asso2[1], '.', markersize=ms)


In [None]:
# helper functions

def associate_features(corr):

  # it returns a 2d array whose colums show the associations.  It also return 
  # a permuted version of corr in which the associted featurs of the larger dim
  # are sorted
  d1, d2 = corr.shape
  mindim = np.min([d1, d2])

  flat_index = np.empty((3, d1 * d2))
  
  ind = 0
  for i in range(d1):
    for j in range(d2):      
      flat_index[:, ind] = (int(i), int(j), abs(corr[i, j]))
      ind += 1

  flat_index = flat_index[:, np.argsort(- flat_index[2, :])]

  asdim1 = np.array([-1] * mindim)
  asdim2 = np.array([-1] * mindim)
  associated = 0
  ind = 0
  while associated < mindim:
    i, j = flat_index[:2, ind]
    if (i not in asdim1) and (j not in asdim2):
      asdim1[associated] = i
      asdim2[associated] = j
      associated += 1
    ind += 1

  
  if d1 <= d2:
    sortind = np.argsort(asdim1)
  else:
    sortind = np.argsort(asdim2)

  associates = np.array([asdim1[sortind], asdim2[sortind]])


  if d1 <= d2:
    rest = np.setdiff1d(flat_index[1, :], associates[1])
    permute = np.append(associates[1], rest)    

    # print('permute d2 is ', permute)
    permute_corr = corr[:, permute]
  else:
    rest = np.setdiff1d(flat_index[0, :], associates[0])
    permute = np.append(associates[0], rest)

    # print('permute d1 is ', permute)
    permute_corr = corr[permute, :]


  return associates, permute_corr


def take_correlation(covariance, data1, data2):

  # send inputs to device
  covariance = covariance.to(device)
  data1 = data1.to(device)
  data2 = data2.to(device)

  # get standard deviations
  std_data1 = torch.std(data1, dim=0).reshape(-1, 1)
  std_data2 = torch.std(data2, dim=0).reshape(1, -1)

  # print('shape 1 & 2 are', std_data1.shape, std_data2.shape)

  stdmatrix = torch.matmul(std_data1, std_data2)

  corr = torch.div(covariance, stdmatrix)

  return corr
  

def take_covariance(data1, data2):
  
  # send inputs to device
  data1 = data1.to(device)
  data2 = data2.to(device)
  
  datasize = data1.shape[0]
  matrix_shape = (data1.shape[1], data2.shape[1])
  
  meandata1 = torch.sum(data1, dim=0) / datasize
  meandata1 = meandata1.repeat(datasize, 1)
  data1 -= meandata1

  cov = torch.matmul(torch.transpose(data1, 0, 1), data2) / datasize 

  return cov


def dataname_from_address(add):
  filename = basename(add)
  name = filename.split('-on-')[1]
  name = name.split('-')[0]

  return name

def get_save_cavariance(featadd1, featadd2,  
  saveroot= 'Simran/cov_corr_associate/'):


  data1_name, data2_name = dataname_from_address(featadd1), dataname_from_address(featadd2)

  if data1_name != data2_name:
    print('not matched data')
    return None
  else:
    saveroot += data1_name.replace('_trn', '') + "/"

  saveprefix = basename(featadd1) + "-vs-" + basename(featadd2)
  saveprefix = saveprefix.replace('_features.pt','')
  saveprefix = saveroot + saveprefix
  
  checkadd = saveprefix + '-permutedcorr.pt'
  if os.path.exists(checkadd):
    print('already exists')
    return None
  
  data1 = torch.load(featadd1)
  data2 = torch.load(featadd2)

  cov = take_covariance(data1=data1, data2=data2)
  corr =take_correlation(covariance=cov, data1=data1, data2=data2)
  asso, permute_corr = associate_features(corr=corr)

  covadd = saveprefix + '-cov.pt'
  torch.save(cov, covadd)
  # print('cov saved at:',covadd)

  corradd = saveprefix + '-corr.pt'
  torch.save(corr, corradd)

  assoadd = saveprefix + '-asso.pt'
  torch.save(asso, assoadd)

  permute_corradd = saveprefix + '-permutedcorr.pt'
  torch.save(permute_corr, permute_corradd)


'done'

'done'

In [None]:
# save corr data
featroot = 'Saravanan/Vggnet_features/'
saveroot = 'Simran/cov_corr_associate/'
oldfeatroot = 'Saravanan/Features/'

featadds = glob(featroot + "*features*")
oldfeatadds = glob(oldfeatroot + "*features*")

# print(len(featadds), len(oldfeatadds))
# print(*featadds, sep='\n')
ind = 1
for add1 in featadds:
  # if ind > 5:
  #   break

  for add2 in featadds:
    # if ind > 5:
    #   break

    if add2 < add1:
      continue
    else:
      data1_name = dataname_from_address(add1)
      data2_name = dataname_from_address(add2)

      if data1_name == data2_name:
        get_save_cavariance(add1, add2, saveroot=saveroot)
        print(ind, basename(add1), basename(add2))
        ind += 1

print('\n start cross correlations \n')

for add1 in featadds:

  # if ind > 10:
  #   break

  for add2 in oldfeatadds:
    
    # if ind > 10:
    #   break

    data1_name = dataname_from_address(add1)
    data2_name = dataname_from_address(add2)

    if data1_name == data2_name:
        get_save_cavariance(add1, add2, saveroot=saveroot)
        print(ind, basename(add1), basename(add2))
        ind += 1


already exists
1 vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt
already exists
2 vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt vgg19-init=negative_kaiming_normal-on-cifar_trn-20epochs-adam_features.pt
already exists
3 vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt vgg19-init=kaiming_uniform-on-cifar_trn-20epochs-adam_features.pt
already exists
4 vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt vgg19-init=negative_kaiming_uniform-on-cifar_trn-20epochs-adam_features.pt
already exists
5 vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt vgg19-init=kaiming_uniform-on-cifar_trn-20epochs-sgd_features.pt
6 vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt vgg19-init=negative_kaiming_uniform-on-cifar_trn-20epochs-sgd_features.pt
7 vgg19-init=kaiming_normal-on-cifar_trn-20epochs-adam_features.pt vgg19-init=kaiming_normal-on-cifar_tr