In [None]:
from train import *
from generators.colourless_polythetic_MNIST import task_generator
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
class Config:
    path = 'PATH_TO_DATA_FOLDER'
    shot = 24
    query = 8
    train_way = 2
    examples_per_group = 32
    groups_per_class = 2
    nb_val_tasks = 1000
    max_epoch = 10000
    prob_xor = None
    iterations = 2
    temp = 0.5
    scale = 1
    out_dim = 64
    verbose = False
config = Config()

In [None]:
def hot_attn(Q, K, V, temp):
    return torch.softmax(Q@K.T/(temp),-1)@V  # * np.sqrt(K.shape[-1])

def euclidean_metric(a, b):
    n = a.shape[0]
    m = b.shape[0]
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    logits = -((a - b)**2).sum(dim=2)
    return logits

def z_norm(x, h=1e-7):
    return (x - x.mean(0))/(x.std(0, unbiased=True) + h)

In [None]:
def forward_euclid(train, test, train_labels, config, attn_fn=hot_attn):
    iterations=config.iterations
    temp=config.temp
    scale=config.scale
    train = z_norm(train)
    test = z_norm(test)
    tr0, tr1 = train[train_labels==0], train[train_labels==1]
    
    # Self-attention feature selection
    for _ in range(iterations):
        tr0 = hot_attn(tr0, tr0, tr0, temp)
        tr1 = hot_attn(tr1, tr1, tr1, temp)     
    rescale = tr0.abs().mean(0) + tr1.abs().mean(0)
    rescale = scale * (rescale - rescale.min()) / (rescale.max() - rescale.min() + 1e-5)
    
    # Compute predictions and accuracy
    distances = euclidean_metric(rescale*test, rescale*train)  # Shape=(nb_test, nb_train)
    weights = torch.softmax(distances, axis=-1) # Shape=(nb_test, nb_train)
    predictions = weights @ train_labels
    predictions = torch.clip(predictions, 0.05, 0.95)
    
    return predictions

#### Protonets

In [None]:
def classify_proto(train, test, train_labels, **kwargs):
    tr0, tr1 = train[train_labels==0], train[train_labels==1]
    
    proto_tr0 = tr0.mean(0)
    proto_tr1 = tr1.mean(0)
    proto = torch.stack((proto_tr0, proto_tr1))
    
    # Compute predictions and accuracy
    logits = euclidean_metric(test, proto)
    predictions = torch.softmax(logits, axis=-1)[:, 1]
    return predictions

In [None]:
model, g = train(task_generator=task_generator,
                 forward_fn=classify_proto,
                 config=config,
                 xor_task=False)

In [None]:
accs_proto, _ = validate(task_generator=task_generator,
                        forward_fn=classify_proto,
                        config=config,
                        model=model,
                        xor_task=False)
np.mean(accs_proto)

In [None]:
accs_proto_xor, _ = validate(task_generator=task_generator,
                        forward_fn=classify_proto,
                        config=config,
                        model=model,
                        xor_task=True)
np.mean(accs_proto_xor)

In [None]:
# FS+Attn (test time)
accs_proto_attn_test, _ = validate(task_generator=task_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    xor_task=False)
np.mean(accs_proto_attn_test)

In [None]:
accs_proto_attn_test_xor, _ = validate(task_generator=task_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    xor_task=True)
np.mean(accs_proto_attn_test_xor)

#### Matching networks

In [None]:
def cosine_attn(Q, K, V):    
    normalised_Q = Q / (Q.pow(2).sum(dim=1, keepdim=True).sqrt() + 1e-5)
    normalised_K = K / (K.pow(2).sum(dim=1, keepdim=True).sqrt() + 1e-5)

    cosine_similarities = normalised_Q @ normalised_K.T
    weights = torch.softmax(cosine_similarities, axis=-1)
    out = weights @ V
    return out

def forward_cosine(train, test, train_labels, config, attn_fn=cosine_attn):
    predictions = cosine_attn(test, train, train_labels)
    predictions = torch.clip(predictions, 0.05, 0.95)
    
    return predictions

In [None]:
model, g = train(task_generator=task_generator,
                 forward_fn=forward_cosine,
                 config=config,
                 xor_task=False)

In [None]:
accs_matching, _ = validate(task_generator=task_generator,
                    forward_fn=forward_cosine,
                    config=config,
                    model=model,
                    xor_task=False)
np.mean(accs_matching)

In [None]:
accs_matching_xor, _ = validate(task_generator=task_generator,
                    forward_fn=forward_cosine,
                    config=config,
                    model=model,
                    xor_task=True)
np.mean(accs_matching_xor)

#### FS + Attn

In [None]:
model, g = train(task_generator=task_generator,
                 forward_fn=forward_euclid,
                 config=config,
                 xor_task=False)

In [None]:
accs_fs_attn, _ = validate(task_generator=task_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    xor_task=False)
np.mean(accs_fs_attn)

In [None]:
accs_fs_attn_xor, _ = validate(task_generator=task_generator,
                    forward_fn=forward_euclid,
                    config=config,
                    model=model,
                    xor_task=True)
np.mean(accs_fs_attn_xor)