# Simple Neural Attentive Meta-Learner (SNAIL)

## Imports

In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from collections import OrderedDict

## `blocks.py`

In [3]:
class CasualConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, dilation=1, groups=1, bias=True):
        super(CasualConv1d, self).__init__()
        self.dilation = dilation
        padding = dilation * (kernel_size - 1)
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
                                padding, dilation, groups, bias)

    def forward(self, input):
        # Takes something of shape (N, in_channels, T),
        # returns (N, out_channels, T)
        out = self.conv1d(input)
        return out[:, :, :-self.dilation] # TODO: make this correct for different strides/padding

    
class DenseBlock(nn.Module):
    def __init__(self, in_channels, dilation, filters, kernel_size=2):
        super(DenseBlock, self).__init__()
        self.casualconv1 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)
        self.casualconv2 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)

    def forward(self, input_):
        # input is dimensions (N, in_channels, T)
        xf = self.casualconv1(input_)
        xg = self.casualconv2(input_)
        activations = torch.tanh(xf) * torch.sigmoid(xg) # shape: (N, filters, T)

        return torch.cat((input_, activations), dim=1)

class TCBlock(nn.Module):
    def __init__(self, in_channels, seq_length, filters):
        super(TCBlock, self).__init__()
        self.dense_blocks = nn.ModuleList([DenseBlock(in_channels + i * filters, 2 ** (i+1), filters)
                                           for i in range(int(math.ceil(math.log(seq_length, 2))))])

    def forward(self, input):
        # input is dimensions (N, T, in_channels)
        input = torch.transpose(input, 1, 2)
        for block in self.dense_blocks:
            input = block(input)
        return torch.transpose(input, 1, 2)

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, key_size, value_size):
        super(AttentionBlock, self).__init__()
        self.linear_query = nn.Linear(in_channels, key_size)
        self.linear_keys = nn.Linear(in_channels, key_size)
        self.linear_values = nn.Linear(in_channels, value_size)
        self.sqrt_key_size = math.sqrt(key_size)

    def forward(self, input):
        # input is dim (N, T, in_channels) where N is the batch_size, and T is
        # the sequence length
        mask = np.array([[1 if i>j else 0 for i in range(input.shape[1])] for j in range(input.shape[1])])
        mask = torch.BoolTensor(mask).cuda()

        #import pdb; pdb.set_trace()
        keys = self.linear_keys(input) # shape: (N, T, key_size)
        query = self.linear_query(input) # shape: (N, T, key_size)
        values = self.linear_values(input) # shape: (N, T, value_size)
        temp = torch.bmm(query, torch.transpose(keys, 1, 2)) # shape: (N, T, T)
        temp.data.masked_fill_(mask, -float('inf'))
        temp = F.softmax(temp / self.sqrt_key_size, dim=1) # shape: (N, T, T), broadcasting over any slice [:, x, :], each row of the matrix
        temp = torch.bmm(temp, values) # shape: (N, T, value_size)
        return torch.cat((input, temp), dim=2) # shape: (N, T, in_channels + value_size)

## `resnet_blocks.py`

In [4]:
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1):
    """convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, bias=False)

def conv_block(in_channels, out_channels):
    '''
    returns a block conv-bn-relu-pool
    '''
    return nn.Sequential(OrderedDict([
        ('conv', nn.Conv2d(in_channels, out_channels, 3, padding=1)),
        ('bn', nn.BatchNorm2d(out_channels, momentum=1)),
        #('bn', nn.BatchNorm2d(out_channels)),
        ('relu', nn.ReLU()),
        ('pool', nn.MaxPool2d(2))
    ]))

def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True,eps=1e-5, momentum=0.1):
    # momentum = 1 restricts stats to the current mini-batch
    # This hack only works when momentum is 1 and avoids needing to track
    # running stats by substituting dummy variables
    size = int(np.prod(np.array(input.data.size()[1])))
    running_mean = torch.zeros(size).cuda()
    running_var = torch.ones(size).cuda()
    return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)

class OmniglotNet(nn.Module):
    '''
    Model as described in the reference paper,
    source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
    '''
    def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
        super(OmniglotNet, self).__init__()
        self.encoder = nn.Sequential(OrderedDict([
            ('block1', conv_block(x_dim, hid_dim)),
            ('block2', conv_block(hid_dim, hid_dim)),
            ('block3', conv_block(hid_dim, hid_dim)),
            ('block4', conv_block(hid_dim, z_dim)),
        ]))

    def forward(self, x, weights=None):
        if weights is None:
            x = self.encoder(x)
        else:
            x = F.conv2d(x, weights['encoder.block1.conv.weight'], weights['encoder.block1.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block1.bn.weight'], bias=weights['encoder.block1.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block2.conv.weight'], weights['encoder.block2.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block2.bn.weight'], bias=weights['encoder.block2.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block3.conv.weight'], weights['encoder.block3.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block3.bn.weight'], bias=weights['encoder.block3.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block4.conv.weight'], weights['encoder.block4.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block4.bn.weight'], bias=weights['encoder.block4.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
        return x.view(x.size(0), -1)

class ResBlock(nn.Module):

    def __init__(self, in_channels, filters, pool_padding=0):
        super(ResBlock, self).__init__()
        self.conv1 = conv(in_channels, filters)
        self.bn1 = nn.BatchNorm2d(filters)
        self.relu1 = nn.LeakyReLU()
        self.conv2 = conv(filters, filters)
        self.bn2 = nn.BatchNorm2d(filters)
        self.relu2 = nn.LeakyReLU()
        self.conv3 = conv(filters, filters)
        self.bn3 = nn.BatchNorm2d(filters)
        self.relu3 = nn.LeakyReLU()
        self.conv4 = conv(in_channels, filters, kernel_size=1, padding=0)

        self.maxpool = nn.MaxPool2d(2, padding=pool_padding)
        self.dropout = nn.Dropout(p=0.9)

    def forward(self, x):
        residual = self.conv4(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu3(out)

        out += residual
        out = self.maxpool(out)
        out = self.dropout(out)

        return out

## `snail.py`

In [5]:
class SnailFewShot(nn.Module):
    def __init__(self, N, K, task, use_cuda=True):
        # N-way, K-shot
        super(SnailFewShot, self).__init__()
        if task == 'omniglot':
            self.encoder = OmniglotNet()
            num_channels = 64 + N
        elif task == 'mini_imagenet':
            self.encoder = MiniImagenetNet()
            num_channels = 384 + N
        else:
            raise ValueError('Not recognized task value')
        num_filters = int(math.ceil(math.log(N * K + 1, 2)))
        self.attention1 = AttentionBlock(num_channels, 64, 32)
        num_channels += 32
        self.tc1 = TCBlock(num_channels, N * K + 1, 128)
        num_channels += num_filters * 128
        self.attention2 = AttentionBlock(num_channels, 256, 128)
        num_channels += 128
        self.tc2 = TCBlock(num_channels, N * K + 1, 128)
        num_channels += num_filters * 128
        self.attention3 = AttentionBlock(num_channels, 512, 256)
        num_channels += 256
        self.fc = nn.Linear(num_channels, N)
        self.N = N
        self.K = K
        self.use_cuda = use_cuda

    def forward(self, input, labels):
        x = self.encoder(input)
        batch_size = int(labels.size()[0] / (self.N * self.K + 1))

        # TODO: Maybe move this zeroing to process_torchmeta_batch() at train.py
        last_idxs = [(i + 1) * (self.N * self.K + 1) - 1 for i in range(batch_size)]
        if self.use_cuda:
            labels[last_idxs] = torch.Tensor(np.zeros((batch_size, labels.size()[1]))).cuda()
        else:
            labels[last_idxs] = torch.Tensor(np.zeros((batch_size, labels.size()[1])))

        x = torch.cat((x, labels), 1)
        x = x.view((batch_size, self.N * self.K + 1, -1))
        x = self.attention1(x)
        x = self.tc1(x)
        x = self.attention2(x)
        x = self.tc2(x)
        x = self.attention3(x)
        x = self.fc(x)
        return x

## `train.py`

In [6]:
# coding=utf-8
"""
Main script for training SNAIL on Omniglot.
"""

import argparse
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

from networks import SnailFewShot

from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader


def get_acc(last_model, last_targets):
    _, preds = last_model.max(1)
    acc = torch.eq(preds, last_targets).float().mean()
    return acc.item()


def process_torchmeta_batch(batch, options):
    """
    Process batch from torchmeta dataset for the SNAIL model

    Parameters
    ----------
    batch : dict
        A dictionary given by the torchmeta dataset
    options : SimpleNamespace
        A namespace with configuration details.

    Returns
    -------
    input_images : torch.tensor
        Input images to the SNAIL model of shape (batch_size*(N*K+1), img_channels, img_height,
        img_width) where N and K denote the same variables as in the N-way K-shot problem.
    input_onehot_labels : torch.tensor
        Input one-hot labels to the SNAIL model of shape (batch_size*(N*K+1), N) where N and K
        denote the same variables as in the N-way K-shot problem. The last label for each (N*K+1)
        is not used since it is the target label.
    target_labels : torch.tensor
        Test set labels to evaluate the SNAIL model by comparing with its outputs. Has shape
        (batch_size).

    """
    train_inputs, train_labels = batch["train"]
    test_inputs, test_labels = batch["test"]

    # Select one image from N images in the test set
    chosen_indices = torch.randint(test_inputs.shape[1], size=(test_inputs.shape[0],))
    chosen_test_inputs = test_inputs[torch.arange(test_inputs.shape[0]), chosen_indices, :, :, :].unsqueeze(1)
    chosen_test_labels = test_labels[torch.arange(test_labels.shape[0]), chosen_indices].unsqueeze(1)

    # Concatenate train and test set for SNAIL-style input images and labels
    input_images = torch.cat((train_inputs, chosen_test_inputs), dim=1).reshape((-1, *train_inputs.shape[2:]))
    input_labels = torch.cat((train_labels, chosen_test_labels), dim=1).reshape((-1, *train_labels.shape[2:]))

    # Convert labels to one-hot
    input_onehot_labels = F.one_hot(input_labels).float()

    # Separate out target labels
    target_labels = input_labels[::(options.num_cls * options.num_samples + 1)].long()

    # Move to correct device
    if options.cuda:
        input_images, input_onehot_labels = input_images.cuda(), input_onehot_labels.cuda()
        target_labels = target_labels.cuda()

    return input_images, input_onehot_labels, target_labels


def train(model, optimizer, train_dataloader, val_dataloader, opt):
    if val_dataloader is None:
        best_state = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    best_acc = 0

    best_model_path = os.path.join(opt.exp, 'best_model.pth')
    last_model_path = os.path.join(opt.exp, 'last_model.pth')

    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(opt.epochs):
        print('=== Epoch: {} ==='.format(epoch))

        # Training phase
        model.train()
        for i, batch in tqdm(enumerate(train_dataloader), total=1000):
            if i >= 1000:
                break
            input_images, input_onehot_labels, target_labels = process_torchmeta_batch(batch, opt)
            predicted_labels = model(input_images, input_onehot_labels)[:, -1, :]
            loss = loss_fn(predicted_labels, target_labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            train_acc.append(get_acc(predicted_labels, target_labels))

        avg_loss = np.mean(train_loss[-opt.iterations:])
        avg_acc = np.mean(train_acc[-opt.iterations:])
        print('Avg Train Loss: {}, Avg Train Acc: {}'.format(avg_loss, avg_acc))

        # Validation phase
        model.eval()
        for i, batch in tqdm(enumerate(val_dataloader), total=1000):
            if i >= 1000:
                break
            input_images, input_onehot_labels, target_labels = process_torchmeta_batch(batch, opt)
            predicted_labels = model(input_images, input_onehot_labels)[:, -1, :]
            loss = loss_fn(predicted_labels, target_labels)

            val_loss.append(loss.item())
            val_acc.append(get_acc(predicted_labels, target_labels))

        avg_loss = np.mean(val_loss[-opt.iterations:])
        avg_acc = np.mean(val_acc[-opt.iterations:])

        postfix = ' (Best)' if avg_acc >= best_acc else ' (Best: {})'.format(best_acc)
        print('Avg Val Loss: {}, Avg Val Acc: {}{}'.format(avg_loss, avg_acc, postfix))
        if avg_acc >= best_acc:
            torch.save(model.state_dict(), best_model_path)
            best_acc = avg_acc
            best_state = model.state_dict()

        # TODO(seungjaeryanlee): Understand this code better
        for name in ['train_loss', 'train_acc', 'val_loss', 'val_acc']:
            with open(os.path.join(opt.exp, name + '.txt'), 'w') as f:
                for item in locals()[name]:
                    f.write("%s\n" % item)

    torch.save(model.state_dict(), last_model_path)

    return best_state, best_acc, train_loss, train_acc, val_loss, val_acc


def test(model, test_dataloader, opt):
    """
    Test model on given dataset and options.
    """
    model.eval()
    acc_per_epoch = []
    for epoch in range(opt.test_epochs):
        for i, batch in tqdm(enumerate(test_dataloader), total=1000):
            if i >= 1000:
                break
            input_images, input_onehot_labels, target_labels = process_torchmeta_batch(batch, opt)
            predicted_labels = model(input_images, input_onehot_labels)[:, -1, :]

            acc_per_epoch.append(get_acc(predicted_labels, target_labels))

    avg_acc = np.mean(acc_per_epoch)
    print('Test Acc: {}'.format(avg_acc))

    return avg_acc


In [7]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--exp', type=str, default='default')
# parser.add_argument('--epochs', type=int, default=100)
# parser.add_argument('--test-epochs', type=int, default=1)
# parser.add_argument('--iterations', type=int, default=10000)
# parser.add_argument('--dataset', type=str, default='omniglot')
# parser.add_argument('--num_cls', type=int, default=5)
# parser.add_argument('--num_samples', type=int, default=1)
# parser.add_argument('--lr', type=float, default=0.0001)
# parser.add_argument('--batch_size', type=int, default=32)
# parser.add_argument('--cuda', action='store_true')
# options = parser.parse_args()

from types import SimpleNamespace
options = SimpleNamespace(**{
    "exp": "default",
    "epochs": 1,
    "test_epochs": 1,
    "iterations": 10000,
    "dataset": "omniglot",
    "num_cls": 5,
    "num_samples": 1,
    "lr": 0.0001,
    "batch_size": 32,
    "cuda": True,
})

if not os.path.exists(options.exp):
    os.makedirs(options.exp)

if torch.cuda.is_available() and not options.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

# Setup dataset
train_dataset = omniglot("data", ways=5, shots=1, test_shots=1, meta_train=True, download=True)
train_dataloader = BatchMetaDataLoader(train_dataset, batch_size=options.batch_size, num_workers=8)
val_dataset = omniglot("data", ways=5, shots=1, test_shots=1, meta_val=True, download=True)
val_dataloader = BatchMetaDataLoader(val_dataset, batch_size=options.batch_size, num_workers=8)
test_dataset = omniglot("data", ways=5, shots=1, test_shots=1, meta_test=True, download=True)
test_dataloader = BatchMetaDataLoader(test_dataset, batch_size=options.batch_size, num_workers=8)
# Setup model
model = SnailFewShot(options.num_cls, options.num_samples, options.dataset, options.cuda)
model = model.cuda() if options.cuda else model
# Setup optimizer
optimizer = optim.Adam(params=model.parameters(), lr=options.lr)

# Train model
train_result = train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    opt=options,
)
best_state, best_acc, train_loss, train_acc, val_loss, val_acc = train_result

# Test last model
print('Testing with last model..')
test(
    model=model,
    test_dataloader=test_dataloader,
    opt=options,
)

# Test best model
model.load_state_dict(best_state)
print('Testing with best model..')
test(
    opt=options,
    test_dataloader=test_dataloader,
    model=model,
)

=== Epoch: 0 ===


100%|██████████| 1000/1000 [00:45<00:00, 21.93it/s]


Avg Train Loss: 0.5472219952940941, Avg Train Acc: 0.765125


100%|██████████| 1000/1000 [00:42<00:00, 23.72it/s]


Avg Val Loss: 0.0045445125997066495, Avg Val Acc: 1.0 (Best)
Testing with last model..


100%|██████████| 1000/1000 [00:42<00:00, 23.40it/s]


Test Acc: 1.0
Testing with best model..


100%|██████████| 1000/1000 [00:42<00:00, 23.59it/s]


Test Acc: 1.0
