In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
import os
from omniglot import task_generator
from proto_attn_train import train, validate, count_acc
import torch.nn.functional as F
from collections import defaultdict
import torch_scatter
from torch.distributions import Normal
import torchvision.datasets


os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
config = {'dataset': 'omniglot', 'split_name': 'default', 'shot': 5, 'query': 15, 'classes_per_task': [20], 'examples_per_class': 20,
          'groups_per_class': 2, 'examples_per_group': 32, 'max_epoch': 10000,
          'nb_val_tasks': 1000, 'train_way': 2, 'test_way': 2, 'prob_xor': None, 'beta': 0,
          'out_dim': 64, 'iterations': 0, 'temp': 0.5, 'scale': 1, 'verbose': False}
config = AttrDict(config)
accs = defaultdict(list)

In [3]:
def hot_attn(Q, K, V, temp):
    return torch.softmax(Q@K.T/(temp),-1)@V  # * np.sqrt(K.shape[-1])

def euclidean_metric(a, b):
    n = a.shape[0]
    m = b.shape[0]
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    logits = -((a - b)**2).sum(dim=2)
    return logits

def z_norm(x, h=1e-7):
    return (x - x.mean(0))/(x.std(0, unbiased=True) + h)

def conv_block(in_channels, out_channels):
    # bn = CustomBatchNorm()
    bn = nn.BatchNorm2d(out_channels, momentum=0.01, track_running_stats = False)
    # nn.init.uniform_(bn.weight) # for pytorch 1.2 or later
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        bn,
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

class Convnet(nn.Module):
    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            conv_block(x_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
        )
        self.embeddings = nn.Linear(hid_dim, z_dim)
        self.mean = nn.Linear(z_dim, z_dim)
        self.logvar = nn.Linear(z_dim, z_dim)

    def forward(self, x):
        x = self.encoder(x)
        h = x.view(x.size(0), -1)
        h = self.embeddings(h)
        h = nn.ReLU()(h)
        mean = self.mean(h)
        logvar = self.logvar(h)
        std = torch.exp(0.5 * logvar)
        return Normal(mean, std)

In [4]:
def z_norm(x, h=1e-7):
    return (x - x.mean(0))/(x.std(0, unbiased=True) + h)

def forward_euclid(train, test, train_labels, config, attn_fn=hot_attn):
    iterations=config.iterations
    temp=config.temp
    scale=config.scale
    train, test = z_norm(train), z_norm(test)
    
    num_classes = train_labels.max() + 1
    
    # Self-attention feature selection
    for _ in range(iterations):
        for c in range(num_classes):
            t = train[train_labels==c]
            train[train_labels==c] = hot_attn(t, t, t, temp)  
    rescale = train.abs().mean(0)
    rescale = scale * (rescale - rescale.min()) / (rescale.max() - rescale.min() + 1e-7)
    
    # Compute predictions and accuracy
    distances = euclidean_metric(rescale*test, rescale*train)  # Shape=(nb_test, nb_train)
    weights = torch.softmax(distances, axis=-1) # Shape=(nb_test, nb_train)
    predictions = weights @ F.one_hot(train_labels, num_classes=train_labels.max()+1).float()
    predictions = torch.clip(predictions, 0.01, 0.99)
    
    return predictions

#### Protonets

In [5]:
def classify_proto(train, test, train_labels, **kwargs):
    proto = torch_scatter.scatter_mean(train, train_labels.type(torch.int64), dim=0)

    # Compute predictions and accuracy
    logits = euclidean_metric(test, proto)
    # predictions = torch.softmax(logits, axis=-1)
    # return predictions
    return logits

In [50]:
config.verbose = True
model =  Convnet(x_dim=1)
model, g = train(task_generator=task_generator,
                 model=model,
                 forward_fn=classify_proto,
                 config=config,
                 loss_fn=nn.CrossEntropyLoss()  # F.cross_entropy
                )
torch.save(model.state_dict(), 'proto_omniglot_10k')

Files already downloaded and verified
epoch 100, loss=2.2504, kl=0.0000, acc=0.6212
epoch 200, loss=1.2627, kl=0.0000, acc=0.7672
epoch 300, loss=0.9046, kl=0.0000, acc=0.8251
epoch 400, loss=0.7205, kl=0.0000, acc=0.8558
epoch 500, loss=0.6028, kl=0.0000, acc=0.8764
epoch 600, loss=0.5217, kl=0.0000, acc=0.8910
epoch 700, loss=0.4622, kl=0.0000, acc=0.9018
epoch 800, loss=0.4155, kl=0.0000, acc=0.9105
epoch 900, loss=0.3793, kl=0.0000, acc=0.9172
epoch 1000, loss=0.3497, kl=0.0000, acc=0.9229
epoch 1100, loss=0.3255, kl=0.0000, acc=0.9276
epoch 1200, loss=0.3052, kl=0.0000, acc=0.9315
epoch 1300, loss=0.2876, kl=0.0000, acc=0.9349
epoch 1400, loss=0.2721, kl=0.0000, acc=0.9379
epoch 1500, loss=0.2589, kl=0.0000, acc=0.9404
epoch 1600, loss=0.2465, kl=0.0000, acc=0.9430
epoch 1700, loss=0.2360, kl=0.0000, acc=0.9450
epoch 1800, loss=0.2267, kl=0.0000, acc=0.9470
epoch 1900, loss=0.2181, kl=0.0000, acc=0.9487
epoch 2000, loss=0.2099, kl=0.0000, acc=0.9505
epoch 2100, loss=0.2028, kl=0.0

In [107]:
accs_proto, losses_proto = validate(task_generator=task_generator,
                        forward_fn=classify_proto,
                        config=config,
                        model=model,
                        loss_fn=nn.CrossEntropyLoss())

Files already downloaded and verified


In [108]:
np.mean(accs_proto), np.std(accs_proto)

(0.9855400081276894, 0.014472103031131597)

In [109]:
config.iterations = 0
accs_attn_test, losses_attn_test = validate(task_generator=task_generator,
                        forward_fn=forward_euclid,
                        config=config,
                        model=model,
                        loss_fn=nn.CrossEntropyLoss())

Files already downloaded and verified


In [110]:
np.mean(accs_attn_test), np.std(accs_attn_test)

(0.9816266762018204, 0.014782069532884237)

#### Feature selection + euclid attention

In [64]:
config.verbose = True
model =  Convnet(x_dim=1)
model, g = train(task_generator=task_generator,
                 model=model,
                 forward_fn=forward_euclid,
                 config=config,
                 loss_fn=F.cross_entropy)
torch.save(model.state_dict(), 'attn_omniglot_10k')

Files already downloaded and verified
epoch 100, loss=2.7803, kl=0.0000, acc=0.3025
epoch 200, loss=2.6343, kl=0.0000, acc=0.4568
epoch 300, loss=2.5462, kl=0.0000, acc=0.5473
epoch 400, loss=2.4917, kl=0.0000, acc=0.6032
epoch 500, loss=2.4519, kl=0.0000, acc=0.6435
epoch 600, loss=2.4229, kl=0.0000, acc=0.6729
epoch 700, loss=2.4004, kl=0.0000, acc=0.6958
epoch 800, loss=2.3816, kl=0.0000, acc=0.7149
epoch 900, loss=2.3666, kl=0.0000, acc=0.7301
epoch 1000, loss=2.3534, kl=0.0000, acc=0.7434
epoch 1100, loss=2.3406, kl=0.0000, acc=0.7563
epoch 1200, loss=2.3294, kl=0.0000, acc=0.7676
epoch 1300, loss=2.3199, kl=0.0000, acc=0.7773
epoch 1400, loss=2.3113, kl=0.0000, acc=0.7859
epoch 1500, loss=2.3042, kl=0.0000, acc=0.7930
epoch 1600, loss=2.2975, kl=0.0000, acc=0.7997
epoch 1700, loss=2.2931, kl=0.0000, acc=0.8041
epoch 1800, loss=2.2883, kl=0.0000, acc=0.8090
epoch 1900, loss=2.2838, kl=0.0000, acc=0.8134
epoch 2000, loss=2.2794, kl=0.0000, acc=0.8179
epoch 2100, loss=2.2748, kl=0.0

In [65]:
accs, losses = validate(task_generator=task_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    loss_fn=F.cross_entropy)
np.mean(accs)

Files already downloaded and verified


0.9594400005936623

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model =  Convnet(x_dim=1)
model.load_state_dict(torch.load('attn_omniglot_10k'))
model = model.to(device)
model.eval();

In [6]:
accs, losses = validate(task_generator=task_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    loss_fn=F.cross_entropy)
np.mean(accs), 100 * np.std(accs)/np.sqrt(1000)

Files already downloaded and verified


(0.9624566689729691, 0.06424073820264695)

#### Test performance on alphabets

In [7]:
class alphabetOmniglot(torch.utils.data.Dataset):
    """
    A (possibly dumb) way to wrap the base Omniglot dataset to get it to
    behave like the rest of the pytorch datasets, i.e. having the data
    and targets being attributes such that
        dataset.data[i] is the i-th data
        dataset.targets[i] is the i-th target
    """
    def __init__(
        self,
        root,
        download,
        background,
    ):
        base = torchvision.datasets.Omniglot(root=root,
                                             download=download,
                                             background=background)
        
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((28,28)),
            torchvision.transforms.ToTensor()
        ])
        data, targets = [], []
        for ex in base:
            data.append(transform(ex[0]))
        self.data = torch.cat(data, 0)
    
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, idx):
        
        return self.data[idx]
    
class alphabet_generator:
    
    def __init__(self, config):
        dataset = alphabetOmniglot(root='./data/',download=True,background=False)
        test = torchvision.datasets.Omniglot(root='DATA_FOLDER',
                                             download=True,
                                             background=False)
        self.dataset = dataset
        self.test = test
        
        # get character indices
        indices = []
        for alphabet in test._alphabets:
            alphabet_index = []
            for idx,character in enumerate(test._characters):
                if character.startswith(alphabet):
                    alphabet_index.append(idx)
            indices.append(alphabet_index)
        self.indices = indices
    
    def get_shot_query(self, config, device, **task_kwargs):
        # select n alphabets, m characters, s support from each character and the rest are queries (20-s)
        n, m, s = 3, 13, 7
        
        # choose alphabets
        dataset = self.dataset
        indices = self.indices
        alphabets = np.random.choice(len(indices),n)
        support_ids = []
        query_ids = []
        support_labels = []
        query_labels = []
        for label,a in enumerate(alphabets):
            # permute to get random sample of characters
            characters = np.random.permutation(indices[a])[:m]
            for c in characters:
                # there are 20 of each character, with dataset ids starting at 20c
                shuffled_ids = np.random.permutation(list(range(c*20,(c+1)*20)))
                support_ids += list(shuffled_ids[:s])
                query_ids += list(shuffled_ids[s:])
            support_labels += [label]*m*s
            query_labels += [label]*m*(20-s)
        support_labels, query_labels = torch.tensor(support_labels), torch.tensor(query_labels)
        support, queries = dataset.data[support_ids][:, None, ...], dataset.data[query_ids][:, None, ...]
        
        return support.to(device), support_labels.long().to(device), queries.to(device), query_labels.long().to(device)

In [97]:
model.load_state_dict(torch.load('proto_omniglot_10k'))
model.eval();

In [104]:
accs, losses = validate(task_generator=alphabet_generator,
                    forward_fn=classify_proto,
                    config=config,
                    model=model,
                    loss_fn=nn.CrossEntropyLoss())
np.mean(accs), np.std(accs)

Files already downloaded and verified
Files already downloaded and verified


(0.8341558179855346, 0.08456723601631734)

In [105]:
accs, losses = validate(task_generator=alphabet_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    loss_fn=nn.CrossEntropyLoss())
np.mean(accs), np.std(accs)

Files already downloaded and verified
Files already downloaded and verified


(0.948380671530962, 0.0822310599780685)

In [106]:
config.iterations = 0
accs, losses = validate(task_generator=alphabet_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    loss_fn=nn.CrossEntropyLoss())
np.mean(accs), np.std(accs)

Files already downloaded and verified
Files already downloaded and verified


(0.9605266293287277, 0.0738353866045988)

In [9]:
model.load_state_dict(torch.load('attn_omniglot_10k'))
model = model.to(device)
model.eval()

accs, losses = validate(task_generator=alphabet_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    loss_fn=F.cross_entropy)
np.mean(accs), 100 * np.std(accs)/np.sqrt(1000)

Files already downloaded and verified
Files already downloaded and verified


(0.9423057215809822, 0.26429950746671643)