In [None]:
import PIL
import torchvision.transforms as transforms
import math
import copy
import pandas as pd

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import backbone
from methods.baselinetrain import BaselineTrain
from io_utils import parse_args, get_resume_file, get_best_file, get_assigned_file
from datasets import miniImageNet_few_shot, ISIC_few_shot, EuroSAT_few_shot, CropDisease_few_shot, Chest_few_shot
from data.datamgr import SimpleDataManager, SetDataManager

In [None]:
def get_clustering_measure(features, target, n_way):
    N = target.size(0)
    C = n_way
    
    class_features = []
    for c in range(5):
        class_features.append(features[target==c])
    
    mu = torch.mean(features, dim=0)
    class_mu = []
    for c in range(5):
        class_mu.append(torch.mean(class_features[c], dim=0))
    
    sigma_within = 0
    sigma_btw = 0
    
    for c in range(5):
        sigma_btw += torch.norm(class_mu[c]-mu)
        for feature in class_features[c]:
            sigma_within += torch.norm(feature-class_mu[c])
    
    sigma_within /= N
    sigma_btw /= C
    
    return sigma_within.item(), sigma_btw.item(), (sigma_within/sigma_btw).item()

def get_cos(out, y):
    cos = torch.nn.CosineSimilarity()
    cos_mtx = torch.zeros([len(out), len(out)])
    for i in range(len(out)):
        cos_mtx[i] = cos(out, out[i].view(1, -1))
        
    same_mtx = (y.view(1, -1) == y.view(-1, 1)).float().cpu()
    diff_mtx = 1 - same_mtx
    same_mtx[range(len(same_mtx)), range(len(same_mtx))] = 0.
    
    same_cos = cos_mtx * same_mtx
    diff_cos = cos_mtx * diff_mtx
    
    return torch.sum(same_cos).item() / len(torch.where(same_cos!=0)[0]), torch.sum(diff_cos).item() / len(torch.where(diff_cos!=0)[0])

def get_pretrained_model(model, method, pretrained_dataset, save_dir, track_bn):
    if pretrained_dataset == 'miniImageNet':
        num_classes = 64
    elif pretrained_dataset == 'tieredImageNet':
        num_classes = 351
    
    model_dict = {model: backbone.ResNet10(method=method, track_bn=track_bn, reinit_bn_stats=False)}
    pretrained_model = BaselineTrain(model_dict[model], num_classes, loss_type='softmax')
    
    if track_bn:
        checkpoint_dir = '%s/checkpoints/%s/%s_%s_aug_track' %(save_dir, pretrained_dataset, model, method)
    else:
        checkpoint_dir = '%s/checkpoints/%s/%s_%s_aug' %(save_dir, pretrained_dataset, model, method)
        
    modelfile = get_resume_file(checkpoint_dir)
    state = torch.load(modelfile)['state']

    pretrained_model.load_state_dict(state, strict=True)
    pretrained_model.cuda()
    pretrained_model.eval()

    for p in pretrained_model.parameters():
        p.requires_grad = False
    
    return pretrained_model

In [None]:
image_size = 224
iter_num = 100

n_way = 5
n_support = 5
n_query = 15
few_shot_params = dict(n_way=n_way, n_support=n_support)

dataset_names = ["miniImageNet", "miniImageNet_test", "CropDisease", "EuroSAT", "ISIC", "ChestX"]
method = 'baseline'
model = 'ResNet10'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'

for dataset_name in dataset_names:
    print (dataset_name)
    if dataset_name == "miniImageNet" or dataset_name == "miniImageNet_test":
        datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_episode=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "CropDisease":
        datamgr = CropDisease_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "EuroSAT":
        datamgr = EuroSAT_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ISIC":
        datamgr = ISIC_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ChestX":
        datamgr = Chest_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
        
    if dataset_name == "miniImageNet_test":
        novel_loader = datamgr.get_data_loader(aug=False, train=False)
    else:
        novel_loader = datamgr.get_data_loader(aug=False)

    clustering_measure_mtx = np.zeros([len(novel_loader), 4])
    
    for task_num, (x, y) in tqdm(enumerate(novel_loader)):
        n_query = x.size(1) - n_support
        x = x.cuda()
        x_var = Variable(x)

        x_a_i = x_var[:,:n_support,:,:,:].contiguous().view( n_way* n_support, *x.size()[2:]) # (25 (5-way * 5-n_support), 3, 224, 224)
        # x_b_i = x_var[:,n_support:,:,:,:].contiguous().view( n_way* n_query, *x.size()[2:]) # (75 (5-way * 15-n_qeury), 3, 224, 224)
        y_a_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_support ) )).cuda() # (25,)
        # y_b_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_query ) )).cuda() # (75,)

        pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir, track_bn=True)
        block1_out, block2_out, block3_out, block4_out = pretrained_model.feature.return_features(x_a_i)
        y = y_a_i
        
#         intra_cos = []
#         inter_cos = []
        
        for k, out in enumerate([block1_out, block2_out, block3_out, block4_out]):
            _, _, c = get_clustering_measure(out, y, n_way)
#             intra, inter = get_cos(out, y)
#             intra_cos.append(intra)
#             inter_cos.append(inter)
            clustering_measure_mtx[task_num, k] = c
        
#         xrange = ['block1', 'block2', 'block3', 'block4', 'last (avgpool)']
#         plt.title('dataset: {}, method: {}'.format(dataset_name, method))
#         plt.plot(xrange, intra_cos, ls='-', marker='o', color='b', label='intra_cos')
#         plt.plot(xrange, inter_cos, ls='-', marker='o', color='r', label='inter_cos')
#         plt.legend()
#         plt.show()
#         plt.close()

    xrange = ['block1', 'block2', 'block3', 'block4']
    plt.errorbar(xrange, np.mean(clustering_measure_mtx, axis=0), yerr=np.std(clustering_measure_mtx, axis=0), fmt='-o', capsize=5)
#     plt.plot(xrange, baseline_clustering, ls='-', marker='o', label=dataset_name)
    
    plt.title('{} clustering measure'.format(method))
#     plt.legend()
    plt.show()
    plt.close()

In [None]:
image_size = 224
iter_num = 100

n_way = 5
n_support = 5
n_query = 15
few_shot_params = dict(n_way=n_way, n_support=n_support)

dataset_names = ["miniImageNet", "miniImageNet_test", "CropDisease", "EuroSAT", "ISIC", "ChestX"]
method = 'baseline'
model = 'ResNet10'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'

for dataset_name in dataset_names:
    print (dataset_name)
    if dataset_name == "miniImageNet" or dataset_name == "miniImageNet_test":
        datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_episode=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "CropDisease":
        datamgr = CropDisease_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "EuroSAT":
        datamgr = EuroSAT_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ISIC":
        datamgr = ISIC_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ChestX":
        datamgr = Chest_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
        
    if dataset_name == "miniImageNet_test":
        novel_loader = datamgr.get_data_loader(aug=False, train=False)
    else:
        novel_loader = datamgr.get_data_loader(aug=False)

    clustering_measure_mtx = np.zeros([len(novel_loader), 4])
    
    for task_num, (x, y) in tqdm(enumerate(novel_loader)):
        n_query = x.size(1) - n_support
        x = x.cuda()
        x_var = Variable(x)

        x_a_i = x_var[:,:n_support,:,:,:].contiguous().view( n_way* n_support, *x.size()[2:]) # (25 (5-way * 5-n_support), 3, 224, 224)
        # x_b_i = x_var[:,n_support:,:,:,:].contiguous().view( n_way* n_query, *x.size()[2:]) # (75 (5-way * 15-n_qeury), 3, 224, 224)
        y_a_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_support ) )).cuda() # (25,)
        # y_b_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_query ) )).cuda() # (75,)

        pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir, track_bn=False)
        block1_out, block2_out, block3_out, block4_out = pretrained_model.feature.return_features(x_a_i)
        y = y_a_i
        
#         intra_cos = []
#         inter_cos = []
        
        for k, out in enumerate([block1_out, block2_out, block3_out, block4_out]):
            _, _, c = get_clustering_measure(out, y, n_way)
#             intra, inter = get_cos(out, y)
#             intra_cos.append(intra)
#             inter_cos.append(inter)
            clustering_measure_mtx[task_num, k] = c
        
#         xrange = ['block1', 'block2', 'block3', 'block4', 'last (avgpool)']
#         plt.title('dataset: {}, method: {}'.format(dataset_name, method))
#         plt.plot(xrange, intra_cos, ls='-', marker='o', color='b', label='intra_cos')
#         plt.plot(xrange, inter_cos, ls='-', marker='o', color='r', label='inter_cos')
#         plt.legend()
#         plt.show()
#         plt.close()

    xrange = ['block1', 'block2', 'block3', 'block4']
    plt.errorbar(xrange, np.mean(clustering_measure_mtx, axis=0), yerr=np.std(clustering_measure_mtx, axis=0), fmt='-o', capsize=5)
#     plt.plot(xrange, baseline_clustering, ls='-', marker='o', label=dataset_name)
    
    plt.title('{} clustering measure'.format(method))
#     plt.legend()
    plt.show()
    plt.close()

# T-sne

In [None]:
datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 128)
base_loader = datamgr.get_data_loader(aug = False)

In [None]:
model = 'ResNet10'
method = 'baseline'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'
pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir)

In [None]:
def plot_img(img_t):
    norm_dict = dict(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
    tf = transforms.ToPILImage()
    
    img = img_t.cpu() * norm_dict['std'].view(-1, 1, 1) + norm_dict['mean'].view(-1, 1, 1)
    
    return tf(img)

In [None]:
for idx, (image, label) in enumerate(base_loader):
    image = image.cuda()
    block1_out, block2_out, block3_out, block4_out, last_out = pretrained_model.feature.return_features(image)
    
    if idx == 0:
        repr_all = last_out.cpu()
        # image_all = image.cpu()
        label_all = label
    else:
        repr_all = torch.cat([repr_all, last_out.cpu()], dim=0)
        # image_all = torch.cat([image_all, image.cpu()], dim=0)
        label_all = torch.cat([label_all, label])

In [None]:
for i, lab in enumerate([0, 1, 2, 3, 4]):
    find_idx = torch.where(label_all==lab)[0]
    if i == 0:
        repr_selected = repr_all[find_idx]
        # image_selected = image_all[find_idx]
        label_selected = label_all[find_idx]
    else:
        repr_selected = torch.cat([repr_selected, repr_all[find_idx]])
        # image_selected = torch.cat([image_selected, image_all[find_idx]])
        label_selected = torch.cat([label_selected, label_all[find_idx]])

In [None]:
#
#  tsne_torch.py
#
# Implementation of t-SNE in pytorch. The implementation was tested on pytorch
# > 1.0, and it requires Numpy to read files. In order to plot the results,
# a working installation of matplotlib is required.
#
#
# The example can be run by executing: `python tsne_torch.py`
#
#
#  Created by Xiao Li on 23-03-2020.
#  Copyright (c) 2020. All rights reserved.


# torch.set_default_tensor_type(torch.cuda.DoubleTensor)
torch.set_default_tensor_type(torch.DoubleTensor)

def Hbeta_torch(D, beta=1.0):
    P = torch.exp(-D.clone() * beta)

    sumP = torch.sum(P)

    H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
    P = P / sumP

    return H, P


def x2p_torch(X, tol=1e-5, perplexity=30.0):
    """
        Performs a binary search to get P-values in such a way that each
        conditional Gaussian has the same perplexity.
    """

    # Initialize some variables
    print("Computing pairwise distances...")
    (n, d) = X.shape

    sum_X = torch.sum(X*X, 1)
    D = torch.add(torch.add(-2 * torch.mm(X, X.t()), sum_X).t(), sum_X)

    P = torch.zeros(n, n)
    beta = torch.ones(n, 1)
    logU = torch.log(torch.tensor([perplexity]))
    n_list = [i for i in range(n)]

    # Loop over all datapoints
    for i in range(n):

        # Print progress
        if i % 500 == 0:
            print("Computing P-values for point %d of %d..." % (i, n))

        # Compute the Gaussian kernel and entropy for the current precision
        # there may be something wrong with this setting None
        betamin = None
        betamax = None
        Di = D[i, n_list[0:i]+n_list[i+1:n]]

        (H, thisP) = Hbeta_torch(Di, beta[i])

        # Evaluate whether the perplexity is within tolerance
        Hdiff = H - logU
        tries = 0
        while torch.abs(Hdiff) > tol and tries < 50:

            # If not, increase or decrease precision
            if Hdiff > 0:
                betamin = beta[i].clone()
                if betamax is None:
                    beta[i] = beta[i] * 2.
                else:
                    beta[i] = (beta[i] + betamax) / 2.
            else:
                betamax = beta[i].clone()
                if betamin is None:
                    beta[i] = beta[i] / 2.
                else:
                    beta[i] = (beta[i] + betamin) / 2.

            # Recompute the values
            (H, thisP) = Hbeta_torch(Di, beta[i])

            Hdiff = H - logU
            tries += 1

        # Set the final row of P
        P[i, n_list[0:i]+n_list[i+1:n]] = thisP

    # Return final P-matrix
    return P


def pca_torch(X, no_dims=50):
    print("Preprocessing the data using PCA...")
    (n, d) = X.shape
    X = X - torch.mean(X, 0)

    (l, M) = torch.eig(torch.mm(X.t(), X), True)
    # split M real
    for i in range(d):
        if l[i, 1] != 0:
            M[:, i+1] = M[:, i]
            i += 1

    Y = torch.mm(X, M[:, 0:no_dims])
    return Y


def tsne(X, no_dims=2, initial_dims=50, perplexity=30.0):
    """
        Runs t-SNE on the dataset in the NxD array X to reduce its
        dimensionality to no_dims dimensions. The syntaxis of the function is
        `Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array.
    """

    # Check inputs
    if isinstance(no_dims, float):
        print("Error: array X should not have type float.")
        return -1
    if round(no_dims) != no_dims:
        print("Error: number of dimensions should be an integer.")
        return -1

    # Initialize variables
    X = pca_torch(X, initial_dims)
    (n, d) = X.shape
    max_iter = 1000
    initial_momentum = 0.5
    final_momentum = 0.8
    eta = 500
    min_gain = 0.01
    Y = torch.randn(n, no_dims).cuda()
    dY = torch.zeros(n, no_dims).cuda()
    iY = torch.zeros(n, no_dims).cuda()
    gains = torch.ones(n, no_dims).cuda()
    
    # Compute P-values
    P = x2p_torch(X, 1e-5, perplexity)
    P = P + P.t()
    P = P / torch.sum(P)
    P = P * 4.    # early exaggeration
    print("get P shape", P.shape)
    P = torch.max(P.cuda(), torch.tensor([1e-21]).cuda())
    
    # Run iterations
    for iter in range(max_iter):

        # Compute pairwise affinities
        sum_Y = torch.sum(Y*Y, 1)
        num = -2. * torch.mm(Y, Y.t())
        num = 1. / (1. + torch.add(torch.add(num, sum_Y).t(), sum_Y))
        num[range(n), range(n)] = 0.
        Q = num / torch.sum(num)
        Q = torch.max(Q, torch.tensor([1e-12]).cuda())
        
        # Compute gradient
        PQ = P - Q
        for i in range(n):
            dY[i, :] = torch.sum((PQ[:, i] * num[:, i]).repeat(no_dims, 1).t() * (Y[i, :] - Y), 0)

        # Perform the update
        if iter < 20:
            momentum = initial_momentum
        else:
            momentum = final_momentum

        gains = (gains + 0.2) * ((dY > 0.) != (iY > 0.)).double() + (gains * 0.8) * ((dY > 0.) == (iY > 0.)).double()
        gains[gains < min_gain] = min_gain
        iY = momentum * iY - eta * (gains * dY)
        Y = Y + iY
        Y = Y - torch.mean(Y, 0)

        # Compute current value of cost function
        if (iter + 1) % 10 == 0:
            C = torch.sum(P * torch.log(P / Q))
            print("Iteration %d: error is %f" % (iter + 1, C))

        # Stop lying about P-values
        if iter == 100:
            P = P / 4.

    # Return solution
    return Y

In [None]:
X = repr_selected
labels = label_selected.tolist()

# confirm that x file get same number point than label file
# otherwise may cause error in scatter
assert(len(X[:, 0])==len(X[:,1]))
assert(len(X)==len(labels))

with torch.no_grad():
    Y = tsne(X, 2, 50, 20.0)

# You may write result in two files
# print("Save Y values in file")
# Y1 = open("y1.txt", 'w')
# Y2 = open('y2.txt', 'w')
# for i in range(Y.shape[0]):
#     Y1.write(str(Y[i,0])+"\n")
#     Y2.write(str(Y[i,1])+"\n")
Y = Y.cpu().numpy()
# plt.scatter(Y[:, 0], Y[:, 1], 20, labels)
# plt.show()
# plt.close()

plt.scatter(Y[0:600, 0],     Y[0:600, 1],     20, label='0')
plt.scatter(Y[600:1200, 0],  Y[600:1200, 1],  20, label='1')
plt.scatter(Y[1200:1800, 0], Y[1200:1800, 1], 20, label='2')
plt.scatter(Y[1800:2400, 0], Y[1800:2400, 1], 20, label='3')
plt.scatter(Y[2400:3000, 0], Y[2400:3000, 1], 20, label='4')

plt.title('method: {}'.format(method))
plt.legend()
plt.show()
plt.close()

In [None]:
repr_selected[0].shape

In [None]:
sample = repr_selected[2].cuda()

block1_out, block2_out, block3_out, block4_out, last_out = pretrained_model.feature.return_features(sample.view(1, *sample.shape))
plot_img(sample)

---

In [None]:
image_size = 224
iter_num = 600

n_way = 5
n_support = 5
n_query = 15
few_shot_params = dict(n_way=n_way, n_support=n_support)

dataset_names = ["miniImageNet_test", "CropDisease", "EuroSAT", "ISIC", "ChestX"]
method = 'baseline'
model = 'ResNet10'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'
freeze_backbone = False

for dataset_name in dataset_names:
    print (dataset_name)
    if dataset_name == "miniImageNet" or dataset_name == "miniImageNet_test":
        datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_episode=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "CropDisease":
        datamgr = CropDisease_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "EuroSAT":
        datamgr = EuroSAT_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ISIC":
        datamgr = ISIC_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
    elif dataset_name == "ChestX":
        datamgr = Chest_few_shot.SetDataManager(image_size, n_eposide=iter_num, n_query=n_query, **few_shot_params)
        
    if dataset_name == "miniImageNet_test":
        novel_loader = datamgr.get_data_loader(aug=False, train=False)
    else:
        novel_loader = datamgr.get_data_loader(aug=False)
    
    prt_w = []
    prt_wo = []
    ft_w = []
    ft_wo = []
    df = pd.DataFrame(np.zeros([iter_num, 5]), columns=['r1.0', 'r0.8', 'r0.5', 'r0.2', 'r0.0'])
    
    for r in [1.0, 0.8, 0.5, 0.2, 0.0]:
        acc_lst = []
        for task_num, (x, y) in tqdm(enumerate(novel_loader)):
            pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir)

            ####################################################################################

            if freeze_backbone is False:
                for name, p in pretrained_model.named_parameters():
                    if 'trunk.7' in name:
                        if 'BN' in name:
                            pass
#                             if 'weight' in name:
#                                 p.data.fill_(1.)
#                             else:
#                                 p.data.fill_(0.)
                        else:
                            # ber = torch.bernoulli(torch.ones(p.data.shape)*r).cuda()
                            # p.data = p.data * ber
                            if r == 1.0:
                                pass
                            elif r == 0.0:
                                p.data.fill_(0.)
                                # nn.init.kaiming_uniform_(p.data, a=math.sqrt(5))
                            else:
                                r_dim = int(p.data.shape[0] * r)
                                p.data[r_dim:,:,:,:].fill_(0.)
                                # nn.init.kaiming_uniform_(p.data[r_dim:,:,:,:], a=math.sqrt(5))

            ####################################################################################

            n_query = x.size(1) - n_support
            x = x.cuda()
            x_var = Variable(x)

            x_a_i = x_var[:,:n_support,:,:,:].contiguous().view( n_way* n_support, *x.size()[2:]) # (25 (5-way * 5-n_support), 3, 224, 224)
            x_b_i = x_var[:,n_support:,:,:,:].contiguous().view( n_way* n_query,  *x.size()[2:]) # (75 (5-way * 15-n_qeury), 3, 224, 224)
            y_a_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_support ) )).cuda() # (25,)
            y_b_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_query ) )).cuda() # (75,)
            y_b_i = y_b_i.detach().cpu().numpy()

            ####################################################################################

            with torch.no_grad():
                pretrained_model.eval()

                nil_cls = torch.zeros([n_way, 512])
                for i in range(n_way):
                    cls_idx = y_a_i==i
                    nil_cls[i] = torch.mean(pretrained_model.feature(x_a_i)[cls_idx].cpu(), dim=0)
                wo_clas_scores = torch.mm(pretrained_model.feature(x_b_i).cpu(), nil_cls.T)

                topk_scores, topk_labels = wo_clas_scores.data.topk(1, 1, True, True)
                topk_ind = topk_labels.cpu().numpy()

                top1_correct = np.sum(topk_ind[:,0] == y_b_i)
                correct_this, count_this = float(top1_correct), len(y_b_i)
                acc_lst.append(correct_this/count_this*100)
                # print ('prt w/o clas {:2.4f}'.format(correct_this/count_this*100))
        df['r{}'.format(r)] = acc_lst
    df.to_csv('./delete_results/{}_zero_channel.csv'.format(dataset_name), index=False)

---

In [None]:
image_size = 224
model = 'ResNet10'
method = 'baseline'
pretrained_dataset = 'miniImageNet'
save_dir = './logs'
pretrained_model = get_pretrained_model(model, method, pretrained_dataset, save_dir)

In [None]:
source_dataset = 'miniImageNet'
datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 32)
source_loader = datamgr.get_data_loader(aug=False)

In [None]:
source_repr_all = []
source_label_all = []

for idx, (image, label) in enumerate(source_loader):
    image = image.cuda()
    block1_out, block2_out, block3_out, block4_out, last_out = pretrained_model.feature.return_features(image)
#     last_out = pretrained_model(image)[:,:64]
    
    source_repr_all.append(last_out.cpu())
    source_label_all.append(label)

source_repr_all = torch.stack(source_repr_all).reshape(-1, source_repr_all[0].shape[-1])
source_label_all = torch.stack(source_label_all).reshape(-1)

In [None]:
target_dataset = 'ChestX'

if target_dataset == 'miniImageNet':
    datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 32)
    target_loader = datamgr.get_data_loader(aug=False, train=False)
elif target_dataset == 'CropDisease':
    datamgr = CropDisease_few_shot.SimpleDataManager(image_size, batch_size = 32)
    target_loader = datamgr.get_data_loader(aug=False)
elif target_dataset == 'EuroSAT':
    datamgr = EuroSAT_few_shot.SimpleDataManager(image_size, batch_size = 32)
    target_loader = datamgr.get_data_loader(aug=False)
elif target_dataset == 'ISIC':
    datamgr = ISIC_few_shot.SimpleDataManager(image_size, batch_size = 32)
    target_loader = datamgr.get_data_loader(aug=False)
elif target_dataset == 'ChestX':
    datamgr = Chest_few_shot.SimpleDataManager(image_size, batch_size = 32)
    target_loader = datamgr.get_data_loader(aug=False)

In [None]:
target_repr_all = []
target_label_all = []

for idx, (image, label) in enumerate(target_loader):
    image = image.cuda()
    block1_out, block2_out, block3_out, block4_out, last_out = pretrained_model.feature.return_features(image)
#     last_out = pretrained_model(image)[:,:64]
    
    target_repr_all.append(last_out.cpu())
    target_label_all.append(label)

#
target_repr_all = torch.stack(target_repr_all).reshape(-1, target_repr_all[0].shape[-1])
target_label_all = torch.stack(target_label_all).reshape(-1)

In [None]:
def make_label_repr(repr_all, label_all):
    label_repr_dict = {}
    for label in label_all.unique().tolist():
        find_idx = label_all == label
        label_repr_all = repr_all[find_idx]
        label_repr_dict[label] = torch.mean(label_repr_all, dim=0)
    return label_repr_dict

In [None]:
# source_label_repr = make_label_repr(source_repr_all, source_label_all)
# target_label_repr = make_label_repr(source_repr_all, source_label_all)

In [None]:
distance_mtx = np.zeros([len(source_label_repr.keys()), len(target_label_repr.keys())])

for source_label, source_repr in source_label_repr.items():
    for target_label, target_repr in target_label_repr.items():
        distance_mtx[source_label, target_label] = torch.sum(source_repr*target_repr)/(torch.norm(source_repr)*torch.norm(target_repr))

In [None]:
plt.figure(figsize=(4,3))
plt.pcolor(distance_mtx)
plt.colorbar()
plt.xticks([])
plt.yticks([])

# plt.show()
plt.savefig('./src/{}_miniimagenet_repr3_cosine.pdf'.format(method), bbox_inches='tight', format='pdf')
plt.close()

In [None]:
distance_mtx.shape

In [None]:
tot = 0
T = 5.0

source_repr_all = target_repr_all
source_label_all = target_label_all

for curr_repr, curr_label in tqdm(list(zip(source_repr_all, source_label_all))):
    xxx = torch.sum(torch.exp(-torch.norm(curr_repr - source_repr_all, dim=1)/T)) - 1
    yyy = torch.sum(torch.exp(-torch.norm(curr_repr - source_repr_all[curr_label == source_label_all], dim=1)/T)) - 1
    
    if xxx < 1e-6:
        xxx = 1e-6
    if yyy < 1e-6:
        yyy = 1e-6
        
    tot += -np.log(yyy/xxx)

tot /= len(source_repr_all)
print (tot)

miniImageNet
baseline repr3, repr4 (4.0569), logit (1.4584)
baselinebody repr3, repr4 (0.6776), logit (2.2129)

CropDisease
baseline repr3, repr4 (), logit ()
baselinebody repr3, repr4 (0.8321), logit ()

In [None]:
import torch.nn.functional as F

In [None]:
prob = F.softmax(source_repr_all, dim=1)

In [None]:
prob[5] > 0.95