# METRIC-BASED META-LEARNING using Matching Networks

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')
#%cd drive/MyDrive/'Colab Notebooks'
#%cd meta-learning-course-notebooks/1_MAML/
#!ls

In [None]:
#!pip install import_ipynb --quiet

In [None]:
#!pip install learn2learn --quiet

In [None]:
import import_ipynb
import utils
import models
utils.hide_toggle('Imports 1')

In [None]:
from IPython import display
import torch
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from l2lutils import KShotLoader
from IPython import display
utils.hide_toggle('Imports 2')

# Data Generation and Loading

In [None]:
#Generate data - euclidean
meta_train_ds, meta_test_ds, full_loader = utils.euclideanDataset(n_samples=10000,n_features=20,n_classes=10,batch_size=32)

In [None]:
# Define an MLP network. Note that input dimension has to be data dimension. For classification
# final dimension has to be number of classes; for regression one.
#torch.manual_seed(10)
net = models.MLP(dims=[20,32,32,10])

In [None]:
# Train the network; note that network is trained in place so repeated calls further train it.
net,loss,accs=models.Train(net,full_loader,lr=1e-2,epochs=50,verbose=True)

In [None]:
#Training accuracy.
models.accuracy(net,meta_train_ds.samples,meta_train_ds.labels,verbose=True)

In [None]:
# Test accuracy.
models.accuracy(net,meta_test_ds.samples,meta_test_ds.labels)

# Meta-Learning: Tasks

Generate a k-shot n-way loader using the meta-training dataset

In [None]:
meta_train_kloader=KShotLoader(meta_train_ds,shots=5,ways=5)

Sample a task - each task has a k-shot n-way training set and a similar test set

In [None]:
d_train,d_test=meta_train_kloader.get_task()

Let's try directly learning using the task training set albeit its small size: create a dataset and loader and train it with the earlier network and Train function.

In [None]:
taskds = utils.MyDS(d_train[0],d_train[1])

In [None]:
d_train_loader = torch.utils.data.DataLoader(dataset=taskds,batch_size=1,shuffle=True)

In [None]:
net,loss,accs=models.Train(net,d_train_loader,lr=1e-1,epochs=10,verbose=False)

How does it do on the test set of the sampled task?

In [None]:
models.accuracy(net,d_test[0],d_test[1])

# Matching Networks

In [None]:
import learn2learn as l2l
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

Sampling a training task: Note that each of d_train and d_test is a tuple comprising of a training set, and labels.

In [None]:
d_train,d_test=meta_train_kloader.get_task()

In [None]:
lossfn = torch.nn.NLLLoss()

Cos computes cosine similarities between a batch of targets and a given support set

In [None]:
class Cos(nn.Module):
    def __init__(self,dims=[20,32,32]):
        super(Cos,self).__init__()
    def forward(self,target,ss):
        # compute cosine distances between 
        # target (batch,embedding_dim) and support set ss (ss_size,embedding_dim)
        # return (batch,ss_size)
        target_normed = F.normalize(target,p=2,dim=1)
        # shape of target_normed will be (batch,1,embedding_dim)
        ss_normed = F.normalize(ss,p=2,dim=1).permute(1,0)
        similarities = torch.mm(target_normed,ss_normed)
        # result will be (batch,ss_size)
        return similarities

Matching Network (simple - without Full-context embeddings)

In [None]:
class MAN(nn.Module):
    def __init__(self,dims=[20,32,32],n_classes=2,lr=1e-3):
        super(MAN,self).__init__()
        self.n_classes = n_classes
        self.mlp = models.MLP(dims=dims,task='embedding')
        self.cos = Cos()
        self.attn = nn.Softmax(dim=1)
        self.optimizer = optim.Adam(self.parameters(),lr=lr)
    def forward(self,X,d_train):
        # X = (batch,n_features)
        (x_tr,y_tr) = d_train
        # x_tr = (ss_size,n_features), y_tr = (ss_size)
        ss_e = self.mlp(x_tr)
        X_e = self.mlp(X)
        sims = self.cos(X_e,ss_e)
        # size (batch,ss_size)
        attn_wts = self.attn(sims)
        y_h = torch.eye(self.n_classes)[y_tr]
        # y_h = one-hot version of y_tr = (ss_size,n_classes)
        preds = attn_wts@y_h
        return preds

In [None]:
X = torch.Tensor([[1,1,1],[-1,-1,-1],[1,2,3],[-1,-2,-3]])
y_tr = torch.LongTensor([0,1])
x_tr = X[[0,1],:]
d_tr = (x_tr,y_tr)

In [None]:
man = MAN(dims=[3,8,8])

# Putting it all together: Training a Matching Network
Now let's put all of the above in a loop - training Matching Network algorithm:

In [None]:
# Redifning accuracy function so that it takes h - dataset context - as input since net requires it.
def accuracy(Net,X_test,y_test,h,verbose=True):
    #Net.eval()
    m = X_test.shape[0]
    y_pred = Net(X_test,h)
    _, predicted = torch.max(y_pred, 1)
    correct = (predicted == y_test).float().sum().item()
    if verbose: print(correct,m)
    accuracy = correct/m
    #Net.train()
    return accuracy

In [None]:
classes_train = [i for i in range(5)]
classes_test = [i+5 for i in range(5)]
classes_train, classes_test

In [None]:
import learn2learn as l2l
import torch.optim as optim
shots,ways = 5,5
net = MAN(n_classes=ways,dims=[20,64,32],lr=1e-4)
lossfn = torch.nn.NLLLoss()
meta_train_kloader=KShotLoader(meta_train_ds,shots=shots,ways=ways,num_tasks=1000,classes=classes_train)

In [None]:
epoch=0
n_epochs=100
task_count=50
while epoch<n_epochs:
    test_loss = 0.0
    test_acc = 0.0
    # Sample and train on a task
    for task in range(task_count):
        d_train,d_test=meta_train_kloader.get_task()
        rp = torch.randperm(d_train[1].shape[0])
        d_train0=d_train[0][rp]
        d_train1=d_train[1][rp]
        x_tr = d_train0
        d_tr = x_tr 
        rp1 = torch.randperm(d_test[1].shape[0])
        d_test0=d_test[0][rp1]
        d_test1=d_test[1][rp1]
        x_ts = d_test0
        d_ts = x_ts 
        test_preds = net(d_ts,(x_tr,d_train1))
        #train_preds = net(d_tr,h)
        # Accumulate losses over tasks - note train and test loss both included
        test_loss += lossfn(test_preds,d_test1)
        net.eval()
        test_acc += accuracy(net,d_ts,d_test1,(x_tr,d_train1),verbose=False)
        net.train()
    #Update the network weights
    print('Epoch  % 2d Loss: %2.5e Avg Acc: %2.5f'%(epoch,test_loss/task_count,test_acc/task_count))
    display.clear_output(wait=True)
    net.optimizer.zero_grad()
    test_loss.backward()
    net.optimizer.step()
    epoch+=1
    

Now test the trained CNP network and to tasks sampled from the meta_test_ds dataset:

In [None]:
meta_test_kloader=KShotLoader(meta_test_ds,shots=shots,ways=ways,classes=classes_test)
test_acc = 0.0
task_count = 50
adapt_steps = 1
# Sample and train on a task
for task in range(task_count):
    d_train,d_test=meta_test_kloader.get_task()
    x_tr = d_train[0]
    y_tr_sh = torch.cat((torch.zeros(1,ways),torch.eye(ways)[d_train[1][1:]]))
    d_tr = x_tr #torch.cat((x_tr,y_tr_sh),1)
    x_ts = d_test[0]
    y_ts_sh = torch.zeros(x_ts.shape[0],ways)
    d_ts = x_ts #torch.cat((x_ts,y_ts_sh),1)
    test_preds = net(d_ts,(d_tr,d_train[1]))
    test_acc += accuracy(net,d_ts,d_test[1],(d_tr,d_train[1]),verbose=False)
    # Done with a task
net.train()
print('Avg Acc: %2.5f'%(test_acc/task_count))