In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np



In [8]:
!pip3 install googledrivedownloader

Collecting googledrivedownloader
  Downloading googledrivedownloader-0.4-py2.py3-none-any.whl (3.9 kB)
Installing collected packages: googledrivedownloader
Successfully installed googledrivedownloader-0.4


In [10]:
"""Dataloading for Omniglot."""
import os
import glob

import google_drive_downloader as gdd
import imageio
import numpy as np
import torch
from torch.utils.data import dataset, sampler, dataloader

NUM_TRAIN_CLASSES = 1100
NUM_VAL_CLASSES = 100
NUM_TEST_CLASSES = 423
NUM_SAMPLES_PER_CLASS = 20


def load_image(file_path):
    """Loads and transforms an Omniglot image.

    Args:
        file_path (str): file path of image

    Returns:
        a Tensor containing image data
            shape (1, 28, 28)
    """
    x = imageio.imread(file_path)
    x = torch.tensor(x, dtype=torch.float32).reshape([1, 28, 28]) # (channel, width, height)
    x = x / 255.0
    return 1 - x


class OmniglotDataset(dataset.Dataset):
    """Omniglot dataset for meta-learning.

    Each element of the dataset is a task. A task is specified with a key,
    which is a tuple of class indices (no particular order). The corresponding
    value is the instantiated task, which consists of sampled (image, label)
    pairs.
    """

    _BASE_PATH = './omniglot_resized'
    _GDD_FILE_ID = '1iaSFXIYC3AB8q9K_M-oVMa4pmB7yKMtI'

    def __init__(self, num_support, num_query):
        """Inits OmniglotDataset.

        Args:
            num_support (int): number of support examples per class
            num_query (int): number of query examples per class
        """
        super().__init__()


        # if necessary, download the Omniglot dataset
        if not os.path.isdir(self._BASE_PATH):
            gdd.GoogleDriveDownloader.download_file_from_google_drive(
                file_id=self._GDD_FILE_ID,
                dest_path=f'{self._BASE_PATH}.zip',
                unzip=True
            )

        # get all character folders
        self._character_folders = glob.glob(
            os.path.join(self._BASE_PATH, '*/*/'))
        assert len(self._character_folders) == (
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES + NUM_TEST_CLASSES
        )

        # shuffle characters
        np.random.default_rng(0).shuffle(self._character_folders)

        # check problem arguments
        assert num_support + num_query <= NUM_SAMPLES_PER_CLASS
        self._num_support = num_support
        self._num_query = num_query

    def __getitem__(self, class_idxs):
        """Constructs a task.

        Data for each class is sampled uniformly at random without replacement.
        The ordering of the labels corresponds to that of class_idxs.

        Args:
            class_idxs (tuple[int]): class indices that comprise the task

        Returns:
            images_support (Tensor): task support images
                shape (num_way * num_support, channels, height, width)
            labels_support (Tensor): task support labels
                shape (num_way * num_support,)
            images_query (Tensor): task query images
                shape (num_way * num_query, channels, height, width)
            labels_query (Tensor): task query labels
                shape (num_way * num_query,)
        """
        images_support, images_query = [], []
        labels_support, labels_query = [], []

        for label, class_idx in enumerate(class_idxs):
            # get a class's examples and sample from them
            all_file_paths = glob.glob(
                os.path.join(self._character_folders[class_idx], '*.png')
            )
            sampled_file_paths = np.random.default_rng().choice(
                all_file_paths,
                size=self._num_support + self._num_query,
                replace=False
            )
            images = [load_image(file_path) for file_path in sampled_file_paths]

            # split sampled examples into support and query
            images_support.extend(images[:self._num_support])
            images_query.extend(images[self._num_support:])
            labels_support.extend([label] * self._num_support)
            labels_query.extend([label] * self._num_query)

        # aggregate into tensors
        images_support = torch.stack(images_support)  # shape (N*S, C, H, W)
        labels_support = torch.tensor(labels_support)  # shape (N*S)
        images_query = torch.stack(images_query)
        labels_query = torch.tensor(labels_query)

        return images_support, labels_support, images_query, labels_query


class OmniglotSampler(sampler.Sampler):
    """Samples task specification keys for an OmniglotDataset."""

    def __init__(self, split_idxs, num_way, num_tasks):
        """Inits OmniglotSampler.

        Args:
            split_idxs (range): indices that comprise the
                training/validation/test split
            num_way (int): number of classes per task
            num_tasks (int): number of tasks to sample
        """
        super().__init__(None)
        self._split_idxs = split_idxs
        self._num_way = num_way
        self._num_tasks = num_tasks

    def __iter__(self):
        return (
            np.random.default_rng().choice(
                self._split_idxs,
                size=self._num_way,
                replace=False
            ) for _ in range(self._num_tasks)
        )

    def __len__(self):
        return self._num_tasks


def identity(x):
    return x


def get_omniglot_dataloader(
        split,
        batch_size,
        num_way,
        num_support,
        num_query,
        num_tasks_per_epoch
):
    """Returns a dataloader.DataLoader for Omniglot.

    Args:
        split (str): one of 'train', 'val', 'test'
        batch_size (int): number of tasks per batch
        num_way (int): number of classes per task
        num_support (int): number of support examples per class
        num_query (int): number of query examples per class
        num_tasks_per_epoch (int): number of tasks before DataLoader is
            exhausted
    """

    if split == 'train':
        split_idxs = range(NUM_TRAIN_CLASSES)
    elif split == 'val':
        split_idxs = range(
            NUM_TRAIN_CLASSES,
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES
        )
    elif split == 'test':
        split_idxs = range(
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES,
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES + NUM_TEST_CLASSES
        )
    else:
        raise ValueError

    return dataloader.DataLoader(
        dataset=OmniglotDataset(num_support, num_query),
        batch_size=batch_size,
        sampler=OmniglotSampler(split_idxs, num_way, num_tasks_per_epoch),
        num_workers=2,
        collate_fn=identity,
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )

In [11]:
train_loader = get_omniglot_dataloader("train", 10, 3, 5, 1, 1000)

Downloading 1iaSFXIYC3AB8q9K_M-oVMa4pmB7yKMtI into ./omniglot_resized.zip... Done.
Unzipping...Done.


In [12]:
len(train_loader)

100

In [13]:
for data in train_loader:
  print(data[0][0].shape)
  break

  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


torch.Size([15, 1, 28, 28])


In [14]:
"""Utilities for scoring the model."""
import torch


def score(logits, labels):
    """Returns the mean accuracy of a model's predictions on a set of examples.

    Args:
        logits (torch.Tensor): model predicted logits
            shape (examples, classes)
        labels (torch.Tensor): classification labels from 0 to num_classes - 1
            shape (examples,)
    """

    assert logits.dim() == 2
    assert labels.dim() == 1
    assert logits.shape[0] == labels.shape[0]
    y = torch.argmax(logits, dim=-1) == labels
    y = y.type(torch.float)
    return torch.mean(y).item()

In [15]:
"""Implementation of model-agnostic meta-learning for Omniglot."""

import argparse
import os

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import autograd  # pylint: disable=unused-import
from torch.utils import tensorboard


NUM_INPUT_CHANNELS = 1
NUM_HIDDEN_CHANNELS = 64
KERNEL_SIZE = 3
NUM_CONV_LAYERS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SUMMARY_INTERVAL = 10
SAVE_INTERVAL = 100
LOG_INTERVAL = 10
VAL_INTERVAL = LOG_INTERVAL * 5
NUM_TEST_TASKS = 600


class MAML:
    """Trains and assesses a MAML."""

    def __init__(
            self,
            num_outputs,
            num_inner_steps,
            inner_lr,
            learn_inner_lrs,
            outer_lr,
            log_dir
    ):
        """Inits MAML.

        The network consists of four convolutional blocks followed by a linear
        head layer. Each convolutional block comprises a convolution layer, a
        batch normalization layer, and ReLU activation.

        Note that unlike conventional use, batch normalization is always done
        with batch statistics, regardless of whether we are training or
        evaluating. This technically makes meta-learning transductive, as
        opposed to inductive.

        Args:
            num_outputs (int): dimensionality of output, i.e. number of classes
                in a task
            num_inner_steps (int): number of inner-loop optimization steps
            inner_lr (float): learning rate for inner-loop optimization
                If learn_inner_lrs=True, inner_lr serves as the initialization
                of the learning rates.
            learn_inner_lrs (bool): whether to learn the above
            outer_lr (float): learning rate for outer-loop optimization
            log_dir (str): path to logging directory
        """
        meta_parameters = {}

        # construct feature extractor
        in_channels = NUM_INPUT_CHANNELS
        for i in range(NUM_CONV_LAYERS):
            meta_parameters[f'conv{i}'] = nn.init.xavier_uniform_(
                torch.empty(
                    NUM_HIDDEN_CHANNELS,
                    in_channels,
                    KERNEL_SIZE,
                    KERNEL_SIZE,
                    requires_grad=True,
                    device=DEVICE
                )
            )
            meta_parameters[f'b{i}'] = nn.init.zeros_(
                torch.empty(
                    NUM_HIDDEN_CHANNELS,
                    requires_grad=True,
                    device=DEVICE
                )
            )
            in_channels = NUM_HIDDEN_CHANNELS

        # construct linear head layer
        meta_parameters[f'w{NUM_CONV_LAYERS}'] = nn.init.xavier_uniform_(
            torch.empty(
                num_outputs,
                NUM_HIDDEN_CHANNELS,
                requires_grad=True,
                device=DEVICE
            )
        )
        meta_parameters[f'b{NUM_CONV_LAYERS}'] = nn.init.zeros_(
            torch.empty(
                num_outputs,
                requires_grad=True,
                device=DEVICE
            )
        )

        self._meta_parameters = meta_parameters
        self._num_inner_steps = num_inner_steps
        self._inner_lrs = {
            k: torch.tensor(inner_lr, requires_grad=learn_inner_lrs)
            for k in self._meta_parameters.keys()
        }
        self._outer_lr = outer_lr

        self._optimizer = torch.optim.Adam(
            list(self._meta_parameters.values()) +
            list(self._inner_lrs.values()),
            lr=self._outer_lr
        )

        self._log_dir = log_dir
        os.makedirs(self._log_dir, exist_ok=True)

        self._start_train_step = 0

    def _forward(self, images, parameters):
        """Computes predicted classification logits.

        Args:
            images (Tensor): batch of Omniglot images
                shape (num_images, channels, height, width)
            parameters (dict[str, Tensor]): parameters to use for
                the computation

        Returns:
            a Tensor consisting of a batch of logits
                shape (num_images, classes)
        """
        x = images
        for i in range(NUM_CONV_LAYERS):
            x = F.conv2d(
                input=x,
                weight=parameters[f'conv{i}'],
                bias=parameters[f'b{i}'],
                stride=1,
                padding='same'
            )
            x = F.batch_norm(x, None, None, training=True)
            x = F.relu(x)
        x = torch.mean(x, dim=[2, 3])
        return F.linear(
            input=x,
            weight=parameters[f'w{NUM_CONV_LAYERS}'],
            bias=parameters[f'b{NUM_CONV_LAYERS}']
        )

    def _inner_loop(self, images, labels, train):   # pylint: disable=unused-argument
        """Computes the adapted network parameters via the MAML inner loop.

        Args:
            images (Tensor): task support set inputs
                shape (num_images, channels, height, width)
            labels (Tensor): task support set outputs
                shape (num_images,)
            train (bool): whether we are training or evaluating (not necessary?)

        Returns:
            parameters (dict[str, Tensor]): adapted network parameters
            accuracies (list[float]): support set accuracy over the course of
                the inner loop, length num_inner_steps + 1
        """
        accuracies = []
        parameters = {
            k: torch.clone(v)
            for k, v in self._meta_parameters.items()
        }
        # ********************************************************
        # ******************* YOUR CODE HERE *********************
        # ********************************************************
        # TODO: finish implementing this method.
        # This method computes the inner loop (adaptation) procedure for one
        # task. It also scores the model along the way.
        # Make sure to populate accuracies and update parameters.
        # Use F.cross_entropy to compute classification losses.
        # Use util.score to compute accuracies.

        # here we are doing \phi_i = \theta - inner_lr * grad(\theta, L, D_{tr})
        for _ in range(self._num_inner_steps):
            logits = self._forward(images, parameters)
            loss = F.cross_entropy(logits, labels)

            # create graph due to computing second order derivatives in MAML
            gradients = autograd.grad(
                loss, parameters.values(), create_graph=True)

            # update parameters
            for i in range(len(parameters.keys())):
                k = list(parameters.keys())[i]
                v = list(parameters.values())[i]
                assert v.shape == gradients[i].shape, 'Not proper shape'

                parameters[k] = v - self._inner_lrs[k] * gradients[i]

            acc = score(logits, labels)
            accuracies.append(acc)

        final_logits = self._forward(images, parameters)
        final_acc = score(final_logits, labels)
        accuracies.append(final_acc)
        # ********************************************************
        # ******************* YOUR CODE HERE *********************
        # ********************************************************
        return parameters, accuracies

    def _outer_step(self, task_batch, train):  # pylint: disable=unused-argument
        """Computes the MAML loss and metrics on a batch of tasks.

        Args:
            task_batch (tuple): batch of tasks from an Omniglot DataLoader
            train (bool): whether we are training or evaluating

        Returns:
            outer_loss (Tensor): mean MAML loss over the batch, scalar
            accuracies_support (ndarray): support set accuracy over the
                course of the inner loop, averaged over the task batch
                shape (num_inner_steps + 1,)
            accuracy_query (float): query set accuracy of the adapted
                parameters, averaged over the task batch
        """
        outer_loss_batch = []
        accuracies_support_batch = []
        accuracy_query_batch = []
        for task in task_batch:
            images_support, labels_support, images_query, labels_query = task
            images_support = images_support.to(DEVICE)
            labels_support = labels_support.to(DEVICE)
            images_query = images_query.to(DEVICE)
            labels_query = labels_query.to(DEVICE)
            # ********************************************************
            # ******************* YOUR CODE HERE *********************
            # ********************************************************
            # TODO: finish implementing this method.
            # For a given task, use the _inner_loop method to adapt, then
            # compute the MAML loss and other metrics.
            # Use F.cross_entropy to compute classification losses.
            # Use util.score to compute accuracies.
            # Make sure to populate outer_loss_batch, accuracies_support_batch,
            # and accuracy_query_batch.

            # computes \phi_L
            parameters, supp_accs = self._inner_loop(
                images_support, labels_support, train)

            # gets the loss w.r.t. \phi_L
            logits = self._forward(images_query, parameters)
            loss = F.cross_entropy(logits, labels_query)
            outer_loss_batch.append(loss)

            accuracies_support_batch.append(supp_accs)
            q_accs = score(logits, labels_query)
            accuracy_query_batch.append(q_accs)

            # ********************************************************
            # ******************* YOUR CODE HERE *********************
            # ********************************************************
        outer_loss = torch.mean(torch.stack(outer_loss_batch))
        accuracies_support = np.mean(
            accuracies_support_batch,
            axis=0
        )

        accuracy_query = np.mean(accuracy_query_batch)
        return outer_loss, accuracies_support, accuracy_query

    def train(self, dataloader_train, dataloader_val, writer):
        """Train the MAML.

        Consumes dataloader_train to optimize MAML meta-parameters
        while periodically validating on dataloader_val, logging metrics, and
        saving checkpoints.

        Args:
            dataloader_train (DataLoader): loader for train tasks
            dataloader_val (DataLoader): loader for validation tasks
            writer (SummaryWriter): TensorBoard logger
        """
        print(f'Starting training at iteration {self._start_train_step}.')

        for i_step, task_batch in enumerate(
                dataloader_train,
                start=self._start_train_step
        ):

            self._optimizer.zero_grad()
            outer_loss, accuracies_support, accuracy_query = (
                self._outer_step(task_batch, train=True)
            )

            outer_loss.backward()
            self._optimizer.step()

            if i_step % LOG_INTERVAL == 0:
                print(
                    f'Iteration {i_step}: '
                    f'loss: {outer_loss.item():.3f}, '
                    f'pre-adaptation support accuracy: '
                    f'{accuracies_support[0]:.3f}, '
                    f'post-adaptation support accuracy: '
                    f'{accuracies_support[-1]:.3f}, '
                    f'post-adaptation query accuracy: '
                    f'{accuracy_query:.3f}'
                )
                writer.add_scalar('loss/train', outer_loss.item(), i_step)
                writer.add_scalar(
                    'train_accuracy/pre_adapt_support',
                    accuracies_support[0],
                    i_step
                )
                writer.add_scalar(
                    'train_accuracy/post_adapt_support',
                    accuracies_support[-1],
                    i_step
                )
                writer.add_scalar(
                    'train_accuracy/post_adapt_query',
                    accuracy_query,
                    i_step
                )

            if i_step % VAL_INTERVAL == 0:
                losses = []
                accuracies_pre_adapt_support = []
                accuracies_post_adapt_support = []
                accuracies_post_adapt_query = []
                for val_task_batch in dataloader_val:
                    outer_loss, accuracies_support, accuracy_query = (
                        self._outer_step(val_task_batch, train=False)
                    )
                    losses.append(outer_loss.item())
                    accuracies_pre_adapt_support.append(accuracies_support[0])
                    accuracies_post_adapt_support.append(
                        accuracies_support[-1])
                    
                    accuracies_post_adapt_query.append(accuracy_query)
                loss = np.mean(losses)
                
                accuracy_pre_adapt_support = np.mean(
                    accuracies_pre_adapt_support
                )
                accuracy_post_adapt_support = np.mean(
                    accuracies_post_adapt_support
                )
                accuracy_post_adapt_query = np.mean(
                    accuracies_post_adapt_query
                )

                print(
                    f'Validation: '
                    f'loss: {loss:.3f}, '
                    f'pre-adaptation support accuracy: '
                    f'{accuracy_pre_adapt_support:.3f}, '
                    f'post-adaptation support accuracy: '
                    f'{accuracy_post_adapt_support:.3f}, '
                    f'post-adaptation query accuracy: '
                    f'{accuracy_post_adapt_query:.3f}'
                )
                
                writer.add_scalar('loss/val', loss, i_step)
                writer.add_scalar(
                    'val_accuracy/pre_adapt_support',
                    accuracy_pre_adapt_support,
                    i_step
                )
                writer.add_scalar(
                    'val_accuracy/post_adapt_support',
                    accuracy_post_adapt_support,
                    i_step
                )
                writer.add_scalar(
                    'val_accuracy/post_adapt_query',
                    accuracy_post_adapt_query,
                    i_step
                )

            if i_step % SAVE_INTERVAL == 0:
                self._save(i_step)

    def test(self, dataloader_test):
        """Evaluate the MAML on test tasks.

        Args:
            dataloader_test (DataLoader): loader for test tasks
        """
        accuracies = []
        for task_batch in dataloader_test:
            _, _, accuracy_query = self._outer_step(task_batch, train=False)
            accuracies.append(accuracy_query)
        mean = np.mean(accuracies)
        std = np.std(accuracies)
        mean_95_confidence_interval = 1.96 * std / np.sqrt(NUM_TEST_TASKS)
        print(
            f'Accuracy over {NUM_TEST_TASKS} test tasks: '
            f'mean {mean:.3f}, '
            f'95% confidence interval {mean_95_confidence_interval:.3f}'
        )

    def load(self, checkpoint_step):
        """Loads a checkpoint.

        Args:
            checkpoint_step (int): iteration of checkpoint to load

        Raises:
            ValueError: if checkpoint for checkpoint_step is not found
        """
        target_path = (
            f'{os.path.join(self._log_dir, "state")}'
            f'{checkpoint_step}.pt'
        )
        if os.path.isfile(target_path):
            state = torch.load(target_path)
            self._meta_parameters = state['meta_parameters']
            self._inner_lrs = state['inner_lrs']
            self._optimizer.load_state_dict(state['optimizer_state_dict'])
            self._start_train_step = checkpoint_step + 1
            print(f'Loaded checkpoint iteration {checkpoint_step}.')
        else:
            raise ValueError(
                f'No checkpoint for iteration {checkpoint_step} found.'
            )

    def _save(self, checkpoint_step):
        """Saves parameters and optimizer state_dict as a checkpoint.

        Args:
            checkpoint_step (int): iteration to label checkpoint with
        """
        optimizer_state_dict = self._optimizer.state_dict()
        torch.save(
            dict(meta_parameters=self._meta_parameters,
                 inner_lrs=self._inner_lrs,
                 optimizer_state_dict=optimizer_state_dict),
            f'{os.path.join(self._log_dir, "state")}{checkpoint_step}.pt'
        )
        print('Saved checkpoint.')


def main(args):
    log_dir = args.log_dir
    if log_dir is None:
        log_dir = f'./logs/maml/omniglot.way:{args.num_way}.support:{args.num_support}.query:{args.num_query}.inner_steps:{args.num_inner_steps}.inner_lr:{args.inner_lr}.learn_inner_lrs:{args.learn_inner_lrs}.outer_lr:{args.outer_lr}.batch_size:{args.batch_size}'  # pylint: disable=line-too-long
    print(f'log_dir: {log_dir}')
    writer = tensorboard.SummaryWriter(log_dir=log_dir)

    maml = MAML(
        args.num_way,
        args.num_inner_steps,
        args.inner_lr,
        args.learn_inner_lrs,
        args.outer_lr,
        log_dir
    )

    if args.checkpoint_step > -1:
        maml.load(args.checkpoint_step)
    else:
        print('Checkpoint loading skipped.')

    if not args.test:
        num_training_tasks = args.batch_size * (args.num_train_iterations -
                                                args.checkpoint_step - 1)
        print(
            f'Training on {num_training_tasks} tasks with composition: '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_train = get_omniglot_dataloader(
            'train',
            args.batch_size,
            args.num_way,
            args.num_support,
            args.num_query,
            num_training_tasks
        )

        dataloader_val = get_omniglot_dataloader(
            'val',
            args.batch_size,
            args.num_way,
            args.num_support,
            args.num_query,
            args.batch_size * 4
        )

        maml.train(
            dataloader_train,
            dataloader_val,
            writer
        )

    else:
        print(
            f'Testing on tasks with composition '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_test = get_omniglot_dataloader(
            'test',
            1,
            args.num_way,
            args.num_support,
            args.num_query,
            NUM_TEST_TASKS
        )
        maml.test(dataloader_test)

In [19]:
class Args:
  def __init__(self):
    self.log_dir = None
    self.num_way = 5
    self.num_support = 1
    self.num_query = 15
    self.num_inner_steps = 10
    self.inner_lr = 0.4
    self.learn_inner_lrs = False
    self.outer_lr = 0.1
    self.batch_size = 16
    self.num_train_iterations = 15_000
    self.test = False
    self.checkpoint_step = -1

In [20]:
args = Args()

In [21]:
main(args)

log_dir: ./logs/maml/omniglot.way:5.support:1.query:15.inner_steps:10.inner_lr:0.4.learn_inner_lrs:False.outer_lr:0.1.batch_size:16
Checkpoint loading skipped.
Training on 240000 tasks with composition: num_way=5, num_support=1, num_query=15
Starting training at iteration 0.


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Iteration 0: loss: 1.608, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.209


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.192
Saved checkpoint.
Iteration 10: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.275, post-adaptation query accuracy: 0.243
Iteration 20: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.125, post-adaptation query accuracy: 0.153
Iteration 30: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.232
Iteration 40: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.184
Iteration 50: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.223


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.259, post-adaptation query accuracy: 0.222
Iteration 60: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.207
Iteration 70: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.178
Iteration 80: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.247
Iteration 90: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.243
Iteration 100: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.199


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.181, post-adaptation query accuracy: 0.179
Saved checkpoint.
Iteration 110: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.163, post-adaptation query accuracy: 0.177
Iteration 120: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.125, post-adaptation query accuracy: 0.178
Iteration 130: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.226
Iteration 140: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.193
Iteration 150: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.194


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.134, post-adaptation query accuracy: 0.163
Iteration 160: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.183
Iteration 170: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.198
Iteration 180: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.253
Iteration 190: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.275, post-adaptation query accuracy: 0.256
Iteration 200: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.198


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.191, post-adaptation query accuracy: 0.186
Saved checkpoint.
Iteration 210: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.199
Iteration 220: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.173
Iteration 230: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.193
Iteration 240: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.150, post-adaptation query accuracy: 0.182
Iteration 250: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.250, post-adaptation query accuracy: 0.233


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.209, post-adaptation query accuracy: 0.207
Iteration 260: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.203
Iteration 270: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.163, post-adaptation query accuracy: 0.203
Iteration 280: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.250, post-adaptation query accuracy: 0.213
Iteration 290: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.202
Iteration 300: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.150, post-adaptation query accuracy: 0.205


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.278, post-adaptation query accuracy: 0.240
Saved checkpoint.
Iteration 310: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.138, post-adaptation query accuracy: 0.168
Iteration 320: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.138, post-adaptation query accuracy: 0.141
Iteration 330: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.163, post-adaptation query accuracy: 0.173
Iteration 340: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.218
Iteration 350: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.138, post-adaptation query accuracy: 0.141


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.178, post-adaptation query accuracy: 0.198
Iteration 360: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.206
Iteration 370: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.125, post-adaptation query accuracy: 0.197
Iteration 380: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.197
Iteration 390: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.150, post-adaptation query accuracy: 0.146
Iteration 400: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.206


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.202
Saved checkpoint.
Iteration 410: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.250, post-adaptation query accuracy: 0.235
Iteration 420: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.233
Iteration 430: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.191
Iteration 440: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.207
Iteration 450: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.205


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.206, post-adaptation query accuracy: 0.204
Iteration 460: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.182
Iteration 470: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.253
Iteration 480: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.184
Iteration 490: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.168
Iteration 500: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.179


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.203, post-adaptation query accuracy: 0.185
Saved checkpoint.
Iteration 510: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.177
Iteration 520: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.163, post-adaptation query accuracy: 0.163
Iteration 530: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.196
Iteration 540: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.173
Iteration 550: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.113, post-adaptation query accuracy: 0.180


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.187
Iteration 560: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.138, post-adaptation query accuracy: 0.216
Iteration 570: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.288, post-adaptation query accuracy: 0.256
Iteration 580: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.163, post-adaptation query accuracy: 0.223
Iteration 590: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.202
Iteration 600: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.125, post-adaptation query accuracy: 0.154


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.203
Saved checkpoint.
Iteration 610: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.175
Iteration 620: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.125, post-adaptation query accuracy: 0.157
Iteration 630: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.163, post-adaptation query accuracy: 0.210
Iteration 640: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.189
Iteration 650: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.176


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.206, post-adaptation query accuracy: 0.209
Iteration 660: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.275, post-adaptation query accuracy: 0.224
Iteration 670: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.263, post-adaptation query accuracy: 0.223
Iteration 680: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.220
Iteration 690: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.181
Iteration 700: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.138, post-adaptation query accuracy: 0.144


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.231, post-adaptation query accuracy: 0.210
Saved checkpoint.
Iteration 710: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.162
Iteration 720: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.166
Iteration 730: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.163
Iteration 740: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.183
Iteration 750: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.195


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.219, post-adaptation query accuracy: 0.216
Iteration 760: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.163, post-adaptation query accuracy: 0.195
Iteration 770: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.234
Iteration 780: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.204
Iteration 790: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.190
Iteration 800: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.219


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.216, post-adaptation query accuracy: 0.203
Saved checkpoint.
Iteration 810: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.224
Iteration 820: loss: 1.606, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.218
Iteration 830: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.182
Iteration 840: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.225
Iteration 850: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.251


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.209, post-adaptation query accuracy: 0.207
Iteration 860: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.150, post-adaptation query accuracy: 0.187
Iteration 870: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.237
Iteration 880: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.241
Iteration 890: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.170
Iteration 900: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.138, post-adaptation query accuracy: 0.198


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.166, post-adaptation query accuracy: 0.188
Saved checkpoint.
Iteration 910: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.169
Iteration 920: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.243
Iteration 930: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.215
Iteration 940: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.252
Iteration 950: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.220


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.178, post-adaptation query accuracy: 0.206
Iteration 960: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.188, post-adaptation query accuracy: 0.191
Iteration 970: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.150, post-adaptation query accuracy: 0.173
Iteration 980: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.233
Iteration 990: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.199
Iteration 1000: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.242


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.200, post-adaptation query accuracy: 0.190
Saved checkpoint.
Iteration 1010: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.150, post-adaptation query accuracy: 0.162
Iteration 1020: loss: 1.610, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.138, post-adaptation query accuracy: 0.161
Iteration 1030: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.213, post-adaptation query accuracy: 0.232
Iteration 1040: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.238, post-adaptation query accuracy: 0.213
Iteration 1050: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.175, post-adaptation query accuracy: 0.213


  x = imageio.imread(file_path)
  x = imageio.imread(file_path)


Validation: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.191, post-adaptation query accuracy: 0.188
Iteration 1060: loss: 1.609, pre-adaptation support accuracy: 0.200, post-adaptation support accuracy: 0.225, post-adaptation query accuracy: 0.209


KeyboardInterrupt: 

In [24]:
%load_ext tensorboard


In [25]:
!tensorboard --logdir="logs/"


NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.13.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [20]:
"""Implementation of model-agnostic meta-learning for Omniglot."""

import argparse
import os

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import autograd  # pylint: disable=unused-import
from torch.utils import tensorboard


NUM_INPUT_CHANNELS = 1
NUM_HIDDEN_CHANNELS = 64
KERNEL_SIZE = 3
NUM_CONV_LAYERS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SUMMARY_INTERVAL = 10
SAVE_INTERVAL = 100
LOG_INTERVAL = 10
VAL_INTERVAL = LOG_INTERVAL * 5
NUM_TEST_TASKS = 600


class MAML(tf.keras.Model):

  def __init__(self, dim_input=1, dim_output=1,
               num_inner_updates=1,
               inner_update_lr=0.4, num_filters=32, k_shot=5, learn_inner_update_lr=False):
    super(MAML, self).__init__()
    self.dim_input = dim_input
    self.dim_output = dim_output
    self.inner_update_lr = inner_update_lr
    self.loss_func = cross_entropy_loss
    self.dim_hidden = num_filters
    self.channels = 1
    self.img_size = int(np.sqrt(self.dim_input/self.channels))
    self._num_inner_steps = 10

    # outputs_ts[i] and losses_ts_post[i] are the output and loss after i+1 inner gradient updates
    losses_tr_pre, outputs_tr, losses_ts_post, outputs_ts = [], [], [], []
    accuracies_tr_pre, accuracies_ts = [], []

    # for each loop in the inner training loop
    outputs_ts = [[]]*num_inner_updates
    losses_ts_post = [[]]*num_inner_updates
    accuracies_ts = [[]]*num_inner_updates

    # Define the weights - these should NOT be directly modified by the
    # inner training loop
    tf.random.set_seed(100)
    self.conv_layers = ConvLayers(self.channels, self.dim_hidden, self.dim_output, self.img_size)

    self.learn_inner_update_lr = learn_inner_update_lr

  def call(self, images, parameters):

    outputs = self.conv_layers(images, parameters)
    return outputs

  def _inner_loop(self, images, labels, train):   # pylint: disable=unused-argument

    accuracies = []

        # ********************************************************
        # ******************* YOUR CODE HERE *********************
        # ********************************************************
        # TODO: finish implementing this method.
        # This method computes the inner loop (adaptation) procedure for one
        # task. It also scores the model along the way.
        # Make sure to populate accuracies and update parameters.
        # Use F.cross_entropy to compute classification losses.
        # Use util.score to compute accuracies.

    weights = self.conv_layers.conv_weights

    # here we are doing \phi_i = \theta - inner_lr * grad(\theta, L, D_{tr})
    for _ in range(self._num_inner_steps):

      with tf.GradientTape(persistent=True) as tape:

        logits = self.call(images, weights)
        print(logits)
        loss = self.loss_func(pred = logits, label=labels)


      grads = tape.gradient(loss, list(weights.values()))
      gradients = dict(zip(weights.keys(), grads))

      if self.learn_inner_update_lr:
        weights = dict(zip(weights.keys(), [weights[key]-self.inner_update_lr_dict[key][0]*gradients[key] for key in weights.keys()]))
      else:
        weights = dict(zip(weights.keys(), [weights[key] - self.inner_update_lr*gradients[key] for key in weights.keys()]))

      acc = accuracy(labels, logits)
      accuracies.append(acc)

    final_logits = self.call(images, weights)
    final_acc = accuracy(final_logits, labels)
    accuracies.append(final_acc)

    return weights, accuracy



In [21]:
maml_model = MAML()




In [22]:
maml_model._inner_loop(tf.zeros((10, 100, 100, 3)), tf.zeros((10)), False)

tf.Tensor(

[[nan]

 [nan]

 [nan]

 ...

 [nan]

 [nan]

 [nan]], shape=(300000, 1), dtype=float32)


InvalidArgumentError: ignored