# CE-40959: Advanced Machine Learning
## HW1 - Black-box Meta Learning (100 points)

#### Name: Sepehr Ghobadi
#### Student No: 400211008

In this notebook, you are going to implement a black-box meta learner using the `Omniglot` dataset.

Please write your code in specified sections and do not change anything else. If you have a question regarding this homework, please ask it on the Quera.

Also, it is recommended to use Google Colab to do this homework. You can connect to your drive using the code below:

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [1]:
results_path = '/content/drive/MyDrive/AdvancedML/HW1Q6Results/'
%mkdir '/content/drive/MyDrive/AdvancedML/HW1Q6Results'

mkdir: /content/drive/MyDrive/AdvancedML: No such file or directory


## Import Required libraries

In [4]:
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torchvision
import random
import torch.nn as nn
import math
import pickle
from tqdm import tqdm

import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import torch.utils.data as data

## Introduction

In Meta-Learning literature and in the meta-training phase, you are given some batches which consist of `support` and `query` sets. you train your model in a way that by using a support set you could predict query set labels correctly.

In this homework, you are going to implement such meta-learner like the below architecture. In this model, at each step, you give all your support images and one query to the network simultaneously (query at the end) and you expect that the model predicts query label based on your inputs.


<br><br>

<div style="text-align:center;"><img src="https://drive.google.com/uc?export=view&id=1Au9GF7FB_IChrMLmvM0z4RBPP1R3oPgY" width=300></div>

<br><br>

Don't worry if you didn't understand the architecture. we are going to explain it step by step.

So if our meta-learning is K-shot N-way then each batch will consist of N*K support images with labels and one query image which we have its label in the meta-training phase.

First we should build dataset it this way that each batch return N*K+1 images

The Omniglot data set is designed for developing more human-like learning algorithms. It contains 1623 different handwritten characters from 50 different alphabets. Each of the 1623 characters was drawn online via Amazon's Mechanical Turk by 20 different people.

Train and test dataset contains 964 and 659 classes, respectively. Torchvision-based Omniglot dataset is ordered and every 20 images in a row belong to one class.

In [5]:
# Meta learning parameters.

N = 5
K = 1
seq_length = N*K+1

## Prepare dataset (25 points)

In [6]:
transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.Omniglot('./data/omniglot/', download = True, background = True, transform = transform)
test_dataset = torchvision.datasets.Omniglot('./data/omniglot/', download = True, background = False, transform = transform)

train_labels = np.repeat(np.arange(964), 20)
test_labels = np.repeat(np.arange(659), 20)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to ./data/omniglot/omniglot-py/images_background.zip


  0%|          | 0/9464212 [00:00<?, ?it/s]

Extracting ./data/omniglot/omniglot-py/images_background.zip to ./data/omniglot/omniglot-py
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to ./data/omniglot/omniglot-py/images_evaluation.zip


  0%|          | 0/6462886 [00:00<?, ?it/s]

Extracting ./data/omniglot/omniglot-py/images_evaluation.zip to ./data/omniglot/omniglot-py


To build a dataloader, we should have a class that yields indexes of selected data in the dataset for every iteration and pass it to the `batch_sampler` attribute of dataloader.

Complete below code based on this pseudocode:


1.   select `N` classes randomly from all classes
2.   select `1` class from `N` selected classes as query-contained class
3.   select `K` images from other `N-1` classes independently and randomly
4.   select `K+1` images from the query-contained class independently and randomly
5.   shuffle dataset indexes, but don't forget to put query index at the last of the list



In [7]:
class BatchSampler(object):
    """
    BatchSampler: yield a batch of indexes at each iteration.
    __len__ returns the number of episodes per epoch (same as 'self.iterations').
    """

    def __init__(self, labels, classes_per_it, num_samples, iterations, batch_size):
        """
        Initialize the BatchSampler object
        Arguments:
        - labels: array of labels of dataset.
        - classes_per_it: number of random classes for each iteration
        - num_samples: number of samples for each iteration for each class
        - iterations: number of iterations (episodes) per epoch
        - batch_size: number of batches per iteration
        """
        super(BatchSampler, self).__init__()
        self.labels = labels
        self.classes_per_it = classes_per_it
        self.sample_per_class = num_samples
        self.iterations = iterations
        self.batch_size = batch_size
        
        self.classes = np.unique(self.labels)
        self.indices = np.arange(len(self.labels))

    def __iter__(self):
        '''
        yield a batch of indexes
        '''

        for it in range(self.iterations):
            total_batch_indexes = np.array([])

            #################################################################################
            #                  COMPLETE THE FOLLOWING SECTION (25 points)                   #
            #################################################################################
            # feel free to add/edit initialization part of sampler.
            #################################################################################
            
            for _ in range(self.batch_size):
                sample = np.array([], dtype=np.int64)
                sample_classes = np.random.choice(self.classes , size=self.classes_per_it, replace=False)
                for c in sample_classes:
                    size = K if c!=sample_classes[-1] else K+1
                    class_samples = np.random.choice(self.indices[self.labels==c], size=size, replace=False)
                    sample = np.append(sample, class_samples)
                np.random.shuffle(sample[:N*K])
                total_batch_indexes = np.append(total_batch_indexes, sample)

            #################################################################################
            #                                   THE END                                     #
            #################################################################################

            yield total_batch_indexes.astype(int)

    def __len__(self):
        return self.iterations

In [8]:
iterations = 5000
batch_size = 32

train_sampler = BatchSampler(labels=train_labels, classes_per_it=N,
                              num_samples=K, iterations=iterations,
                              batch_size=batch_size)

test_sampler = BatchSampler(labels=test_labels, classes_per_it=N,
                              num_samples=K, iterations=iterations,
                              batch_size=batch_size)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_sampler)

## Model (50 points)

Let's Build our model. the first block of our model is one encoder which is given below. you are going to implement other blocks of networks with a given explanation

In [9]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels, momentum=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

class OmniglotNet(nn.Module):
    '''
    source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
    '''
    def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
        super(OmniglotNet, self).__init__()
        self.encoder = nn.Sequential(
            conv_block(x_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, z_dim)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)

The whole network consists of two major blocks:


1.   Causal Attention
2.   Temporal Convolution

The first block is `Causal Attention`:


<div style="text-align:center;"><img src="https://drive.google.com/uc?export=view&id=19lWuKzYTRry-UBog838o7dWYVL-r54WF" width=500></div>

<br><br>

The mechanism is so similar to self-attention (if you don't have any information about self-attention, see [this link](https://www.geeksforgeeks.org/self-attention-in-nlp/)) with one difference. the `masked softmax` has been replaced by `softmax`. It means that at each timestep when you calculate weights of the attention mechanism, you do it with just past keys/values.

In [10]:
class AttentionBlock(nn.Module):
    def __init__(self, in_channels, key_size, value_size):
        super(AttentionBlock, self).__init__()

        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (2.5 points)                  #
        #################################################################################
        
        self.key_layer =  nn.Linear(in_channels, key_size)
        self.value_layer =  nn.Linear(in_channels, value_size)
        self.query_layer =  nn.Linear(in_channels, key_size)
        self.temprature = np.sqrt(key_size)

        #################################################################################
        #                                   THE END                                     #
        #################################################################################
        self.softmax_temp = math.sqrt(key_size) #don't forget to apply temperature before calculating softmax.

    def forward(self, x):
        # x is dim (N, T, in_channels) where N is the batch_size, and T is the sequence length
        mask = np.array([[True if i>j else False for i in range(x.shape[1])] for j in range(x.shape[1])])
        mask = torch.BoolTensor(mask).to(x.device)

      
        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (7.5 points)                  #
        #################################################################################
    
        weights = torch.matmul( self.query_layer(x), torch.transpose(self.key_layer(x), 2, 1) )
        weights.data[:,mask] = -float('inf')
        weights = F.softmax(weights/self.temprature, dim=1)
        attention = torch.matmul(weights, self.value_layer(x))
        return torch.cat([attention, x], dim=2)

        #################################################################################
        #                                   THE END                                     #
        #################################################################################

The second block is `Temporal Convolution`:

a Temporal Convolution consists of a series of `Dense Blocks` whose dilation rates increase exponentially until their receptive field exceeds the desired sequence length. For example first time when you apply this block, sequence length is (N*K+1) and dilation is 2.
to sum up, what you will do is this:

<div style="text-align:center;"><img src="https://drive.google.com/uc?export=view&id=1_mWTFiZNQlN4sMTWp2GqolSSzNTAFJuh" width=1000></div>

<br>
Dense Block pseduocode is:
<br><br>

<div style="text-align:center;"><img src="https://drive.google.com/uc?export=view&id=1T2q6KugqBEcwSyJAAGymTaXTe__MGsv3" width=1000></div>

<br>
The `CausalConv` code is given.

<br>

In [11]:
class CasualConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, dilation=1):
        super(CasualConv1d, self).__init__()

        self.pad = nn.ConstantPad1d((dilation, 0), 0)
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size=2, dilation=dilation)

    def forward(self, x):
        return self.conv1d(self.pad(x))

class DenseBlock(nn.Module):

    def __init__(self, in_channels, out_channels, dilation):
        super().__init__()

        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (2.5 points)                  #
        #################################################################################
        
        self.casualconvf = CasualConv1d(in_channels, out_channels, dilation=dilation)
        self.casualconvg = CasualConv1d(in_channels, out_channels, dilation=dilation)

        #################################################################################
        #                                   THE END                                     #
        #################################################################################


    def forward(self, x):

        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (5 points)                    #
        #################################################################################

        xf, xg = self.casualconvf(x), self.casualconvg(x)
        return torch.cat((x, torch.tanh(xf) * torch.sigmoid(xg)), dim=1) #functional.tanh is deprecated
    
        #################################################################################
        #                                   THE END                                     #
        #################################################################################



class TemporalConvolutionBlock(nn.Module):

    def __init__(self, sequence_length, in_channels, dense_block_out_channels, levels):
        super().__init__()

        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (2.5 points)                  #
        #################################################################################
        
        #dense_blocks = [DenseBlock(in_channels, dense_block_out_channels, 2)] + [DenseBlock(dense_block_out_channels, dense_block_out_channels, 2**(l+1)) for l in range(1,levels)]
        dense_blocks = [DenseBlock(in_channels+l*dense_block_out_channels, dense_block_out_channels, 2**(l+1)) for l in range(levels)]
        self.dense_blocks = nn.ModuleList(dense_blocks)
        
        #################################################################################
        #                                   THE END                                     #
        #################################################################################

    def forward(self, x):

        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (10 points)                   #
        #################################################################################
        x = torch.transpose(x, 2, 1)
        for block in self.dense_blocks:
            x = block(x)
        x = torch.transpose(x, 2, 1)
        return x
    
        #################################################################################
        #                                   THE END                                     #
        #################################################################################
        


In [12]:
# The general mechanism of the network is as follows:
# Your input shape is (B*S, C, H, W). B is batch size, S is sequence length, C is channel, H is height and W is width of your image.
# first you should pass your input to "OmniglotNet" network to get feature vectors per data. shape: (B*S, V). V is feature vector size.
# then separate B and S dimensions and concat one-hot labels with your data. Shape: (B, S, V + N). N is your meta-learner parameter (number of classes per batch)
# pass it to a attention block with key size of 64 and value size of 32. shape: (B, S, v1)
# pass it to a temporal convolution block which consists of dense blocks with 128 output channels. shape: (B, S, v2)
# pass it to a attention block with key size of 256 and value size of 128. shape: (B, S, v3)
# pass it to a temporal convolution block which consists of dense blocks with 128 output channels. shape: (B, S, v4)
# pass it to a attention block with key size of 512 and value size of 256. shape: (B, S, v5)
# pass it to a Linear block with N outputs to predict labels. shape: (B, S, N)
# return last index of sequence which is related to query (second dimension). shape: (B, N)

class Network(nn.Module):
    def __init__(self, N, K):
        super(Network, self).__init__()

        self.N = N
        self.K = K
        self.encoder = OmniglotNet().double()
        channels_number = 64 + N
        seq_legth = N*K+1
        dense_block_out_channels = 128
        dense_block_levels = int(np.ceil(np.log2(seq_legth)))

        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (10 points)                   #
        #################################################################################


        key_size, value_size = 64, 32
        self.attention1 = AttentionBlock(in_channels=channels_number, key_size=key_size, value_size=value_size)
        channels_number += value_size
        
        self.tempconv1 = TemporalConvolutionBlock(seq_legth, in_channels=channels_number, dense_block_out_channels=dense_block_out_channels, levels=dense_block_levels)
        channels_number += dense_block_levels*dense_block_out_channels
        
        key_size, value_size = 256, 128
        self.attention2 = AttentionBlock(in_channels=channels_number, key_size=key_size, value_size=value_size)
        channels_number += value_size
        
        self.tempconv2 = TemporalConvolutionBlock(seq_legth, in_channels=channels_number, dense_block_out_channels = 128, levels=dense_block_levels)
        channels_number += dense_block_levels*dense_block_out_channels
        
        key_size, value_size = 512, 256
        self.attention3 = AttentionBlock(in_channels=channels_number, key_size=key_size, value_size=value_size)
        channels_number += value_size
        
        self.fc = nn.Linear(channels_number, N)


        #################################################################################
        #                                   THE END                                     #
        #################################################################################


    def forward(self, input, labels):

        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (10 points)                   #
        #################################################################################
        # input shape is (B*S, C, H, W)
        # labels shape is (B, S, N)
        # output shape is (N, N)
        # calculate output by given description
        #################################################################################

        output = self.encoder(input)
        output = torch.cat((output, labels), dim=1)
        output = output.view((batch_size, self.N*self.K+1, -1))
        output = self.attention1(output)
        output = self.tempconv1(output)
        output = self.attention2(output)
        output = self.tempconv2(output)
        output = self.attention3(output)
        output = self.fc(output)

        return torch.squeeze(output[:,-1:])
        #################################################################################
        #                                   THE END                                     #
        #################################################################################

## Utils

In [13]:
def get_onehot_labels(labels, seq_length, batch_size, device):
    one_hots = []
    for s in range(batch_size):
        sample = labels[s*seq_length:(s+1)*seq_length]
        indices = [ np.where(sorted(sample[:-1].cpu().numpy()) == c)[0][0] for c in sample.cpu().numpy()]
        one_hot = np.zeros((seq_length, N))
        one_hot[np.arange(seq_length), indices] = 1
        one_hots.append(torch.tensor(one_hot))

    return torch.cat(one_hots, dim=0).to(device)

def mask_labels(one_hot_labels, seq_length, batch_size, device):
    query_indices = [ s*seq_length-1 for s in range(1,batch_size+1) ]
    masked_labels = one_hot_labels.detach().clone()
    masked_labels[ query_indices ] = torch.tensor(np.zeros((batch_size, N)), dtype=torch.float64, device=device) # we have #batch_size queries
    return masked_labels, query_indices

def eval_model(logits, targets, criterion):
    loss = criterion(logits, targets)
    _, output_labels = logits.max(dim=1)
    _, labels = targets.max(dim=1)
    acc = torch.eq(labels, output_labels).double().mean().item()
    return loss, acc


        

## Train (15 points)

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

lr = 0.0001
model = Network(N,K)
model = model.double()
model.to(device)
optimizer = optim.Adam(params=model.parameters(), lr=lr)
epochs = 30

criterion = nn.CrossEntropyLoss()

train_loss, train_acc, best_acc = [], [], 0.0

for epoch in range(epochs):
    print('Epoch {}/{}'.format(epoch+1, epochs))
    model.train()
    loader_iter = iter(train_dataloader)
    epoch_loss, epoch_acc = [], []
    for inputs, labels in tqdm(loader_iter):
        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (15 points)                   #
        #################################################################################
        # prepare your data as input to your model.
        # extract query label (last image label in each batch) for loss function.
        # convert your labels to one-hot form and don't forget to set all elements of
        # one-hotted query label to zero (it's trivial that we shouldn't give
        # the output of the network to model as input!).
        # train your model.
        # save loss of each iteration
        #################################################################################
        inputs = inputs.double().to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        
        # creating one hot labels
        one_hot_labels = get_onehot_labels(labels, seq_length, batch_size, device)
        
        #masking query labels to all zero
        masked_labels, query_indices = mask_labels(one_hot_labels, seq_length, batch_size, device)
        
        
        logits = model(inputs, masked_labels)
        loss, acc = eval_model(logits, one_hot_labels[query_indices], criterion)
        
        epoch_loss.append(loss.item())
        epoch_acc.append(acc)
        
        loss.backward()
        optimizer.step()
    
    train_loss = np.mean(epoch_loss)
    train_acc = np.mean(epoch_acc)
    
    model_state = {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
        }
    torch.save(model_state, results_path+'model_train_state.pt')

    print(f"Train Loss: {train_loss}  ---------  Train Accuracy: {train_acc}")


    model.eval()
    loader_iter = iter(test_dataloader)
    epoch_loss, epoch_acc = [], []
    with torch.no_grad():
        for inputs, labels in tqdm(loader_iter):
            #################################################################################
            #                  COMPLETE THE FOLLOWING SECTION (5 points)                    #
            #################################################################################
            # report accuracy of your model.
            # plot loss values in whole training iterations.
            #################################################################################
            inputs = inputs.double().to(device)
            labels = labels.to(device)

            # creating one hot labels
            one_hot_labels = get_onehot_labels(labels, seq_length, batch_size, device)

            #masking query labels to all zero
            masked_labels, query_indices = mask_labels(one_hot_labels, seq_length, batch_size, device)

            logits = model(inputs, masked_labels)
            loss, acc = eval_model(logits, one_hot_labels[query_indices], criterion)

            epoch_loss.append(loss.item())
            epoch_acc.append(acc)

    val_loss = np.mean(epoch_loss)
    val_acc = np.mean(epoch_acc)

    print(f"Validation Loss: {val_loss}  ---------  Validation Accuracy: {val_acc}")


    if val_acc > best_acc:
        torch.save(model_state, results_path+'best_model_state.pt')
        best_acc = val_acc
    epoch_result = {
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'model': model_state
    }
    pickle.dump(epoch_result, open(os.path.join(results_path, f"train_epoch_{epoch}.p"), 'wb'))

        #################################################################################
        #                                   THE END                                     #
        #################################################################################

Device: cuda
Epoch 1/30


100%|██████████| 5000/5000 [13:37<00:00,  6.12it/s]


Train Loss: 1.3241444389015309  ---------  Train Accuracy: 0.40656875


100%|██████████| 5000/5000 [12:24<00:00,  6.71it/s]


Validation Loss: 0.9122364922539528  ---------  Validation Accuracy: 0.65801875
Epoch 2/30


100%|██████████| 5000/5000 [13:37<00:00,  6.12it/s]


Train Loss: 0.19941467473051244  ---------  Train Accuracy: 0.9278625


100%|██████████| 5000/5000 [12:20<00:00,  6.75it/s]


Validation Loss: 0.3070856669964922  ---------  Validation Accuracy: 0.8918
Epoch 3/30


100%|██████████| 5000/5000 [13:35<00:00,  6.13it/s]


Train Loss: 0.0922963682489667  ---------  Train Accuracy: 0.96695


100%|██████████| 5000/5000 [11:47<00:00,  7.06it/s]


Validation Loss: 0.24903771602297883  ---------  Validation Accuracy: 0.9158375
Epoch 4/30


100%|██████████| 5000/5000 [12:28<00:00,  6.68it/s]


Train Loss: 0.06598298080398533  ---------  Train Accuracy: 0.97685625


100%|██████████| 5000/5000 [11:47<00:00,  7.06it/s]


Validation Loss: 0.2263014063565537  ---------  Validation Accuracy: 0.92541875
Epoch 5/30


100%|██████████| 5000/5000 [12:49<00:00,  6.50it/s]


Train Loss: 0.05312101418651987  ---------  Train Accuracy: 0.9811125


100%|██████████| 5000/5000 [11:52<00:00,  7.02it/s]


Validation Loss: 0.21911597323845236  ---------  Validation Accuracy: 0.92783125
Epoch 6/30


100%|██████████| 5000/5000 [12:51<00:00,  6.48it/s]


Train Loss: 0.04415627173569222  ---------  Train Accuracy: 0.984325


100%|██████████| 5000/5000 [11:59<00:00,  6.94it/s]


Validation Loss: 0.1777102025687355  ---------  Validation Accuracy: 0.941375
Epoch 7/30


100%|██████████| 5000/5000 [13:25<00:00,  6.20it/s]


Train Loss: 0.03847476034950947  ---------  Train Accuracy: 0.98635


100%|██████████| 5000/5000 [11:54<00:00,  7.00it/s]


Validation Loss: 0.1627938081292993  ---------  Validation Accuracy: 0.9460375
Epoch 8/30


100%|██████████| 5000/5000 [12:53<00:00,  6.46it/s]


Train Loss: 0.033860022948349174  ---------  Train Accuracy: 0.98775625


100%|██████████| 5000/5000 [11:55<00:00,  6.99it/s]


Validation Loss: 0.15936860225827423  ---------  Validation Accuracy: 0.94793125
Epoch 9/30


  2%|▏         | 81/5000 [00:13<13:37,  6.01it/s]


KeyboardInterrupt: ignored

## Test (5 points)

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

lr = 0.0001
model = Network(N,K)
model = model.double()
model.to(device)
optimizer = optim.Adam(params=model.parameters(), lr=lr)
epochs = 30

criterion = nn.CrossEntropyLoss()

best_model_path = results_path+'best_model_state.pt'
state_dict = torch.load(best_model_path, map_location=device)
model.load_state_dict(state_dict['model_state'], strict=False)
optimizer.load_state_dict(state_dict['optimizer_state'])

test_epochs = 1
for epoch in range(test_epochs):
    loader_iter = iter(test_dataloader)
    epoch_loss, epoch_acc = [], []
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(loader_iter):
            #################################################################################
            #                  COMPLETE THE FOLLOWING SECTION (5 points)                    #
            #################################################################################
            # report accuracy of your model.
            # plot loss values in whole training iterations.
            #################################################################################
            inputs = inputs.double().to(device)
            labels = labels.to(device)

            # creating one hot labels
            one_hot_labels = get_onehot_labels(labels, seq_length, batch_size, device)

            #masking query labels to all zero
            masked_labels, query_indices = mask_labels(one_hot_labels, seq_length, batch_size, device)

            logits = model(inputs, masked_labels)
            loss, acc = eval_model(logits, one_hot_labels[query_indices], criterion)

            epoch_loss.append(loss.item())
            epoch_acc.append(acc)
    
    print(f"\nTest Accuracy: {np.mean(epoch_acc)}")



        #################################################################################
        #                                   THE END                                     #
        #################################################################################

100%|██████████| 5000/5000 [12:36<00:00,  6.61it/s]


Test Accuracy: 0.9471875





## Question (5 points)

Question) State one problem of using this network for meta-learning
<br><br>

Answer: one problem is that  few-shot learning is not inherently a sequential problem and using temporal convolutions imposes a causality hypothesis on dataset which is not a good assumption for supervised classification. also in RL this window of context causallity could not be determined easily. in total using reurrent based archittectures for classification problems create an order assumption on dataset which s assumed to be i.i.d. . also due to the design of this architecture model has high number of parameters which can lead to overfitting on extremely out-of-distribution tasks.  these out of distribution tasks can not be tested on datasets like omniglot or even imagenet. also this type of meta-learners which are black box style on completely out of distribution may not converge to optimum solution even ini long runs oppose to optimization-based meta learning algorithms.  