# Calculate gaussian std and prototypes using consine distance

In [1]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import sys
from tqdm import tqdm
import numpy as np
import h5py
from os.path import join
import json
current_path = os.getcwd()
sys.path.append(os.path.dirname(current_path))

from os.path import join
from omegaconf import DictConfig, open_dict, OmegaConf
from utils.utils import grab_arg_from_checkpoint, prepend_paths, re_prepend_paths, get_transforms
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

plt.rc('font', family='Calibri', size=18)
# matplotlib.rcParams['font.family'] = 'DejaVu Sans'
plt.rc('legend', loc='best', frameon=True, edgecolor='k')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load checkpoint
result_folder = '/home/siyi/project/mm/result/Dynamic_project/PM16/demo_DynamicTransformer_singleCLS_singleCLS_whole_none_PolyMNIST_DynamicTransformer_singleCLS_0426_151226'
if 'CAD' in result_folder or 'Infarction' in result_folder or 'CelebA' in result_folder:
    eval_metric = 'auc'
else:
    eval_metric = 'acc'
checkpoint_path = join(result_folder, f'downstream/checkpoint_best_{eval_metric}.ckpt')
ckpt = torch.load(checkpoint_path, map_location='cpu')
args = ckpt['hyper_parameters']
args = OmegaConf.create(args)
args['data_base'] = args.data_base_cq
OmegaConf.set_struct(args, False)
args.checkpoint = checkpoint_path
args.data_base = join(args.data_base, args.data_base_postfix)
args = re_prepend_paths(args)
hparams = args
if 'low' not in args.keys():
    args.low = 1.0

In [3]:
# Load dataset
if args.dataset_name == 'PolyMNIST':
    from datasets.PolyMNISTDataset import PolyMNISTDataset
    image_size = grab_arg_from_checkpoint(hparams, 'image_size')
    missing_train = 'missing_train_whole_none'
    train_dataset = PolyMNISTDataset(
        unimodal_datapaths=hparams.DATA_train, data_base=hparams.data_base, missing_path=missing_train, transform=get_transforms(image_size,hparams.target,'train'),
        target_transform=None, low=hparams.low)
elif args.dataset_name == 'MST':
        from datasets.MSTDataset import SVHNMNIST
        missing_train = 'none'
        flags = {'dir_data': hparams.data_base, 'len_sequence': 8, 'data_multiplications': 20}
        flags = OmegaConf.create(flags)
        alphabet_path = join(hparams.data_base, 'alphabet.json')
        with open(alphabet_path) as alphabet_file:
            alphabet = str(''.join(json.load(alphabet_file)))
        train_dataset = SVHNMNIST(flags, alphabet, train='train', missing_path=missing_train, transform=get_transforms(hparams.image_size, hparams.target, 'train'))
        hparams.alphabet = alphabet
elif hparams.dataset_name == 'CelebA':
    from datasets.CelebADataset import CelebaDataset
    missing_train = 'none'
    flags = {'dir_data': hparams.data_base, 'dir_text': hparams.data_base, 'len_sequence': 256, 'random_text_ordering': False, 'random_text_startindex': True}
    flags = OmegaConf.create(flags)
    alphabet_path = join(hparams.data_base, 'alphabet.json')
    with open(alphabet_path) as alphabet_file:
        alphabet = str(''.join(json.load(alphabet_file)))
    train_dataset = CelebaDataset(flags, alphabet, missing_path=missing_train, partition=0, transform=get_transforms(hparams.image_size, hparams.target, 'train'))
    hparams.alphabet = alphabet
elif hparams.dataset_name in set(['DVM', 'CAD', 'Infarction']):
    from datasets.TIPDataset import ImagingAndTabularDataset
    missing_train = 'none'
    train_dataset = ImagingAndTabularDataset(
                hparams.DATA_data_train_eval_imaging, hparams.delete_segmentation, hparams.augmentation_rate, 
                hparams.DATA_data_train_eval_tabular, hparams.DATA_field_lengths_tabular, hparams.eval_one_hot,
                hparams.DATA_labels_train_eval_imaging, hparams.image_size, hparams.live_loading, train=False, target=hparams.target,
                corruption_rate=hparams.corruption_rate, data_base=hparams.data_base, augmentation_speedup=hparams.augmentation_speedup,
                missing_tabular=hparams.missing_tabular, missing_strategy=hparams.missing_strategy, missing_rate=hparams.missing_rate,algorithm_name=hparams.algorithm_name,
                missing_path=missing_train)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=4)

Using PlyMNIST transforms for train mode
Missing mask data loaded from /bigdata/siyi/data/MoPoE/PolyMNIST/missing_modality/missing_train_whole_none.csv
Missing mask example: [False False False False False]


In [4]:
if args.dataset_name == 'PolyMNIST':
    from models.PolyMNIST.Dynamic.DynamicTransformer import DynamicTransformer
    model = DynamicTransformer(args)
elif args.dataset_name == 'MST':
    from models.MST.Dynamic.DynamicTransformer import DynamicTransformer
    model = DynamicTransformer(args)
elif args.dataset_name == 'CelebA':
    from models.CelebA.Dynamic.DynamicTransformer import DynamicTransformer
    model = DynamicTransformer(args)
elif hparams.dataset_name in set(['DVM', 'CAD', 'Infarction']):
    from models.TIPData.Dynamic.DynamicTransformer import DynamicTransformer
    model = DynamicTransformer(args)
model.load_state_dict(ckpt['state_dict'], strict=True)
model.eval()

DynamicTransformer
Randomly drop 0 modalities
Use distance metric for DynamicTransformer: cosine_similarity


DynamicTransformer(
  (m_encoders): ModuleList(
    (0): UnimodalEncoder(
      (network): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (5): ReLU()
      )
      (proj_conv): Sequential(
        (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): UnimodalEncoder(
      (network): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (5): ReLU()
      )
      (proj_conv): Sequential(
   

In [5]:
import scipy.stats as stats
# folder to store features for each subset
storage_folder = result_folder+'/gaussian'

In [6]:
def extract_feature(train_loader, model, mask0):
    feat_list = []
    y_list = []
    y_hat_list = []
    for batch in tqdm(train_loader):
        x, _, y = batch
        for key in x:
            x[key] = x[key].to(device)
        # y = y.to(device)
        mask0 = torch.tensor(mask0).to(device)
        mask = mask0.expand(len(y), -1)
        # print(mask)
        with torch.no_grad():
            y_hat, feat = model.forward_train(x, mask)
        feat_list.append(feat.cpu().numpy())
        y_list.extend(y.cpu().numpy())
        y_hat_list.extend(y_hat.cpu().numpy())
        # break
    feat = np.concatenate(feat_list, axis=0)
    y_list = np.array(y_list)
    y_hat_list = np.array(y_hat_list)
    return feat, y_list, y_hat_list

def calculate_class_prototypes(feat, y, num_classes):
    feat = np.array(feat)
    y = np.array(y)
    prototypes = np.zeros((num_classes, feat.shape[1]))
    prot_dist_std = np.zeros((num_classes))
    t_dist_params = np.zeros((num_classes, 3))
    for i in range(num_classes):
        mask = (y == i)
        feat_i = feat[mask]
        prototype = np.mean(feat_i, axis=0, keepdims=True)
        prototypes[i] = prototype
        similarity = np.dot(feat_i, prototype.T)
        if np.max(similarity) > 1:
            print(f'similarity max: {np.max(similarity)}, min: {np.min(similarity)}')
        whole_similarity = np.concatenate((1-similarity, similarity-1), axis=0)
        assert whole_similarity[:len(similarity)].sum() + whole_similarity[len(similarity):].sum() == 0
        std = np.std(whole_similarity)
        prot_dist_std[i] = std
        # fit t-distribution
        df, loc, scale = stats.t.fit(whole_similarity)
        t_dist_params[i] = [df, loc, scale]
    return prototypes, prot_dist_std, t_dist_params, whole_similarity

In [None]:
id2subset = model.id2subset
prototypes_all = []
dist_std_all = []
t_dist_params_all = []
model = model.to(device)
num_classes = hparams.num_classes
print('Number of classes:', num_classes)
for subset in id2subset.items():
    id, mask0 = subset
    print(f'Extracting {id}: {mask0} features')
    feat, y, y_hat = extract_feature(train_loader, model, mask0)
    prototypes, prot_dist_std, t_dist_params, whole_similarity = calculate_class_prototypes(feat, y, num_classes)
    print(whole_similarity.max(), whole_similarity.min())
    prototypes_all.append(prototypes)
    dist_std_all.append(prot_dist_std)
    t_dist_params_all.append(t_dist_params)
    # break
prototypes_all = torch.from_numpy(np.array(prototypes_all))
overall_prototypes = prototypes_all.mean(dim=0)
dist_std_all = torch.from_numpy(np.array(dist_std_all))
t_dist_params_all = torch.from_numpy(np.array(t_dist_params_all))


In [18]:
# store prototypes and std
if not os.path.exists(storage_folder):
    os.makedirs(storage_folder)
subset_gaussian = {
    'prototypes': prototypes_all.float(),
    'overall_prototypes': overall_prototypes.float(),
    'dist_std': dist_std_all.float(),
    't_dist_params': t_dist_params_all.float(),
    'id2subset': id2subset,
    'subset2id': model.subset2id,
}
torch.save(subset_gaussian, os.path.join(storage_folder, 'subset_gaussian.pt'))

In [19]:
subset_prototypes, subset_prot_dist_std, t_dist_params, whole_similarity = calculate_class_prototypes(feat, y, num_classes)
print('whole_similarity', whole_similarity)
print('std: ', subset_prot_dist_std)
print('t_dist_params', t_dist_params)

whole_similarity [[ 0.26287752]
 [ 0.34177506]
 [ 0.6449405 ]
 ...
 [-0.53106064]
 [-0.3899665 ]
 [-0.2327801 ]]
std:  [0.35574988 0.38305342]
t_dist_params [[ 5.81130627e+03 -5.59180739e-04  3.55789700e-01]
 [ 1.50072172e+02 -2.14038438e-04  3.80548242e-01]]


In [None]:
# Plot the similarity distribution
plt.figure(figsize=(10, 5))
plt.hist(whole_similarity, bins=100, density=True, alpha=0.6, color='g', label='whole similarity')
plt.axvline(x=0, color='r', linestyle='--', label='0')
plt.title('Similarity Distribution')
plt.xlabel('Similarity')