In [1]:
import sys
import argparse
import os
import glob

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import autograd
from torch.utils import tensorboard

import imageio
from torch.utils.data import dataset, sampler, dataloader

In [2]:
NUM_INPUT_CHANNELS = 1
NUM_HIDDEN_CHANNELS = 32
KERNEL_SIZE = 3
NUM_CONV_LAYERS = 4
SUMMARY_INTERVAL = 10
SAVE_INTERVAL = 100
LOG_INTERVAL = 10
VAL_INTERVAL = LOG_INTERVAL * 5
NUM_TEST_TASKS = 600

NUM_TRAIN_CLASSES = 1100
NUM_VAL_CLASSES = 100
NUM_TEST_CLASSES = 423
NUM_SAMPLES_PER_CLASS = 20

In [3]:
def load_image(file_path):
    """Loads and transforms an Omniglot image"""
    x = imageio.imread(file_path)
    x = torch.tensor(x, dtype=torch.float32).reshape([1, 28, 28])
    x = x / 255.0
    return 1 - x


class OmniglotDataset(dataset.Dataset):
    """Omniglot dataset for meta-learning.

    Each element of the dataset is a task. A task is specified with a key,
    which is a tuple of class indices (no particular order). The corresponding
    value is the instantiated task, which consists of sampled (image, label)
    pairs.
    """

    _BASE_PATH = './omniglot_resized'

    def __init__(self, num_support, num_query):
        """Inits OmniglotDataset"""
        super().__init__()

        # get all character folders
        self._character_folders = glob.glob(
            os.path.join(self._BASE_PATH, '*/*/'))
        assert len(self._character_folders) == (
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES + NUM_TEST_CLASSES
        )

        # shuffle characters
        np.random.default_rng(0).shuffle(self._character_folders)

        # check problem arguments
        assert num_support + num_query <= NUM_SAMPLES_PER_CLASS
        self._num_support = num_support
        self._num_query = num_query

    def __getitem__(self, class_idxs):
        """Constructs a task.

        Data for each class is sampled uniformly at random without replacement.
        """
        images_support, images_query = [], []
        labels_support, labels_query = [], []

        for label, class_idx in enumerate(class_idxs):
            # get a class's examples and sample from them
            all_file_paths = glob.glob(
                os.path.join(self._character_folders[class_idx], '*.png')
            )
            sampled_file_paths = np.random.default_rng().choice(
                all_file_paths,
                size=self._num_support + self._num_query,
                replace=False
            )
            images = [load_image(file_path) for file_path in sampled_file_paths]

            # split sampled examples into support and query
            images_support.extend(images[:self._num_support])
            images_query.extend(images[self._num_support:])
            labels_support.extend([label] * self._num_support)
            labels_query.extend([label] * self._num_query)

        # aggregate into tensors
        images_support = torch.stack(images_support)  # shape (N*S, C, H, W)
        labels_support = torch.tensor(labels_support)  # shape (N*S)
        images_query = torch.stack(images_query)
        labels_query = torch.tensor(labels_query)

        return images_support, labels_support, images_query, labels_query


class OmniglotSampler(sampler.Sampler):
    """Samples task specification keys for an OmniglotDataset."""
    def __init__(self, split_idxs, num_way, num_tasks):
        """Inits OmniglotSampler"""
        super().__init__(None)
        self._split_idxs = split_idxs
        self._num_way = num_way
        self._num_tasks = num_tasks

    def __iter__(self):
        return (
            np.random.default_rng().choice(
                self._split_idxs,
                size=self._num_way,
                replace=False
            ) for _ in range(self._num_tasks)
        )

    def __len__(self):
        return self._num_tasks

def identity(x):
    return x

def get_omniglot_dataloader(
        split,
        batch_size,
        num_way,
        num_support,
        num_query,
        num_tasks_per_epoch,
        num_workers=0,
):
    """Returns a dataloader.DataLoader for Omniglot"""

    if split == 'train':
        split_idxs = range(NUM_TRAIN_CLASSES)
    elif split == 'val':
        split_idxs = range(
            NUM_TRAIN_CLASSES,
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES
        )
    elif split == 'test':
        split_idxs = range(
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES,
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES + NUM_TEST_CLASSES
        )
    else:
        raise ValueError

    return dataloader.DataLoader(
        dataset=OmniglotDataset(num_support, num_query),
        batch_size=batch_size,
        sampler=OmniglotSampler(split_idxs, num_way, num_tasks_per_epoch),
        num_workers=0,
        collate_fn=identity,
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )


In [7]:
class MAML:
    """Trains and assesses a MAML."""

    def __init__(
            self, num_outputs, num_inner_steps, inner_lr, learn_inner_lrs, outer_lr,
    ):
        meta_parameters = {}

        # construct feature extractor
        in_channels = NUM_INPUT_CHANNELS
        for i in range(NUM_CONV_LAYERS):
            meta_parameters[f'conv{i}'] = nn.init.xavier_uniform_(
                torch.empty(
                    NUM_HIDDEN_CHANNELS,
                    in_channels,
                    KERNEL_SIZE,
                    KERNEL_SIZE,
                    requires_grad=True,
                )
            )
            meta_parameters[f'b{i}'] = nn.init.zeros_(
                torch.empty(
                    NUM_HIDDEN_CHANNELS,
                    requires_grad=True,
                )
            )
            in_channels = NUM_HIDDEN_CHANNELS

        # construct linear head layer
        meta_parameters[f'w{NUM_CONV_LAYERS}'] = nn.init.xavier_uniform_(
            torch.empty(
                num_outputs,
                NUM_HIDDEN_CHANNELS,
                requires_grad=True,
            )
        )
        meta_parameters[f'b{NUM_CONV_LAYERS}'] = nn.init.zeros_(
            torch.empty(
                num_outputs,
                requires_grad=True,
            )
        )

        self._meta_parameters = meta_parameters
        self._num_inner_steps = num_inner_steps
        self._inner_lrs = {
            k: torch.tensor(inner_lr, requires_grad=learn_inner_lrs)
            for k in self._meta_parameters.keys()
        }
        self._outer_lr = outer_lr

        self._optimizer = torch.optim.Adam(
            list(self._meta_parameters.values()) +
            list(self._inner_lrs.values()),
            lr=self._outer_lr
        )

        self._start_train_step = 0

    def _forward(self, images, parameters):
        """Computes predicted classification logits"""
        x = images
        for i in range(NUM_CONV_LAYERS):
            x = F.conv2d(
                input=x,
                weight=parameters[f'conv{i}'],
                bias=parameters[f'b{i}'],
                stride=1,
                padding='same'
            )
            x = F.batch_norm(x, None, None, training=True)
            x = F.relu(x)
        x = torch.mean(x, dim=[2, 3])
        return F.linear(
            input=x,
            weight=parameters[f'w{NUM_CONV_LAYERS}'],
            bias=parameters[f'b{NUM_CONV_LAYERS}']
        )

    def _inner_loop(self, images, labels, train):
        """Computes the adapted network parameters via the MAML inner loop"""
        accuracies = []
        parameters = {
            k: torch.clone(v)
            for k, v in self._meta_parameters.items()
        }
        for i in range(self._num_inner_steps):
            logit = self._forward(images, parameters)
            loss = F.cross_entropy(logit, labels)
            gradients = autograd.grad(
                loss,
                parameters.values(),
                create_graph=train
            )

            for i, (k, v) in enumerate(parameters.items()):
                parameters[k] = v - self._inner_lrs[k] * gradients[i]

            accuracies.append(
                score(logit, labels)
            )

        logit = self._forward(images, parameters)
        accuracies.append(
            score(logit, labels)
        )
        return parameters, accuracies

    def _outer_step(self, task_batch, train):
        """Computes the MAML loss and metrics on a batch of tasks"""
        outer_loss_batch = []
        accuracies_support_batch = []
        accuracy_query_batch = []
        for task in task_batch:
            images_support, labels_support, images_query, labels_query = task
            params, support_acc = self._inner_loop(images_support, labels_support, train)

            logit = self._forward(images_query, params)
            loss = F.cross_entropy(logit, labels_query)

            outer_loss_batch.append(loss)
            accuracies_support_batch.append(support_acc)
            accuracy_query_batch.append(
                score(logit, labels_query)
            )
        outer_loss = torch.mean(torch.stack(outer_loss_batch))
        accuracies_support = np.mean(
            accuracies_support_batch,
            axis=0
        )
        accuracy_query = np.mean(accuracy_query_batch)
        return outer_loss, accuracies_support, accuracy_query

    def train(self, dataloader_train, dataloader_val):
        print(f'Starting training at iteration {self._start_train_step}.')        
        for i_step, task_batch in enumerate(
                dataloader_train,
                start=self._start_train_step
        ):
            self._optimizer.zero_grad()
            outer_loss, accuracies_support, accuracy_query = (
                self._outer_step(task_batch, train=True)
            )
            outer_loss.backward()
            self._optimizer.step()

            if i_step % LOG_INTERVAL == 0:
                print(
                    f'Iteration {i_step}: '
                    f'loss: {outer_loss.item():.3f}, '
                    f'pre-adaptation support accuracy: {accuracies_support[0]:.3f}, '
                    f'post-adaptation support accuracy: {accuracies_support[-1]:.3f}, '
                    f'post-adaptation query accuracy: {accuracy_query:.3f}'
                )

            if i_step % VAL_INTERVAL == 0:
                losses = []
                acc_pre_adapt_support = []
                acc_post_adapt_support = []
                acc_post_adapt_query = []
                for val_task_batch in dataloader_val:
                    outer_loss, accuracies_support, accuracy_query = (
                        self._outer_step(val_task_batch, train=False)
                    )
                    losses.append(outer_loss.item())
                    acc_pre_adapt_support.append(accuracies_support[0])
                    acc_post_adapt_support.append(accuracies_support[-1])
                    acc_post_adapt_query.append(accuracy_query)
                loss = np.mean(losses)
                acc_pre_adapt_support = np.mean(acc_pre_adapt_support)
                acc_post_adapt_support = np.mean(acc_post_adapt_support)
                acc_post_adapt_query = np.mean(acc_post_adapt_query)
                print(
                    f'Validation: '
                    f'loss: {loss:.3f}, '
                    f'pre-adaptation support accuracy: '
                    f'{acc_pre_adapt_support:.3f}, '
                    f'post-adaptation support accuracy: '
                    f'{acc_post_adapt_support:.3f}, '
                    f'post-adaptation query accuracy: '
                    f'{acc_post_adapt_query:.3f}'
                )

    def test(self, dataloader_test):
        """Evaluate the MAML on test tasks"""
        accuracies = []
        for task_batch in dataloader_test:
            _, _, accuracy_query = self._outer_step(task_batch, train=False)
            accuracies.append(accuracy_query)
        mean = np.mean(accuracies)
        std = np.std(accuracies)
        mean_95_confidence_interval = 1.96 * std / np.sqrt(NUM_TEST_TASKS)
        print(
            f'Accuracy over {NUM_TEST_TASKS} test tasks: '
            f'mean {mean:.3f}, '
            f'95% confidence interval {mean_95_confidence_interval:.3f}'
        )


In [5]:
def score(logits, labels):
    """Returns the mean accuracy of a model's predictions on a set of examples"""
    y = torch.argmax(logits, dim=-1) == labels
    y = y.type(torch.float)
    return torch.mean(y).item()

In [None]:
def main(args):
    print(args)
    maml = MAML(
        args.num_way, args.num_inner_steps, args.inner_lr, args.learn_inner_lrs, args.outer_lr,
    )
    if not args.test:
        num_training_tasks = args.batch_size * (args.num_train_iterations - 1)
        print(
            f'Training on {num_training_tasks} tasks with composition: '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_train = get_omniglot_dataloader(
            'train', args.batch_size, args.num_way, args.num_support, args.num_query, num_training_tasks
        )
        dataloader_val = get_omniglot_dataloader(
            'val', args.batch_size, args.num_way, args.num_support, args.num_query, args.batch_size * 4
        )
        maml.train(
            dataloader_train, dataloader_val,
        )
    else:
        print(
            f'Testing on tasks with composition '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_test = get_omniglot_dataloader(
            'test',
            1,
            args.num_way,
            args.num_support,
            args.num_query,
            NUM_TEST_TASKS
        )
        maml.test(dataloader_test)

In [6]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser('Train a MAML!')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='directory to save to or load from')
    parser.add_argument('--num_way', type=int, default=5,
                        help='number of classes in a task')
    parser.add_argument('--num_support', type=int, default=1,
                        help='number of support examples per class in a task')
    parser.add_argument('--num_query', type=int, default=15,
                        help='number of query examples per class in a task')
    parser.add_argument('--num_inner_steps', type=int, default=1,
                        help='number of inner-loop updates')
    parser.add_argument('--inner_lr', type=float, default=0.4,
                        help='inner-loop learning rate initialization')
    parser.add_argument('--learn_inner_lrs', default=False, action='store_true',
                        help='whether to optimize inner-loop learning rates')
    parser.add_argument('--outer_lr', type=float, default=0.001,
                        help='outer-loop learning rate')
    parser.add_argument('--batch_size', type=int, default=16,
                        help='number of tasks per outer-loop update')
    parser.add_argument('--num_train_iterations', type=int, default=150,
                        help='number of outer-loop updates to train for')
    parser.add_argument('--test', default=False, action='store_true',
                        help='train or test')
    parser.add_argument('--cache', action='store_true')
    parser.add_argument('--device', type=str, default='cpu')

    args, unknown = parser.parse_known_args()

    main(args)

Namespace(log_dir=None, num_way=5, num_support=1, num_query=15, num_inner_steps=1, inner_lr=0.4, learn_inner_lrs=False, outer_lr=0.001, batch_size=16, num_train_iterations=150, test=False, cache=False, device='cpu')
Training on 2384 tasks with composition: num_way=5, num_support=1, num_query=15
Starting training at iteration 0.


  x = imageio.imread(file_path)


Iteration 0: loss: 1.585, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.363, post-adaptation query accuracy: 0.287
Validation: loss: 1.554, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.422, post-adaptation query accuracy: 0.331
Iteration 10: loss: 1.477, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.588, post-adaptation query accuracy: 0.435
Iteration 20: loss: 1.439, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.588, post-adaptation query accuracy: 0.467
Iteration 30: loss: 1.383, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.725, post-adaptation query accuracy: 0.512
Iteration 40: loss: 1.322, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.688, post-adaptation query accuracy: 0.552
Iteration 50: loss: 1.306, pre-adaptation support accuracy: 0.213, post-adaptation support accuracy: 0.700, post-adaptation que