In [1]:
# import faiss
import numpy as np
import pickle
import sys
import torch

from numpy.random import multivariate_normal

sys.path.insert(1, '../')
from base import BaseDataset, get_data, get_cifar100_coarse

# Pytorch
## 3 ways to make vector indices
1. standard mean of all samples
2. mean with norm of all samples
3. herding 1 element closest to the mean

In [2]:
def class_mean_norm(x):
    """
    Calculate the class mean for given x.
    Take the norm into account.
    
    param:
    x - the samples in numpy array
    
    return: the mean of xs
    """
    class_mean = np.sum(x / np.linalg.norm(x, axis=1, keepdims=True), axis=0) / x.shape[0]
    class_mean = class_mean / np.linalg.norm(class_mean)
    return class_mean 

In [9]:
def class_mean_norm(x):
    """
    Calculate the class mean for given x while taking the norm into account.
    
    Concretely, it does weighted mean and covariance:
        Mean: 
            Calculate the Frobenius norm of each element and use them as weights for a weighted mean. 
            This gives more importance to elements with larger magnitudes in the Frobenius norm sense.
        Covariance: 
            Apply the same weighting scheme to the centered data (subtract the mean calculated above from each element) 
            before calculating the covariance matrix.
            
    Caveat: the norm of the class mean is used to calculate the final class mean.
    
    param:
    x - the samples in numpy array
    
    return: the mean of xs, the standard deviation and the covariance
    """
    class_mean = np.sum(x / np.linalg.norm(x, axis=1, keepdims=True), axis=0) / x.shape[0]
    class_mean = class_mean / np.linalg.norm(class_mean)
    
    standard_dev = class_mean.std()
    
    centered_data = x - class_mean
    cov = np.cov(centered_data.T, aweights=np.linalg.norm(x, axis=1)**2)
    
    return class_mean, standard_dev, cov

In [3]:
train_embedding_path = "../cifar100_coarse_train_embedding_nn.pt"
val_embedding_path = None
test_embedding_path = "../cifar100_coarse_test_embedding_nn.pt"
ignore_super = False

n_class = 100
chosen_superclass = {1: [1, 32, 67, 73, 91],
                    2: [54, 62, 70, 82, 92],
                    14: [2, 11, 35, 46, 98],
                    17: [47, 52, 56, 59, 96]}

data, task_cla, class_order = get_cifar100_coarse(train_embedding_path, test_embedding_path, None, chosen_superclass=chosen_superclass, ignore_super=ignore_super)

# print(class_order)
idx2cls = {k: v for k, v in enumerate(class_order)}
cls2idx = {v: k for k, v in enumerate(class_order)}

superclass_order: [1, 2, 14, 17]
class_order: [1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]
total_task: 4
chosen_superclass: {1: [1, 32, 67, 73, 91], 2: [54, 62, 70, 82, 92], 14: [2, 11, 35, 46, 98], 17: [47, 52, 56, 59, 96]}


# 29/01/2024
Checking data split

In [4]:
from collections import Counter

for i in range(4):
    print(f"i-{i}")
    print(f"Train data: {len(data[i]['trn']['x'])}")
    print(f"\t{Counter(data[i]['trn']['y'])}")
    print(f"Val data: {len(data[i]['val']['x'])}")
    print(f"\t{Counter(data[i]['val']['y'])}") 
    print(f"Test data: {len(data[i]['tst']['x'])}")
    print(f"\t{Counter(data[i]['tst']['y'])}")  
    print()  

i-0
Train data: 2000
	Counter({3: 400, 2: 400, 1: 400, 0: 400, 4: 400})
Val data: 500
	Counter({0: 100, 1: 100, 2: 100, 3: 100, 4: 100})
Test data: 500
	Counter({3: 100, 1: 100, 2: 100, 4: 100, 0: 100})

i-1
Train data: 2000
	Counter({9: 400, 7: 400, 6: 400, 5: 400, 8: 400})
Val data: 500
	Counter({5: 100, 6: 100, 7: 100, 8: 100, 9: 100})
Test data: 500
	Counter({9: 100, 7: 100, 8: 100, 5: 100, 6: 100})

i-2
Train data: 2000
	Counter({11: 400, 12: 400, 13: 400, 10: 400, 14: 400})
Val data: 500
	Counter({10: 100, 11: 100, 12: 100, 13: 100, 14: 100})
Test data: 500
	Counter({11: 100, 10: 100, 12: 100, 13: 100, 14: 100})

i-3
Train data: 2000
	Counter({19: 400, 16: 400, 17: 400, 15: 400, 18: 400})
Val data: 500
	Counter({15: 100, 16: 100, 17: 100, 18: 100, 19: 100})
Test data: 500
	Counter({17: 100, 16: 100, 15: 100, 18: 100, 19: 100})



# 29/01/2024
Done checking

In [5]:
all_X, all_y = [], []
for i in range(4):
    all_X.extend(data[i]['trn']['x'])
    all_y.extend(data[i]['trn']['y'])

all_X = torch.stack(all_X).numpy()
all_y = np.vectorize(idx2cls.get)(all_y)

print(all_X.shape)
# print(len(all_y))
# print(all_y)

(8000, 768)


In [6]:
data.keys()

dict_keys([0, 1, 2, 3, 'ncla', 'ordered'])

In [7]:
# samples, samples_val, samples_tst = {}, {}, {}

# all_X_val, all_X_tst = [], []
# all_y_val, all_y_tst = [], []

# for i in range(4):
#     all_X_val.extend(data[i]['val']['x'])
#     all_X_tst.extend(data[i]['tst']['x'])
    
#     all_y_val.extend(data[i]['val']['y'])
#     all_y_tst.extend(data[i]['tst']['y'])

# all_X_val = torch.stack(all_X_val).numpy()
# all_X_tst = torch.stack(all_X_tst).numpy()

# all_y_val = np.vectorize(idx2cls.get)(all_y_val)
# all_y_tst = np.vectorize(idx2cls.get)(all_y_tst)

# for cls_ in class_order:
#     x_ = all_X[all_y == cls_]
#     x_val = all_X_val[all_y_val == cls_]
#     x_tst = all_X_tst[all_y_tst == cls_]
    
#     samples[cls_] = x_
#     samples_val[cls_] = x_val
#     samples_tst[cls_] = x_tst

# to_save = {
#     'trn': samples,
#     'val': samples_val,
#     'tst': samples_tst,
# }

# pickle.dump(to_save, open("class_samples.pkl", "wb"))

# samples = pickle.load(open("class_samples.pkl", "rb"))
# for k, v in samples['trn'].items():
#     print(f"{len(v)}\t{len(samples['val'][k])}\t{len(samples['tst'][k])}")
    

400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100


### 1. Get standard mean for all samples

In [14]:
cls_mean = {}
cls_cov = {}
for cls_ in class_order:
    x_ = all_X[all_y == cls_]
    # cls_mean[cls_] = torch.tensor(m_)
    cls_mean[cls_] = np.mean(x_, axis=0)
    cls_cov[cls_] = np.cov(x_.T)

print(cls_mean)


{1: array([ 7.27316797e-01, -4.04498816e-01, -3.06583375e-01,  1.79068546e-03,
        3.81747007e-01, -7.67703176e-01, -2.52703071e-01,  1.67838842e-01,
       -1.51580185e-01,  6.17780685e-01, -4.16703939e-01,  5.87966383e-01,
       -4.70115244e-01, -1.92558134e+00, -3.89908642e-01,  1.41020149e-01,
       -3.83115381e-01,  9.67558682e-01,  9.25975859e-01,  5.11077881e-01,
        1.44513047e+00,  1.24922164e-01, -3.62681955e-01, -1.23705077e+00,
       -2.43729293e-01, -9.45885301e-01, -5.80741227e-01,  3.83955598e-01,
        3.71582299e-01, -2.26154998e-01,  1.10685456e+00,  8.91578138e-01,
        4.07950997e-01,  4.26266640e-01, -1.40797153e-01,  2.25647497e+00,
        1.70816407e-01,  4.89319026e-01, -4.28821295e-01, -9.67811048e-01,
       -4.40722972e-01, -3.22870612e-02,  5.08437037e-01,  6.98126629e-02,
        2.65299832e-03, -1.41270542e+00,  5.50778396e-02, -4.53112096e-01,
        2.98401213e+00, -1.13026679e+00, -9.69770700e-02,  4.31977808e-02,
        7.17757463e-0

### 2. Get mean with norm for all samples

In [9]:
cls_means_norm = {}
cls_cov_norm = {}
cls_std_norm = {}
for cls_ in class_order:
    x_ = all_X[all_y == cls_]
    m_, s_, c_ = class_mean_norm(x_)
    # cls_means_norm[cls_] = torch.tensor(m_)
    cls_means_norm[cls_] = m_
    cls_cov_norm[cls_] = c_
    cls_std_norm[cls_] = s_

print(cls_means_norm)


{1: array([ 3.36071216e-02, -1.79494098e-02, -1.31756952e-02, -1.69723015e-03,
        1.92792173e-02, -3.34479287e-02, -9.45155229e-03,  7.36456504e-03,
       -5.33567090e-03,  2.72536762e-02, -1.44927436e-02,  2.70924792e-02,
       -2.40473300e-02, -9.16189998e-02, -1.78088713e-02,  1.12323267e-02,
       -1.72548033e-02,  4.61303107e-02,  4.07063179e-02,  2.26761512e-02,
        6.89955577e-02,  5.94451046e-03, -1.68567970e-02, -5.46799004e-02,
       -1.07571213e-02, -4.34063189e-02, -2.88633611e-02,  1.90803204e-02,
        1.42364353e-02, -1.44891599e-02,  5.49393184e-02,  4.24354300e-02,
        2.02329941e-02,  2.14630067e-02, -5.75589109e-03,  1.03820324e-01,
        8.49708356e-03,  2.29365118e-02, -1.89931411e-02, -4.27323245e-02,
       -1.84517782e-02,  3.24220001e-03,  2.34537404e-02,  4.60152375e-03,
        3.90285067e-03, -6.35409877e-02,  6.31788047e-03, -2.01900136e-02,
        1.40390545e-01, -5.39872982e-02, -5.53053990e-03, -2.32257662e-04,
        3.00459899e-0

In [10]:
print(cls_means_norm[1].shape)
print(cls_cov_norm[1].shape)
print(cls_std_norm[1])

(768,)
(768, 768)
0.036012094


### 3. herding 1 element closest to the mean

In [13]:
# https://avalanche-api.continualai.org/en/v0.1.0/_modules/avalanche/training/storage_policy.html#HerdingSelectionStrategy
def herding(features):
    selected_indices = []
    center = features.mean(axis=0)
    current_center = center * 0

    for i in range(len(features)):
        # Compute distances with real center
        candidate_centers = current_center * i / (i + 1) + features / (i + 1)
        distances = pow(candidate_centers - center, 2).sum(axis=1)
        distances[selected_indices] = np.inf

        # Select best candidate
        new_index = distances.argmin().tolist()
        selected_indices.append(new_index)
        current_center = candidate_centers[new_index]

    return selected_indices

def closest_to_center(features):
    center = features.mean(axis=0)
    distances = pow(features - center, 2).sum(axis=1)
    return distances.argsort()

In [12]:
herds = {}
for cls_ in class_order:
    x_ = all_X[all_y == cls_]
    herd_index = herding(x_)[0]
    # herds[cls_] = torch.tensor(x_[herd_index])
    herds[cls_] = x_[herd_index]

print(herds)

{1: array([ 3.03295553e-01, -4.68688160e-01, -5.02184359e-03, -6.86901033e-01,
        4.62199330e-01, -1.52794540e+00,  2.04632670e-01,  1.10086605e-01,
       -5.84385395e-01,  8.23821276e-02, -4.71936353e-02,  1.02075076e+00,
       -2.28946298e-01, -1.72137296e+00, -1.33222091e+00,  6.17095292e-01,
        2.99843308e-02,  8.34216535e-01,  6.20164573e-01,  3.87073159e-01,
        1.37861264e+00, -3.99404854e-01,  4.73487675e-01, -1.32023096e+00,
       -5.52278399e-01, -1.38497782e+00, -4.83816266e-01,  7.49962986e-01,
        2.28707328e-01, -9.58132893e-02,  2.03235126e+00,  2.01178834e-01,
        6.69346511e-01,  9.59112763e-01, -4.25726503e-01,  3.02583528e+00,
        1.13890395e-01, -1.22360259e-01, -8.09126437e-01, -8.54385436e-01,
       -2.19733641e-01,  1.61202073e-01,  7.02559277e-02, -6.48962736e-01,
       -2.61798084e-01, -2.38647008e+00, -5.58014095e-01,  2.34946594e-01,
        2.90392280e+00, -1.19433367e+00, -6.76755667e-01,  1.19795978e-01,
        6.67879879e-0

### 4. Save for future use

In [15]:
# to_save = {
#     "data": data,
#     "task_cla": task_cla,
#     "class_order": class_order,
#     "class_mean": cls_mean,
#     "class_mean_norm": cls_means_norm,
#     "class_std_norm": cls_std_norm,
#     "class_cov_norm": cls_cov_norm,
#     "herding": herds,    
#     "samples": {
#         "trn": samples,
#         "val": samples_val,
#         "tst": samples_tst,
#     }
# }

# pickle.dump(to_save, open("distance_data.pkl", "wb"))

## Calculate Euclidean Distance using Standard Deviation and factor $\alpha$

In [16]:
data = pickle.load(open("distance_data.pkl", "rb"))
samples = data['samples']
data.keys()

dict_keys(['data', 'task_cla', 'class_order', 'class_mean', 'class_mean_norm', 'class_std_norm', 'class_cov_norm', 'herding', 'samples'])

In [17]:
print(data['class_mean'].keys())
print(samples['trn'].keys())
print(data['class_std_norm'])

dict_keys([1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96])
dict_keys(['trn', 'val', 'tst'])
{1: 0.036012094, 32: 0.035914794, 67: 0.03594373, 73: 0.035979237, 91: 0.035972454, 54: 0.035970148, 62: 0.036010336, 70: 0.035992365, 82: 0.035989147, 92: 0.0359705, 2: 0.035924025, 11: 0.035911325, 35: 0.035909723, 46: 0.035915878, 98: 0.035920296, 47: 0.035973284, 52: 0.03597919, 56: 0.035943724, 59: 0.03595718, 96: 0.035956316}


In [18]:
class_order = data['class_order']
idx2cls = {k: v for k, v in enumerate(class_order)}
cls2idx = {v: k for k, v in enumerate(class_order)}
cls2str = {1: "aquarium_fish", 32: "flatfish", 67: "ray", 73: "shark", 91: "trout", 
           54: "orchid", 62: "poppy", 70: "rose", 82: "sunflower", 92: "tulip",
           2: "baby", 11: "boy", 35: "girl", 46: "man", 98: "woman",
           47: "maple_tree", 52: "oak_tree", 56: "palm_tree", 59: "pine_tree", 96: "willow_tree"}

In [19]:
# t1_cls = [1, 32, 54, 62, 70]
# t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
# t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]
# t2_idx2cls = {k: v for k, v in enumerate(t2_cls)}

def calculate_distance(t1_cls, t2_cls):
    t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
    t2_idx2cls = {k: v for k, v in enumerate(t2_cls)}

    t1_mean_norms, t2_mean_norms = [], []
    for c1 in t1_cls:
        t1_mean_norms.append(data['class_mean_norm'][c1])
    for c2 in t2_cls:
        t2_mean_norms.append(data['class_mean_norm'][c2])
        
    t1_mean_norms = np.stack(t1_mean_norms)
    t2_mean_norms = np.stack(t2_mean_norms)

    distances = np.linalg.norm(t1_mean_norms[:, np.newaxis] - t2_mean_norms, axis=2)

    # for i, (cls_, query_vector) in enumerate(zip(t2_cls, t2_mean_norms)):
        # closest_index = np.argmin(distances[:, i])
        # print(f"{cls2str[cls_]} - {cls2str[t1_idx2cls[closest_index]]} : {pow(distances[closest_index, i], 2):.4f}")
    
    return distances, t1_idx2cls, t2_idx2cls

In [20]:
# print(distances)

def calculate_distance_within_alphas(distances, t1_cls, t2_cls, alphas):
    dist = np.square(distances.T)
    print(dist)

    alphas = [i for i in range(1, 30)]
    distance_within_alphas = {}

    for i, t2_cls_ in enumerate(t2_cls):
        t2_str = cls2str[t2_cls_]
        t2_distance = distance_within_alphas.get(t2_str, {})
        
        for alpha in alphas:
            t2_alpha = t2_distance.get(alpha, [])
            
            for j, t1_cls_ in enumerate(t1_cls):
                t1_str = cls2str[t1_cls_]
                t1_std = data['class_std_norm'][t1_cls_]
                dist_ = dist[i][j]
                print(i, t2_str, alpha, t1_str, t1_std, dist_)
                t2_alpha.append((t1_str, dist_, t1_std) if alpha * t1_std > dist_ else False)
                
            t2_distance[alpha] = t2_alpha
            distance_within_alphas[t2_str] = t2_distance
    
    return distance_within_alphas

In [21]:
t1_cls = [1, 32, 54, 62, 70]
t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

distances, t1_idx2cls, t2_idx2cls = calculate_distance(t1_cls, t2_cls)

for i, t1_cls_ in enumerate(t1_cls):
    t1_std = data['class_std_norm'][t1_cls_]
    for j, t2_cls_ in enumerate(t2_cls):
        t1_t2_distance = distances[i][j]
        print(cls2str[t1_cls_], "--", cls2str[t2_cls_], t1_std, t1_t2_distance)

aquarium_fish -- ray 0.036012094 0.9109668
aquarium_fish -- shark 0.036012094 0.9710969
aquarium_fish -- trout 0.036012094 0.7306827
aquarium_fish -- sunflower 0.036012094 1.2119881
aquarium_fish -- tulip 0.036012094 1.1776173
aquarium_fish -- baby 0.036012094 1.2008721
aquarium_fish -- boy 0.036012094 1.1691731
aquarium_fish -- girl 0.036012094 1.165725
aquarium_fish -- man 0.036012094 1.1631287
aquarium_fish -- woman 0.036012094 1.1851892
aquarium_fish -- maple_tree 0.036012094 1.1598411
aquarium_fish -- oak_tree 0.036012094 1.1834888
aquarium_fish -- palm_tree 0.036012094 1.1483909
aquarium_fish -- pine_tree 0.036012094 1.1365764
aquarium_fish -- willow_tree 0.036012094 1.1264191
flatfish -- ray 0.035914794 0.6625631
flatfish -- shark 0.035914794 0.9191732
flatfish -- trout 0.035914794 0.7070166
flatfish -- sunflower 0.035914794 1.0658798
flatfish -- tulip 0.035914794 1.0487133
flatfish -- baby 0.035914794 0.9400064
flatfish -- boy 0.035914794 0.80232775
flatfish -- girl 0.035914794

1. Transpose the distance `distance.T` to get `target (15) -> source (5)`
2. Calculate the distance from each `target` to `source`
3. Take into account $\alpha$ * `source_std`

In [22]:
alphas = [i for i in range(1, 30)]
distance_within_alphas = calculate_distance_within_alphas(distances, t1_cls, t2_cls, alphas)
distance_within_alphas

[[0.8298605  0.43898985 1.3834373  1.3602514  1.2309551 ]
 [0.94302917 0.8448793  1.5728471  1.5670425  1.4762114 ]
 [0.53389716 0.49987245 1.4213208  1.4860897  1.357609  ]
 [1.4689151  1.1360998  1.0507425  0.99467725 0.7342686 ]
 [1.3867825  1.0997996  0.60514796 0.43467182 0.3510074 ]
 [1.4420937  0.883612   1.4255637  1.4330143  1.187301  ]
 [1.3669658  0.6437298  1.3776999  1.3778806  1.1139759 ]
 [1.3589147  0.6798042  1.349      1.374446   1.0355508 ]
 [1.3528684  0.5869079  1.4564683  1.4070889  1.2081133 ]
 [1.4046736  0.78762335 1.394456   1.4062375  1.0807223 ]
 [1.3452313  0.7581604  1.1032561  1.0070745  0.8413917 ]
 [1.4006459  0.77890366 1.2517904  1.1535447  1.005237  ]
 [1.3188016  0.7896038  1.1403838  1.1068218  0.8940955 ]
 [1.291806   0.6893483  1.1780411  1.0648863  0.919125  ]
 [1.2688199  0.61729294 1.1050403  1.0205464  0.82391214]]
0 ray 1 aquarium_fish 0.036012094 0.8298605
0 ray 1 flatfish 0.035914794 0.43898985
0 ray 1 orchid 0.035970148 1.3834373
0 ray 1 

{'ray': {1: [False, False, False, False, False],
  2: [False, False, False, False, False],
  3: [False, False, False, False, False],
  4: [False, False, False, False, False],
  5: [False, False, False, False, False],
  6: [False, False, False, False, False],
  7: [False, False, False, False, False],
  8: [False, False, False, False, False],
  9: [False, False, False, False, False],
  10: [False, False, False, False, False],
  11: [False, False, False, False, False],
  12: [False, False, False, False, False],
  13: [False, ('flatfish', 0.43898985, 0.035914794), False, False, False],
  14: [False, ('flatfish', 0.43898985, 0.035914794), False, False, False],
  15: [False, ('flatfish', 0.43898985, 0.035914794), False, False, False],
  16: [False, ('flatfish', 0.43898985, 0.035914794), False, False, False],
  17: [False, ('flatfish', 0.43898985, 0.035914794), False, False, False],
  18: [False, ('flatfish', 0.43898985, 0.035914794), False, False, False],
  19: [False, ('flatfish', 0.4389898

In [23]:
def print_min_distance(distance_within_alphas):
    sources = distance_within_alphas.keys()
    for source in sources:
        alphas = distance_within_alphas[source]
        found = False
        
        for alpha, vals in alphas.items():
            if any(isinstance(item, tuple) for item in vals) and not found:
                for item in vals:
                    if isinstance(item, tuple) and source != item[0]:
                        print(f"{source},{item[0]},{item[1]:.5f},{item[2]:.5f},{alpha}")
                        found = True
                        break
        # print()

In [24]:
print_min_distance(distance_within_alphas)

ray,flatfish,0.43899,0.03591,13
shark,flatfish,0.84488,0.03591,24
trout,flatfish,0.49987,0.03591,14
sunflower,rose,0.73427,0.03599,21
tulip,rose,0.35101,0.03599,10
baby,flatfish,0.88361,0.03591,25
boy,flatfish,0.64373,0.03591,18
girl,flatfish,0.67980,0.03591,19
man,flatfish,0.58691,0.03591,17
woman,flatfish,0.78762,0.03591,22
maple_tree,flatfish,0.75816,0.03591,22
oak_tree,flatfish,0.77890,0.03591,22
palm_tree,flatfish,0.78960,0.03591,22
pine_tree,flatfish,0.68935,0.03591,20
willow_tree,flatfish,0.61729,0.03591,18


Proof of concept: calculate the `distance_within_alphas` for cartesian product, excluding the diagonal (because it will be 0)

In [25]:
t1_cls = [1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

distances, t1_idx2cls, t2_idx2cls = calculate_distance(t1_cls, t1_cls)

for i, t1_cls_ in enumerate(t1_cls):
    t1_std = data['class_std_norm'][t1_cls_]
    for j, t1_cls_ in enumerate(t1_cls):
        t1_t2_distance = distances[i][j]
        print(cls2str[t1_cls_], "--", cls2str[t1_cls_], t1_std, t1_t2_distance)

aquarium_fish -- aquarium_fish 0.036012094 0.0
flatfish -- flatfish 0.036012094 0.93689615
ray -- ray 0.036012094 0.9109668
shark -- shark 0.036012094 0.9710969
trout -- trout 0.036012094 0.7306827
orchid -- orchid 0.036012094 1.2259445
poppy -- poppy 0.036012094 1.2079712
rose -- rose 0.036012094 1.1960306
sunflower -- sunflower 0.036012094 1.2119881
tulip -- tulip 0.036012094 1.1776173
baby -- baby 0.036012094 1.2008721
boy -- boy 0.036012094 1.1691731
girl -- girl 0.036012094 1.165725
man -- man 0.036012094 1.1631287
woman -- woman 0.036012094 1.1851892
maple_tree -- maple_tree 0.036012094 1.1598411
oak_tree -- oak_tree 0.036012094 1.1834888
palm_tree -- palm_tree 0.036012094 1.1483909
pine_tree -- pine_tree 0.036012094 1.1365764
willow_tree -- willow_tree 0.036012094 1.1264191
aquarium_fish -- aquarium_fish 0.035914794 0.93689615
flatfish -- flatfish 0.035914794 0.0
ray -- ray 0.035914794 0.6625631
shark -- shark 0.035914794 0.9191732
trout -- trout 0.035914794 0.7070166
orchid -- 

In [26]:
alphas = [i for i in range(1, 30)]
distance_within_alphas = calculate_distance_within_alphas(distances, t1_cls, t1_cls, alphas)

[[0.         0.87777436 0.8298605  0.94302917 0.53389716 1.5029399
  1.4591944  1.4304892  1.4689151  1.3867825  1.4420937  1.3669658
  1.3589147  1.3528684  1.4046736  1.3452313  1.4006459  1.3188016
  1.291806   1.2688199 ]
 [0.87777436 0.         0.43898985 0.8448793  0.49987245 1.2613235
  1.2779337  1.0930789  1.1360998  1.0997996  0.883612   0.6437298
  0.6798042  0.5869079  0.78762335 0.7581604  0.77890366 0.7896038
  0.6893483  0.61729294]
 [0.8298605  0.43898985 0.         0.38103166 0.6616935  1.3834373
  1.3602514  1.2309551  1.2123548  1.2288872  1.1250894  0.99658626
  0.9814506  0.9529264  1.0462887  1.035406   1.0570855  0.9466541
  0.9151786  0.9009826 ]
 [0.94302917 0.8448793  0.38103166 0.         0.74592865 1.5728471
  1.5670425  1.4762114  1.4490457  1.4781109  1.3846058  1.2659986
  1.2754849  1.2079417  1.2910007  1.304616   1.3164773  1.1842687
  1.2083961  1.2230257 ]
 [0.53389716 0.49987245 0.6616935  0.74592865 0.         1.4213208
  1.4860897  1.357609   1.41

In [27]:
print_min_distance(distance_within_alphas)

aquarium_fish,trout,0.53390,0.03597,15
flatfish,ray,0.43899,0.03594,13
ray,shark,0.38103,0.03598,11
shark,ray,0.38103,0.03594,11
trout,flatfish,0.49987,0.03591,14
orchid,tulip,0.60515,0.03597,17
poppy,tulip,0.43467,0.03597,13
rose,tulip,0.35101,0.03597,10
sunflower,tulip,0.70958,0.03597,20
tulip,rose,0.35101,0.03599,10
baby,boy,0.27487,0.03591,8
boy,girl,0.09022,0.03591,3
girl,boy,0.09022,0.03591,3
man,boy,0.20994,0.03591,6
woman,girl,0.23238,0.03591,7
maple_tree,oak_tree,0.08704,0.03598,3
oak_tree,maple_tree,0.08704,0.03597,3
palm_tree,pine_tree,0.35872,0.03596,10
pine_tree,willow_tree,0.10686,0.03596,3
willow_tree,maple_tree,0.10775,0.03597,3


In [28]:
distance_within_alphas['maple_tree']

{1: [False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  ('maple_tree', 0.0, 0.035973284),
  False,
  False,
  False,
  False],
 2: [False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  ('maple_tree', 0.0, 0.035973284),
  False,
  False,
  False,
  False],
 3: [False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  ('maple_tree', 0.0, 0.035973284),
  ('oak_tree', 0.087040134, 0.03597919),
  False,
  False,
  ('willow_tree', 0.107745476, 0.035956316)],
 4: [False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  ('maple_tree', 0.0, 0.035973284),
  ('oak_tree', 0.087040134, 0.03597919),
  False,
  ('pine_tree', 0.1349383, 0.03595718),
  ('willow_tree', 0.107745476, 0.03

In [29]:
distance_within_alphas['maple_tree'].keys()

source = 'maple_tree'
alphas = distance_within_alphas[source]
found = False
# while not found:
for alpha, vals in alphas.items():
    if any(isinstance(item, tuple) for item in vals) and not found:
        for item in vals:
            if isinstance(item, tuple) and source != item[0]:
                print(f"{source},{item[0]},{item[1]:.5f},{item[2]:.5f},{alpha}")
                found = True
                break


maple_tree,oak_tree,0.08704,0.03598,3


# BACK TO FAISS SIMILARITY SEARCH
### Task orders: 1: [1, 32, 54, 62, 70], 2: [67, 73, 91, 82, 92]

#### Try put [1, 32, 54, 62, 70, 67, 73, 91, 82, 92] in Faiss index

In [1]:
import faiss
import numpy as np
import pickle

In [19]:
vector_rep = pickle.load(open("class_vector_representation.pkl", "rb"))
samples = pickle.load(open("class_samples.pkl", "rb"))

In [20]:
print(vector_rep['class_mean'].keys())
print(samples['trn'].keys())

dict_keys([1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96])
dict_keys([1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96])


In [4]:
def calculate_accuracy(index, samples, verbose=False):
    total_hit, total_miss = 0, 0
    
    for i, (k, v) in enumerate(samples.items()):
        _, I = index.search(v, 1)
        hit = np.where(I[:, 0] == i)[0].shape[0]
        miss = np.where(I[:, 0] != i)[0].shape[0]
        total_hit += hit
        total_miss += miss
        
        if verbose:
            print(f"class-{k}\thit: {hit}\tmiss: {miss}")
            
    hit_pct = total_hit / (total_hit + total_miss)
    miss_pct = total_miss / (total_hit + total_miss)
    print(f"hit_pct: {(100 * hit_pct):.2f}\tmiss_pct: {(100 * miss_pct):.2f}\n")

In [5]:
all_means = []
for k, v in vector_rep['class_mean'].items():
    all_means.append(v)
    
all_means = np.stack(all_means)

In [6]:
d = 768
nb = 20

index = faiss.IndexFlatL2(d)
index.add(all_means)

print("standard mean accuracy on train samples")
calculate_accuracy(index, samples['trn'])

print("standard mean accuracy on validation samples")
calculate_accuracy(index, samples['val'])

print("standard mean accuracy on test samples")
calculate_accuracy(index, samples['tst'])

standard mean accuracy on train samples
hit_pct: 72.52	miss_pct: 27.47

standard mean accuracy on validation samples
hit_pct: 70.30	miss_pct: 29.70

standard mean accuracy on test samples
hit_pct: 72.25	miss_pct: 27.75



In [7]:
all_means_norm = []
for k, v in vector_rep['class_mean_norm'].items():
    all_means_norm.append(v)
    
all_means_norm = np.stack(all_means_norm)

index_norm = faiss.IndexFlatL2(d)
index_norm.add(all_means_norm)

print("mean norm accuracy on train samples")
calculate_accuracy(index_norm, samples['trn'])

print("mean norm accuracy on validation samples")
calculate_accuracy(index_norm, samples['val'])

print("mean norm accuracy on test samples")
calculate_accuracy(index_norm, samples['tst'])

mean norm accuracy on train samples
hit_pct: 74.34	miss_pct: 25.66

mean norm accuracy on validation samples
hit_pct: 72.45	miss_pct: 27.55

mean norm accuracy on test samples
hit_pct: 73.90	miss_pct: 26.10



In [8]:
herds = []
for k, v in vector_rep['herding'].items():
    herds.append(v)
    
herds = np.stack(herds)

index_herding = faiss.IndexFlatL2(d)
index_herding.add(herds)

print("1-herding accuracy on train samples")
calculate_accuracy(index_herding, samples['trn'])

print("1-herding accuracy on validation samples")
calculate_accuracy(index_herding, samples['val'])

print("1-herding accuracy on test samples")
calculate_accuracy(index_herding, samples['tst'])

1-herding accuracy on train samples
hit_pct: 56.41	miss_pct: 43.59

1-herding accuracy on validation samples
hit_pct: 56.50	miss_pct: 43.50

1-herding accuracy on test samples
hit_pct: 57.45	miss_pct: 42.55



In [9]:
index_ip = faiss.IndexFlatIP(d)
index_ip.add(all_means)

print("standard mean IP accuracy on train samples")
calculate_accuracy(index_ip, samples['trn'])

print("standard mean IP accuracy on validation samples")
calculate_accuracy(index_ip, samples['val'])

print("standard mean IP accuracy on test samples")
calculate_accuracy(index_ip, samples['tst'])

standard mean IP accuracy on train samples
hit_pct: 64.33	miss_pct: 35.68

standard mean IP accuracy on validation samples
hit_pct: 63.50	miss_pct: 36.50

standard mean IP accuracy on test samples
hit_pct: 63.00	miss_pct: 37.00



In [10]:
index_norm_ip = faiss.IndexFlatIP(d)
index_norm_ip.add(all_means_norm)

print("mean norm IP accuracy on train samples")
calculate_accuracy(index_norm_ip, samples['trn'])

print("mean norm IP accuracy on validation samples")
calculate_accuracy(index_norm_ip, samples['val'])

print("mean norm IP accuracy on test samples")
calculate_accuracy(index_norm_ip, samples['tst'])

mean norm IP accuracy on train samples
hit_pct: 74.34	miss_pct: 25.66

mean norm IP accuracy on validation samples
hit_pct: 72.45	miss_pct: 27.55

mean norm IP accuracy on test samples
hit_pct: 73.90	miss_pct: 26.10



In [11]:
index_herding_ip = faiss.IndexFlatIP(d)
index_herding_ip.add(herds)

print("1-herding IP accuracy on train samples")
calculate_accuracy(index_herding_ip, samples['trn'])

print("1-herding IP accuracy on validation samples")
calculate_accuracy(index_herding_ip, samples['val'])

print("1-herding IP accuracy on test samples")
calculate_accuracy(index_herding_ip, samples['tst'])

1-herding IP accuracy on train samples
hit_pct: 50.22	miss_pct: 49.78

1-herding IP accuracy on validation samples
hit_pct: 49.05	miss_pct: 50.95

1-herding IP accuracy on test samples
hit_pct: 50.15	miss_pct: 49.85



# STILL FAISS

## Figuring out the threshold for similarity search

Task: generate mean vector for T2 classes, calculate similarity to T1 classes. If some T2 classes are similar to T1, then extend the expert with new classes, and update on the experts. and then update the gate.

1. Use mean with norm for all classes
2. Check every T2 class to all T1 classes
    1. need a threshold to evaluate the similarity between the two subsets
        1. test with T1 and T2 are totally different classes
            - T1: [1 (aquarium_fish), 32 (flatfish), 67 (ray), 73 (shark), 91 (trout)]
            - T2: all other 15 classes
        2. test with T1 contains some classes under the same superclass with T2 classes
            - T1: [1 (aquarium_fish), 32 (flatfish), 54 (orchid), 62 (poppy), 70 (rose)]
            - T2: [67 (ray), 73 (shark), 91 (trout), 82 (sunflower), 92 (tulip)] + all other 10 classes

In [12]:
print(vector_rep['class_mean'].keys())
print(samples['trn'].keys())

dict_keys([1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96])
dict_keys([1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96])


### Using Flat L2 (Euclidean) index

#### Build the index on [1, 32, 67, 73, 91]

In [41]:
class_order = [1, 32, 67, 73, 91, 54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]
idx2cls = {k: v for k, v in enumerate(class_order)}
cls2idx = {v: k for k, v in enumerate(class_order)}
cls2str = {1: "aquarium_fish", 32: "flatfish", 67: "ray", 73: "shark", 91: "trout", 
           54: "orchid", 62: "poppy", 70: "rose", 82: "sunflower", 92: "tulip",
           2: "baby", 11: "boy", 35: "girl", 46: "man", 98: "woman",
           47: "maple_tree", 52: "oak_tree", 56: "palm_tree", 59: "pine_tree", 96: "willow_tree"}

In [14]:
t1_cls = [1, 32, 67, 73, 91]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

# print(t1_mean_norms)
# print(t2_mean_norms)
d = 768
nb = 5

index1 = faiss.IndexFlatL2(d)
index1.add(t1_mean_norms)

D, I = index1.search(t2_mean_norms, 1)

for i, c in enumerate(t2_cls):
    print(f"{cls2str[c]} - {cls2str[t1_idx2cls[I[i][0]]]} : {D[i][0]:.4f}")

orchid - flatfish : 1.2708
poppy - flatfish : 1.2996
rose - flatfish : 1.1161
sunflower - flatfish : 1.1430
tulip - flatfish : 1.0975
baby - flatfish : 0.8830
boy - flatfish : 0.6233
girl - flatfish : 0.6847
man - flatfish : 0.5563
woman - flatfish : 0.7797
maple_tree - flatfish : 0.7539
oak_tree - flatfish : 0.7822
palm_tree - flatfish : 0.7990
pine_tree - flatfish : 0.6965
willow_tree - flatfish : 0.6143


#### Build the index on [1, 32, 54, 62, 70]

In [15]:
t1_cls = [1, 32, 54, 62, 70]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

d = 768
nb = 5

index2 = faiss.IndexFlatL2(d)
index2.add(t1_mean_norms)

D, I = index2.search(t2_mean_norms, 1)
for i, c in enumerate(t2_cls):
    print(f"{cls2str[c]} - {cls2str[t1_idx2cls[I[i][0]]]} : {D[i][0]:.4f}")

ray - flatfish : 0.4523
shark - flatfish : 0.8352
trout - aquarium_fish : 0.5196
sunflower - rose : 0.7239
tulip - rose : 0.3427
baby - flatfish : 0.8830
boy - flatfish : 0.6233
girl - flatfish : 0.6847
man - flatfish : 0.5563
woman - flatfish : 0.7797
maple_tree - flatfish : 0.7539
oak_tree - flatfish : 0.7822
palm_tree - flatfish : 0.7990
pine_tree - flatfish : 0.6965
willow_tree - flatfish : 0.6143


### Calculate Using standard Numpy Euclidean

In [16]:
t1_cls = [1, 32, 67, 73, 91]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

distances = np.linalg.norm(t1_mean_norms[:, np.newaxis] - t2_mean_norms, axis=2)

for i, (cls_, query_vector) in enumerate(zip(t2_cls, t2_mean_norms)):
    closest_index = np.argmin(distances[:, i])
    print(f"{cls2str[cls_]} - {cls2str[t1_idx2cls[closest_index]]} : {pow(distances[closest_index, i], 2):.4f}")

orchid - flatfish : 1.2708
poppy - flatfish : 1.2996
rose - flatfish : 1.1161
sunflower - flatfish : 1.1430
tulip - flatfish : 1.0975
baby - flatfish : 0.8830
boy - flatfish : 0.6233
girl - flatfish : 0.6847
man - flatfish : 0.5563
woman - flatfish : 0.7797
maple_tree - flatfish : 0.7539
oak_tree - flatfish : 0.7822
palm_tree - flatfish : 0.7990
pine_tree - flatfish : 0.6965
willow_tree - flatfish : 0.6143


#### Numpy Cosine similarity (a.k.a FAISS IP)

In [17]:
similarities = np.dot(t1_mean_norms, t2_mean_norms.T) / (np.linalg.norm(t1_mean_norms, axis=1)[:, np.newaxis] * np.linalg.norm(t2_mean_norms, axis=1))
for i, (cls_, query_vector) in enumerate(zip(t2_cls, t2_mean_norms)):
    closest_index = np.argmax(similarities[:, i])
    print(f"{cls2str[cls_]} - {cls2str[t1_idx2cls[closest_index]]} : {similarities[closest_index, i]:.4f}")

orchid - flatfish : 0.3646
poppy - flatfish : 0.3502
rose - flatfish : 0.4419
sunflower - flatfish : 0.4285
tulip - flatfish : 0.4512
baby - flatfish : 0.5585
boy - flatfish : 0.6884
girl - flatfish : 0.6576
man - flatfish : 0.7219
woman - flatfish : 0.6102
maple_tree - flatfish : 0.6230
oak_tree - flatfish : 0.6089
palm_tree - flatfish : 0.6005
pine_tree - flatfish : 0.6518
willow_tree - flatfish : 0.6929


### Calculate Using standard Numpy Euclidean

In [18]:
t1_cls = [1, 32, 54, 62, 70]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

distances = np.linalg.norm(t1_mean_norms[:, np.newaxis] - t2_mean_norms, axis=2)

for i, (cls_, query_vector) in enumerate(zip(t2_cls, t2_mean_norms)):
    closest_index = np.argmin(distances[:, i])
    print(f"{cls2str[cls_]} - {cls2str[t1_idx2cls[closest_index]]} : {pow(distances[closest_index, i], 2):.4f}")

ray - flatfish : 0.4523
shark - flatfish : 0.8352
trout - aquarium_fish : 0.5196
sunflower - rose : 0.7239
tulip - rose : 0.3427
baby - flatfish : 0.8830
boy - flatfish : 0.6233
girl - flatfish : 0.6847
man - flatfish : 0.5563
woman - flatfish : 0.7797
maple_tree - flatfish : 0.7539
oak_tree - flatfish : 0.7822
palm_tree - flatfish : 0.7990
pine_tree - flatfish : 0.6965
willow_tree - flatfish : 0.6143


#### Numpy cosine similarity (a.k.a FAISS IP)

In [19]:
similarities = np.dot(t1_mean_norms, t2_mean_norms.T) / (np.linalg.norm(t1_mean_norms, axis=1)[:, np.newaxis] * np.linalg.norm(t2_mean_norms, axis=1))
for i, (cls_, query_vector) in enumerate(zip(t2_cls, t2_mean_norms)):
    closest_index = np.argmax(similarities[:, i])
    print(f"{cls2str[cls_]} - {cls2str[t1_idx2cls[closest_index]]} : {similarities[closest_index, i]:.4f}")

ray - flatfish : 0.7739
shark - flatfish : 0.5824
trout - aquarium_fish : 0.7402
sunflower - rose : 0.6381
tulip - rose : 0.8286
baby - flatfish : 0.5585
boy - flatfish : 0.6884
girl - flatfish : 0.6576
man - flatfish : 0.7219
woman - flatfish : 0.6102
maple_tree - flatfish : 0.6230
oak_tree - flatfish : 0.6089
palm_tree - flatfish : 0.6005
pine_tree - flatfish : 0.6518
willow_tree - flatfish : 0.6929


### Using Flat IP (dot product) index

#### Build the index on [1, 32, 67, 73, 91]

In [20]:
t1_cls = [1, 32, 67, 73, 91]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

# print(t1_mean_norms)
# print(t2_mean_norms)
d = 768
nb = 5

index1 = faiss.IndexFlatIP(d)
index1.add(t1_mean_norms)

D, I = index1.search(t2_mean_norms, 1)

for i, c in enumerate(t2_cls):
    print(f"{cls2str[c]} - {cls2str[t1_idx2cls[I[i][0]]]} : {D[i][0]:.4f}")

orchid - flatfish : 0.3646
poppy - flatfish : 0.3502
rose - flatfish : 0.4419
sunflower - flatfish : 0.4285
tulip - flatfish : 0.4512
baby - flatfish : 0.5585
boy - flatfish : 0.6884
girl - flatfish : 0.6576
man - flatfish : 0.7219
woman - flatfish : 0.6102
maple_tree - flatfish : 0.6230
oak_tree - flatfish : 0.6089
palm_tree - flatfish : 0.6005
pine_tree - flatfish : 0.6518
willow_tree - flatfish : 0.6929


In [21]:
t1_cls = [1, 32, 54, 62, 70]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

d = 768
nb = 5

index2 = faiss.IndexFlatIP(d)
index2.add(t1_mean_norms)

D, I = index2.search(t2_mean_norms, 1)
for i, c in enumerate(t2_cls):
    print(f"{cls2str[c]} - {cls2str[t1_idx2cls[I[i][0]]]} : {D[i][0]:.4f}")

ray - flatfish : 0.7739
shark - flatfish : 0.5824
trout - aquarium_fish : 0.7402
sunflower - rose : 0.6381
tulip - rose : 0.8286
baby - flatfish : 0.5585
boy - flatfish : 0.6884
girl - flatfish : 0.6576
man - flatfish : 0.7219
woman - flatfish : 0.6102
maple_tree - flatfish : 0.6230
oak_tree - flatfish : 0.6089
palm_tree - flatfish : 0.6005
pine_tree - flatfish : 0.6518
willow_tree - flatfish : 0.6929


In [31]:
chosen_superclass = {1: [1, 32, 67, 73, 91],
                    2: [54, 62, 70, 82, 92],
                    14: [2, 11, 35, 46, 98],
                    17: [47, 52, 56, 59, 96]}

sub2sup = {v_: k for k, v in chosen_superclass.items() for v_ in v}

### Try different FAISS index

In [24]:
t1_cls = [1, 32, 67, 73, 91]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [54, 62, 70, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

d = 768
nb = d * 4 # resolution of bucketed vectors

index = faiss.IndexLSH(d, nb)
index1.add(t1_mean_norms)

D, I = index1.search(t2_mean_norms, 1)

for i, c in enumerate(t2_cls):
    print(f"{cls2str[c]} - {cls2str[t1_idx2cls[I[i][0]]]} : {D[i][0]:.4f}")


orchid - flatfish : 0.3646
poppy - flatfish : 0.3502
rose - flatfish : 0.4419
sunflower - flatfish : 0.4285
tulip - flatfish : 0.4512
baby - flatfish : 0.5585
boy - flatfish : 0.6884
girl - flatfish : 0.6576
man - flatfish : 0.7219
woman - flatfish : 0.6102
maple_tree - flatfish : 0.6230
oak_tree - flatfish : 0.6089
palm_tree - flatfish : 0.6005
pine_tree - flatfish : 0.6518
willow_tree - flatfish : 0.6929


In [27]:
t1_cls = [1, 32, 54, 62, 70]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

t1_mean_norms, t2_mean_norms = [], []
for c1 in t1_cls:
    t1_mean_norms.append(vector_rep['class_mean_norm'][c1])
for c2 in t2_cls:
    t2_mean_norms.append(vector_rep['class_mean_norm'][c2])
    
t1_mean_norms = np.stack(t1_mean_norms)
t2_mean_norms = np.stack(t2_mean_norms)

d = 768
nb = d * 4

index2 = faiss.IndexLSH(d, nb)
index2.add(t1_mean_norms)

D, I = index2.search(t2_mean_norms, 1)
for i, c in enumerate(t2_cls):
    print(f"{cls2str[c]} - {cls2str[t1_idx2cls[I[i][0]]]} : {D[i][0]:.4f}")

ray - flatfish : 661.0000
shark - flatfish : 908.0000
trout - flatfish : 726.0000
sunflower - rose : 828.0000
tulip - rose : 573.0000
baby - flatfish : 952.0000
boy - flatfish : 791.0000
girl - flatfish : 824.0000
man - flatfish : 737.0000
woman - flatfish : 879.0000
maple_tree - flatfish : 872.0000
oak_tree - flatfish : 882.0000
palm_tree - flatfish : 901.0000
pine_tree - flatfish : 835.0000
willow_tree - flatfish : 808.0000


## GRAAAHH

In [51]:
vector_rep = pickle.load(open("class_vector_representation_cov_new.pkl", "rb"))
# vector_rep = pickle.load(open("class_vector_representation_cov.pkl", "rb"))
class_mean_norm = vector_rep['class_mean_norm']
class_cov_norm = vector_rep['class_cov_norm']

In [14]:
loc = torch.tensor(cls_means_norm[1], dtype=torch.double)
covariance_matrix = torch.tensor(cls_cov_norm[1], dtype=torch.double)
mvn = MultivariateNormal(loc=loc, covariance_matrix=covariance_matrix)

ValueError: Expected parameter covariance_matrix (Tensor of shape (768, 768)) of distribution MultivariateNormal(loc: torch.Size([768]), covariance_matrix: torch.Size([768, 768])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[ 5.0369e-01, -3.7660e-02,  3.5934e-02,  ...,  4.1147e-02,
         -5.5770e-02, -3.0628e-02],
        [-3.7660e-02,  3.3935e-01,  5.4759e-03,  ..., -9.7397e-03,
          6.1899e-03, -4.2780e-04],
        [ 3.5934e-02,  5.4759e-03,  4.4832e-01,  ..., -9.0165e-02,
          4.8203e-02, -2.0899e-02],
        ...,
        [ 4.1147e-02, -9.7397e-03, -9.0165e-02,  ...,  5.0359e-01,
         -1.0280e-01, -6.5978e-02],
        [-5.5770e-02,  6.1899e-03,  4.8203e-02,  ..., -1.0280e-01,
          4.9860e-01,  6.4160e-03],
        [-3.0628e-02, -4.2780e-04, -2.0899e-02,  ..., -6.5978e-02,
          6.4160e-03,  5.0805e-01]], dtype=torch.float64)

In [30]:
samples = multivariate_normal(class_mean_norm[1], class_cov_norm[1], size=100)

In [53]:
all_samples = {}
for cls_, cls_mean in class_mean_norm.items():
    cls_cov = class_cov_norm[cls_]
    all_samples[cls_] = multivariate_normal(cls_mean, cls_cov, size=200)

In [54]:
all_samples[1]

array([[-1.05075412,  0.79848423, -0.86665117, ...,  1.30014699,
         0.28981925,  0.59635764],
       [-0.56275522,  0.2777413 ,  0.02667506, ..., -0.40953576,
        -0.3214637 , -0.57361511],
       [-0.85231805, -0.01821379,  0.05739761, ...,  0.34916892,
        -0.80492347,  1.88180276],
       ...,
       [ 1.10142477,  0.3840105 ,  0.09670077, ...,  0.04410125,
         0.49412766,  0.16224936],
       [ 1.61127665, -0.97836355,  0.24188742, ..., -0.41196284,
        -1.07026838, -0.80870801],
       [ 0.01444067, -0.13050146,  0.83273259, ...,  0.34093976,
         1.71202047, -0.66042011]])

## Use CKA

In [36]:
import torch
import numpy as np
from cka import CKA, CudaCKA

In [48]:
t1_cls = [1, 32, 54, 62, 70]
t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]

In [55]:
all_samples_torch = {}
for k, v in all_samples.items():
    all_samples_torch[k] = torch.tensor(v, dtype=torch.float)

In [69]:
device = torch.device('cuda:0')
cuda_cka = CudaCKA(device)

distance = {}

for t2 in t2_cls:
    X = all_samples_torch[t2].to(device)
    for t1 in t1_cls:        
        Y = all_samples_torch[t1].to(device)
        linear_cka = cuda_cka.linear_CKA(X, Y).item()
        kernel_cka = cuda_cka.kernel_CKA(X, Y).item()
        print(f"Linear CKA between {cls2str[t2]} and {cls2str[t1]}: {linear_cka:.4f}")
        print(f"RBF Kernel CKA between {cls2str[t2]} and {cls2str[t1]}: {kernel_cka:.4f}")
        
        sub_distance = distance.get(cls2str[t2], [])
        sub_distance.append({cls2str[t1]: [linear_cka, kernel_cka]})
        
        distance[cls2str[t2]] = sub_distance
        print()
        


Linear CKA between ray and aquarium_fish: 0.1077
RBF Kernel CKA between ray and aquarium_fish: 0.1920

Linear CKA between ray and flatfish: 0.0893
RBF Kernel CKA between ray and flatfish: 0.1673

Linear CKA between ray and orchid: 0.1105
RBF Kernel CKA between ray and orchid: 0.1940

Linear CKA between ray and poppy: 0.0990
RBF Kernel CKA between ray and poppy: 0.1754

Linear CKA between ray and rose: 0.0685
RBF Kernel CKA between ray and rose: 0.1309

Linear CKA between shark and aquarium_fish: 0.1175
RBF Kernel CKA between shark and aquarium_fish: 0.1991

Linear CKA between shark and flatfish: 0.0955
RBF Kernel CKA between shark and flatfish: 0.1742

Linear CKA between shark and orchid: 0.1168
RBF Kernel CKA between shark and orchid: 0.2004

Linear CKA between shark and poppy: 0.1026
RBF Kernel CKA between shark and poppy: 0.1784

Linear CKA between shark and rose: 0.0719
RBF Kernel CKA between shark and rose: 0.1397

Linear CKA between trout and aquarium_fish: 0.1202
RBF Kernel CKA 

In [83]:
device = torch.device('cuda:0')
cuda_cka = CudaCKA(device)

for t2 in t2_cls:
    X = all_samples_torch[t2].to(device)
    for t1 in t1_cls:        
        Y = all_samples_torch[t1].to(device)
        print(f"Linear CKA between {cls2str[t2]} and {cls2str[t1]}: {cuda_cka.linear_CKA(X, Y):.4f}")
        print(f"RBF Kernel CKA between {cls2str[t2]} and {cls2str[t1]}: {cuda_cka.kernel_CKA(X, Y):.4f}")
        print()
        


Linear CKA between ray and aquarium_fish: 0.1077
RBF Kernel CKA between ray and aquarium_fish: 0.1920

Linear CKA between ray and flatfish: 0.0893
RBF Kernel CKA between ray and flatfish: 0.1673

Linear CKA between ray and orchid: 0.1105
RBF Kernel CKA between ray and orchid: 0.1940

Linear CKA between ray and poppy: 0.0990
RBF Kernel CKA between ray and poppy: 0.1754

Linear CKA between ray and rose: 0.0685
RBF Kernel CKA between ray and rose: 0.1309

Linear CKA between shark and aquarium_fish: 0.1175
RBF Kernel CKA between shark and aquarium_fish: 0.1991

Linear CKA between shark and flatfish: 0.0955
RBF Kernel CKA between shark and flatfish: 0.1742

Linear CKA between shark and orchid: 0.1168
RBF Kernel CKA between shark and orchid: 0.2004

Linear CKA between shark and poppy: 0.1026
RBF Kernel CKA between shark and poppy: 0.1784

Linear CKA between shark and rose: 0.0719
RBF Kernel CKA between shark and rose: 0.1397

Linear CKA between trout and aquarium_fish: 0.1202
RBF Kernel CKA 

In [70]:
distance

{'ray': [{'aquarium_fish': [0.1077294796705246, 0.1920444369316101]},
  {'flatfish': [0.08928253501653671, 0.16728192567825317]},
  {'orchid': [0.11046697199344635, 0.1940249502658844]},
  {'poppy': [0.09895061701536179, 0.17540933191776276]},
  {'rose': [0.06848851591348648, 0.13093379139900208]}],
 'shark': [{'aquarium_fish': [0.11753126978874207, 0.1991184651851654]},
  {'flatfish': [0.09545803815126419, 0.17418846487998962]},
  {'orchid': [0.11683563888072968, 0.20040051639080048]},
  {'poppy': [0.10263590514659882, 0.1783677488565445]},
  {'rose': [0.07185043394565582, 0.13973644375801086]}],
 'trout': [{'aquarium_fish': [0.12022633105516434, 0.21143823862075806]},
  {'flatfish': [0.10366805642843246, 0.18941736221313477]},
  {'orchid': [0.12736739218235016, 0.22162704169750214]},
  {'poppy': [0.11663080751895905, 0.20376983284950256]},
  {'rose': [0.08759048581123352, 0.16012036800384521]}],
 'sunflower': [{'aquarium_fish': [0.11182324588298798, 0.19027286767959595]},
  {'flatfis

In [81]:
for target, target_list in distance.items():
    source_, max_dist = "", 0
    for source in target_list:
        for s, source_dist in source.items():    
            if max_dist < source_dist[0]:
                source_ = s
                max_dist = source_dist[0]
    print(f"{target}\t{source_}\t{max_dist:.4f}")

ray	orchid	0.1105
shark	aquarium_fish	0.1175
trout	orchid	0.1274
sunflower	orchid	0.1151
tulip	orchid	0.1217
baby	aquarium_fish	0.1302
boy	aquarium_fish	0.1381
girl	orchid	0.1336
man	aquarium_fish	0.1479
woman	aquarium_fish	0.1324
maple_tree	aquarium_fish	0.0881
oak_tree	aquarium_fish	0.1070
palm_tree	orchid	0.1142
pine_tree	orchid	0.1163
willow_tree	orchid	0.1270


In [80]:
for target, target_list in distance.items():
    source_, max_dist = "", 0
    for source in target_list:
        for s, source_dist in source.items():    
            if max_dist < source_dist[1]:
                source_ = s
                max_dist = source_dist[1]
    print(f"{target}\t{source_}\t{max_dist:.4f}")

ray	orchid	0.1940
shark	orchid	0.2004
trout	orchid	0.2216
sunflower	orchid	0.1999
tulip	orchid	0.2141
baby	aquarium_fish	0.2141
boy	orchid	0.2310
girl	orchid	0.2276
man	aquarium_fish	0.2426
woman	aquarium_fish	0.2167
maple_tree	aquarium_fish	0.1616
oak_tree	aquarium_fish	0.1822
palm_tree	orchid	0.1989
pine_tree	orchid	0.2055
willow_tree	orchid	0.2120


In [74]:
distance['ray']

[{'aquarium_fish': [0.1077294796705246, 0.1920444369316101]},
 {'flatfish': [0.08928253501653671, 0.16728192567825317]},
 {'orchid': [0.11046697199344635, 0.1940249502658844]},
 {'poppy': [0.09895061701536179, 0.17540933191776276]},
 {'rose': [0.06848851591348648, 0.13093379139900208]}]

# 30/01/2024
Group all CIFAR-100

In [2]:
coarse_label_dict = {0: 'aquatic_mammals', 1: 'fish', 2: 'flowers',
                     3: 'food_containers', 4: 'fruit_and_vegetables', 5: 'household_electrical_devices',
                     6: 'household_furniture', 7: 'insects', 8: 'large_carnivores', 9: 'large_man-made_outdoor_things',
                     10: 'large_natural_outdoor_scenes', 11: 'large_omnivores_and_herbivores', 12: 'medium_mammals',
                     13: 'non-insect_invertebrates',14: 'people', 15: 'reptiles', 16: 'small_mammals', 
                     17: 'trees', 18: 'vehicles_1', 19: 'vehicles_2'}

fine_label_dict = {0: "apple", 1: "aquarium_fish", 2: "baby", 3: "bear", 4: "beaver",
                   5: "bed", 6: "bee", 7: "beetle", 8: "bicycle", 9: "bottle", 10: "bowl",
                   11: "boy", 12: "bridge", 13: "bus", 14: "butterfly", 15: "camel", 16: "can",
                   17: "castle", 18: "caterpillar", 19: "cattle", 20: "chair", 21: "chimpanzee",
                   22: "clock", 23: "cloud", 24: "cockroach", 25: "couch", 26: "crab", 27: "crocodile", 
                   28: "cup", 29: "dinosaur", 30: "dolphin", 31: "elephant", 32: "flatfish", 33: "forest",
                   34: "fox", 35: "girl", 36: "hamster", 37: "house", 38: "kangaroo", 39: "keyboard", 40: "lamp",
                   41: "lawn_mower", 42: "leopard", 43: "lion", 44: "lizard", 45: "lobster", 46: "man",
                   47: "maple_tree", 48: "motorcycle", 49: "mountain", 50: "mouse", 51: "mushroom", 52: "oak_tree",
                   53: "orange", 54: "orchid", 55: "otter", 56: "palm_tree", 57: "pear", 58: "pickup_truck",
                   59: "pine_tree", 60: "plain", 61: "plate", 62: "poppy", 63: "porcupine", 64: "possum",
                   65: "rabbit", 66: "raccoon", 67: "ray", 68: "road", 69: "rocket", 70: "rose", 71: "sea",
                   72: "seal", 73: "shark", 74: "shrew", 75: "skunk", 76: "skyscraper", 77: "snail", 78: "snake", 
                   79: "spider", 80: "squirrel", 81: "streetcar", 82: "sunflower", 83: "sweet_pepper", 84: "table",
                   85: "tank", 86: "telephone", 87: "television", 88: "tiger", 89: "tractor", 90: "train", 91: "trout",
                   92: "tulip", 93: "turtle", 94: "wardrobe", 95: "whale", 96: "willow_tree", 97: "wolf", 98: "woman", 99: "worm",}

cifar100_coarse_labels = [4,  1, 14,  8,  0,  6,  7,  7, 18,  3,  
                          3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                          6, 11,  5, 10,  7,  6, 13, 15,  3, 15,  
                          0, 11,  1, 10, 12, 14, 16,  9, 11,  5, 
                          5, 19,  8,  8, 15, 13, 14, 17, 18, 10, 
                          16, 4, 17,  4,  2,  0, 17,  4, 18, 17, 
                          10, 3,  2, 12, 12, 16, 12,  1,  9, 19,  
                          2, 10,  0,  1, 16, 12,  9, 13, 15, 13, 
                          16, 19,  2,  4,  6, 19,  5,  5,  8, 19, 
                          18,  1,  2, 15,  6,  0, 17,  8, 14, 13]

In [3]:
chosen_superclass = {}

for idx, c_label in enumerate(cifar100_coarse_labels):
    subs = chosen_superclass.get(c_label, [])
    subs.append(idx)
    chosen_superclass[c_label] = subs
    
print(chosen_superclass)    

{4: [0, 51, 53, 57, 83], 1: [1, 32, 67, 73, 91], 14: [2, 11, 35, 46, 98], 8: [3, 42, 43, 88, 97], 0: [4, 30, 55, 72, 95], 6: [5, 20, 25, 84, 94], 7: [6, 7, 14, 18, 24], 18: [8, 13, 48, 58, 90], 3: [9, 10, 16, 28, 61], 9: [12, 17, 37, 68, 76], 11: [15, 19, 21, 31, 38], 5: [22, 39, 40, 86, 87], 10: [23, 33, 49, 60, 71], 13: [26, 45, 77, 79, 99], 15: [27, 29, 44, 78, 93], 12: [34, 63, 64, 66, 75], 16: [36, 50, 65, 74, 80], 19: [41, 69, 81, 85, 89], 17: [47, 52, 56, 59, 96], 2: [54, 62, 70, 82, 92]}


In [18]:
train_embedding_path = "../cifar100_coarse_train_embedding_nn.pt"
val_embedding_path = None
test_embedding_path = "../cifar100_coarse_test_embedding_nn.pt"
ignore_super = False

n_class = 100

data, task_cla, class_order = get_cifar100_coarse(train_embedding_path, test_embedding_path, None, chosen_superclass=chosen_superclass, ignore_super=ignore_super)

# print(class_order)
idx2cls = {k: v for k, v in enumerate(class_order)}
cls2idx = {v: k for k, v in enumerate(class_order)}

superclass_order: [4, 1, 14, 8, 0, 6, 7, 18, 3, 9, 11, 5, 10, 13, 15, 12, 16, 19, 17, 2]
class_order: [0, 51, 53, 57, 83, 1, 32, 67, 73, 91, 2, 11, 35, 46, 98, 3, 42, 43, 88, 97, 4, 30, 55, 72, 95, 5, 20, 25, 84, 94, 6, 7, 14, 18, 24, 8, 13, 48, 58, 90, 9, 10, 16, 28, 61, 12, 17, 37, 68, 76, 15, 19, 21, 31, 38, 22, 39, 40, 86, 87, 23, 33, 49, 60, 71, 26, 45, 77, 79, 99, 27, 29, 44, 78, 93, 34, 63, 64, 66, 75, 36, 50, 65, 74, 80, 41, 69, 81, 85, 89, 47, 52, 56, 59, 96, 54, 62, 70, 82, 92]
total_task: 20
chosen_superclass: {4: [0, 51, 53, 57, 83], 1: [1, 32, 67, 73, 91], 14: [2, 11, 35, 46, 98], 8: [3, 42, 43, 88, 97], 0: [4, 30, 55, 72, 95], 6: [5, 20, 25, 84, 94], 7: [6, 7, 14, 18, 24], 18: [8, 13, 48, 58, 90], 3: [9, 10, 16, 28, 61], 9: [12, 17, 37, 68, 76], 11: [15, 19, 21, 31, 38], 5: [22, 39, 40, 86, 87], 10: [23, 33, 49, 60, 71], 13: [26, 45, 77, 79, 99], 15: [27, 29, 44, 78, 93], 12: [34, 63, 64, 66, 75], 16: [36, 50, 65, 74, 80], 19: [41, 69, 81, 85, 89], 17: [47, 52, 56, 59, 96

In [19]:
from collections import Counter

for i in range(20):
    print(f"i-{i}")
    print(f"Train data: {len(data[i]['trn']['x'])}")
    print(f"\t{Counter(data[i]['trn']['y'])}")
    print(f"Val data: {len(data[i]['val']['x'])}")
    print(f"\t{Counter(data[i]['val']['y'])}") 
    print(f"Test data: {len(data[i]['tst']['x'])}")
    print(f"\t{Counter(data[i]['tst']['y'])}")  
    print()  

i-0
Train data: 2000
	Counter({0: 400, 1: 400, 4: 400, 2: 400, 3: 400})
Val data: 500
	Counter({0: 100, 1: 100, 2: 100, 3: 100, 4: 100})
Test data: 500
	Counter({1: 100, 0: 100, 2: 100, 3: 100, 4: 100})

i-1
Train data: 2000
	Counter({5: 400, 8: 400, 6: 400, 7: 400, 9: 400})
Val data: 500
	Counter({5: 100, 6: 100, 7: 100, 8: 100, 9: 100})
Test data: 500
	Counter({8: 100, 6: 100, 7: 100, 9: 100, 5: 100})

i-2
Train data: 2000
	Counter({11: 400, 12: 400, 13: 400, 14: 400, 10: 400})
Val data: 500
	Counter({10: 100, 11: 100, 12: 100, 13: 100, 14: 100})
Test data: 500
	Counter({11: 100, 10: 100, 12: 100, 13: 100, 14: 100})

i-3
Train data: 2000
	Counter({19: 400, 16: 400, 17: 400, 18: 400, 15: 400})
Val data: 500
	Counter({15: 100, 16: 100, 17: 100, 18: 100, 19: 100})
Test data: 500
	Counter({17: 100, 19: 100, 16: 100, 15: 100, 18: 100})

i-4
Train data: 2000
	Counter({22: 400, 23: 400, 24: 400, 20: 400, 21: 400})
Val data: 500
	Counter({20: 100, 21: 100, 22: 100, 23: 100, 24: 100})
Test da

In [20]:
all_X, all_y = [], []
for i in range(20):
    all_X.extend(data[i]['trn']['x'])
    all_y.extend(data[i]['trn']['y'])

all_X = torch.stack(all_X).numpy()
all_y = np.vectorize(idx2cls.get)(all_y)

print(all_X.shape)
# print(len(all_y))
# print(all_y)

(40000, 768)


In [21]:
# samples, samples_val, samples_tst = {}, {}, {}

# all_X_val, all_X_tst = [], []
# all_y_val, all_y_tst = [], []

# for i in range(20):
#     all_X_val.extend(data[i]['val']['x'])
#     all_X_tst.extend(data[i]['tst']['x'])
    
#     all_y_val.extend(data[i]['val']['y'])
#     all_y_tst.extend(data[i]['tst']['y'])

# all_X_val = torch.stack(all_X_val).numpy()
# all_X_tst = torch.stack(all_X_tst).numpy()

# all_y_val = np.vectorize(idx2cls.get)(all_y_val)
# all_y_tst = np.vectorize(idx2cls.get)(all_y_tst)

# for cls_ in class_order:
#     x_ = all_X[all_y == cls_]
#     x_val = all_X_val[all_y_val == cls_]
#     x_tst = all_X_tst[all_y_tst == cls_]
    
#     samples[cls_] = x_
#     samples_val[cls_] = x_val
#     samples_tst[cls_] = x_tst

# to_save = {
#     'trn': samples,
#     'val': samples_val,
#     'tst': samples_tst,
# }

# pickle.dump(to_save, open("class_samples_allcifar100.pkl", "wb"))

# samples = pickle.load(open("class_samples_allcifar100.pkl", "rb"))
# for k, v in samples['trn'].items():
#     print(f"{len(v)}\t{len(samples['val'][k])}\t{len(samples['tst'][k])}")
    

400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	100	100
400	

In [22]:
cls_mean = {}
cls_cov = {}
for cls_ in class_order:
    x_ = all_X[all_y == cls_]
    # cls_mean[cls_] = torch.tensor(m_)
    cls_mean[cls_] = np.mean(x_, axis=0)
    cls_cov[cls_] = np.cov(x_.T)

print(cls_mean)

{0: array([-2.40861207e-01,  3.49343956e-01, -1.13809146e-01, -7.12401807e-01,
        8.89423728e-01, -6.87371850e-01, -1.42683074e-01, -2.16851592e-01,
        1.60634503e-01, -3.67381983e-03, -2.78855801e-01,  1.93447247e-01,
       -6.84060514e-01, -1.61514175e+00,  3.45620930e-01,  1.36827791e+00,
       -8.46898854e-01,  1.43960249e+00, -3.96136194e-01,  4.86446202e-01,
       -1.76523969e-01,  1.78533569e-01, -6.81518316e-01, -3.64261538e-01,
        1.38847604e-01,  2.85178851e-02,  6.21047080e-01,  7.19319880e-02,
       -1.11157131e+00, -2.37272322e-01,  1.11452484e+00,  1.77388608e+00,
        8.36689591e-01, -5.36057234e-01,  5.65342307e-01, -8.43356907e-01,
        9.86969694e-02,  1.72776878e-01, -6.98390305e-02,  5.98001242e-01,
        7.03756332e-01,  6.18539095e-01,  5.44102192e-01, -8.71222690e-02,
       -2.00948849e-01, -1.19109261e+00,  9.28789616e-01,  1.36056912e+00,
        1.62066531e+00, -2.66176581e-01,  2.51842588e-02,  1.53481066e-01,
        3.04889262e-0

In [23]:
cls_means_norm = {}
cls_cov_norm = {}
cls_std_norm = {}
for cls_ in class_order:
    x_ = all_X[all_y == cls_]
    m_, s_, c_ = class_mean_norm(x_)
    # cls_means_norm[cls_] = torch.tensor(m_)
    cls_means_norm[cls_] = m_
    cls_cov_norm[cls_] = c_
    cls_std_norm[cls_] = s_

print(cls_means_norm)

{0: array([-1.61074102e-02,  1.69336908e-02, -4.24122531e-03, -3.99375744e-02,
        5.15295491e-02, -3.81465033e-02, -4.66429302e-03, -1.08173965e-02,
        1.15186460e-02, -2.19092960e-03, -1.83972064e-02,  9.30678193e-03,
       -4.01314013e-02, -9.09907743e-02,  1.75263640e-02,  7.86295533e-02,
       -5.02110198e-02,  7.76337907e-02, -2.14060973e-02,  2.74337810e-02,
       -8.51691607e-03,  7.13513326e-03, -3.99994366e-02, -2.14831252e-02,
        1.13025317e-02, -1.98337063e-03,  2.65885256e-02,  6.06491324e-03,
       -6.38854951e-02, -1.11061838e-02,  6.59226403e-02,  1.02434069e-01,
        5.06167300e-02, -2.73264628e-02,  3.12250890e-02, -4.39849682e-02,
        3.90086067e-03,  9.74031724e-03, -1.94415858e-03,  3.14946324e-02,
        4.14995290e-02,  3.58460210e-02,  2.95430049e-02, -3.16200638e-03,
       -1.10809049e-02, -6.38655871e-02,  5.10959253e-02,  7.48105124e-02,
        9.26599503e-02, -1.63745694e-02,  9.55426251e-04,  7.05937715e-03,
        1.69230681e-0

In [24]:
print(cls_means_norm[1].shape)
print(cls_cov_norm[1].shape)
print(cls_std_norm[1])

(768,)
(768, 768)
0.03601106


In [25]:
herds = {}
for cls_ in class_order:
    x_ = all_X[all_y == cls_]
    herd_index = herding(x_)[0]
    # herds[cls_] = torch.tensor(x_[herd_index])
    herds[cls_] = x_[herd_index]

print(herds)

{0: array([-9.24232483e-01,  1.06419347e-01, -1.00273766e-01, -1.25623858e+00,
        7.84764171e-01, -5.46972752e-01, -3.79757166e-01,  4.25285753e-03,
        4.83712822e-01, -2.34212130e-01, -3.98340255e-01,  5.29034026e-02,
       -3.32382023e-01, -1.95651448e+00,  4.93220091e-01,  1.42913890e+00,
       -9.19387996e-01,  8.37165415e-01, -3.91325533e-01,  7.60714889e-01,
        2.50245064e-01,  2.99248040e-01, -1.00671375e+00, -5.73246107e-02,
       -9.93516669e-02, -2.29203582e-01,  1.59166843e-01, -1.54949874e-01,
       -1.23689520e+00, -2.48530045e-01,  9.34154451e-01,  2.04162502e+00,
        5.11575222e-01, -2.40863338e-01,  1.00890779e+00, -8.10359895e-01,
        1.99586257e-01,  8.29482675e-02, -1.14179812e-01,  3.21323782e-01,
        1.05952823e+00,  8.42605352e-01,  4.60666806e-01, -1.93889558e-01,
       -4.53980654e-01, -9.05089140e-01,  9.83959496e-01,  1.31709671e+00,
        1.47120595e+00, -8.25913996e-02,  3.38370763e-02,  5.39873600e-01,
        3.11624974e-0

In [26]:
# to_save = {
#     "data": data,
#     "task_cla": task_cla,
#     "class_order": class_order,
#     # "class_mean": cls_mean,
#     "class_mean_norm": cls_means_norm,
#     "class_std_norm": cls_std_norm,
#     "class_cov_norm": cls_cov_norm,
#     # "herding": herds,    
#     "samples": {
#         "trn": samples,
#         "val": samples_val,
#         "tst": samples_tst,
#     }
# }

# pickle.dump(to_save, open("distance_data_allcifar100.pkl", "wb"))

In [27]:
data = pickle.load(open("distance_data_allcifar100.pkl", "rb"))
data.keys()

dict_keys(['data', 'task_cla', 'class_order', 'class_mean_norm', 'class_std_norm', 'class_cov_norm', 'samples'])

In [28]:
# t1_cls = [1, 32, 54, 62, 70]
# t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
# t2_cls = [67, 73, 91, 82, 92, 2, 11, 35, 46, 98, 47, 52, 56, 59, 96]
# t2_idx2cls = {k: v for k, v in enumerate(t2_cls)}

def calculate_distance(t1_cls, t2_cls):
    t1_idx2cls = {k: v for k, v in enumerate(t1_cls)}
    t2_idx2cls = {k: v for k, v in enumerate(t2_cls)}

    t1_mean_norms, t2_mean_norms = [], []
    for c1 in t1_cls:
        t1_mean_norms.append(data['class_mean_norm'][c1])
    for c2 in t2_cls:
        t2_mean_norms.append(data['class_mean_norm'][c2])
        
    t1_mean_norms = np.stack(t1_mean_norms)
    t2_mean_norms = np.stack(t2_mean_norms)

    distances = np.linalg.norm(t1_mean_norms[:, np.newaxis] - t2_mean_norms, axis=2)

    # for i, (cls_, query_vector) in enumerate(zip(t2_cls, t2_mean_norms)):
        # closest_index = np.argmin(distances[:, i])
        # print(f"{cls2str[cls_]} - {cls2str[t1_idx2cls[closest_index]]} : {pow(distances[closest_index, i], 2):.4f}")
    
    return distances, t1_idx2cls, t2_idx2cls

In [29]:
t1_cls = [i for i in range(100)]

distances, t1_idx2cls, t2_idx2cls = calculate_distance(t1_cls, t1_cls)

for i, t1_cls_ in enumerate(t1_cls):
    for j, t1_cls_ in enumerate(t1_cls):
        t1_t2_distance = distances[i][j]
        print(fine_label_dict[i], "--", fine_label_dict[j], t1_t2_distance)

apple -- apple 0.0
apple -- aquarium_fish 1.212159
apple -- baby 1.1170076
apple -- bear 1.1971267
apple -- beaver 1.1555346
apple -- bed 1.1672064
apple -- bee 1.1590295
apple -- beetle 1.1498005
apple -- bicycle 1.1724029
apple -- bottle 1.0221099
apple -- bowl 0.9761488
apple -- boy 1.0808586
apple -- bridge 1.1863661
apple -- bus 1.2233052
apple -- butterfly 1.1254493
apple -- camel 1.2084166
apple -- can 1.0412776
apple -- castle 1.2096844
apple -- caterpillar 1.0577549
apple -- cattle 1.2279836
apple -- chair 1.1837124
apple -- chimpanzee 1.2187964
apple -- clock 1.1336173
apple -- cloud 1.2157489
apple -- cockroach 1.2437133
apple -- couch 1.1322285
apple -- crab 1.0961698
apple -- crocodile 1.1707826
apple -- cup 1.0376091
apple -- dinosaur 1.2248842
apple -- dolphin 1.2789527
apple -- elephant 1.232301
apple -- flatfish 1.073683
apple -- forest 1.1100656
apple -- fox 1.2351171
apple -- girl 1.0856693
apple -- hamster 1.2130108
apple -- house 1.1684637
apple -- kangaroo 1.22835

In [30]:
pickle.dump(distances, open("distance_between_classes_allcifar100.pkl", "wb"))