In [None]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

from utils.distance import WTE
from utils.datasets import *
from scipy.spatial import distance

In [None]:
device = torch.device('cpu')

## Load datasets

In [None]:
names = ['MNIST','EMNIST','FashionMNIST','KMNIST','USPS']
train_all = [loaders_NIST(name).get_dataset()[0] for name in names]
test_all = [loaders_NIST(name).get_dataset()[1] for name in names]
train_dict = dict(list(zip(names,train_all)))
test_dict = dict(list(zip(names,test_all)))

## Generate reference

In [None]:
def generate_reference(num, dim_low, dim, attached_dim, seed=0):
    torch.manual_seed(seed)
    med = torch.rand(num, dim_low, dim_low).unsqueeze(0)
    s = dim/dim_low
    m = nn.Upsample(scale_factor=s, mode='bilinear')
    attached = torch.randn(num, attached_dim)
    return torch.cat((m(med).reshape(num, -1), attached), dim=1).float()

In [None]:
reference = generate_reference(200, 4, 28, 10)

## WTE

In [None]:
wtes = WTE(train_all, label_dim=10, device=device, ref=reference.cpu(), maxsamples=10000)
wtes = wtes.reshape(wtes.shape[0], -1)
wte_distance = distance.cdist(wtes, wtes, 'euclidean')
print(wte_distance)