In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F

class MAMLLearner(nn.Module):
      def __init__(self):
        super(MAMLLearner, self).__init__()
        self.layers = nn.ModuleList()

        self.layers+=[nn.Conv2d(3, 8,  kernel_size=3) , 
                      nn.BatchNorm2d(8) ,
                      nn.ReLU(inplace=True)]
        self.layers+=[nn.Conv2d(8, 8,  kernel_size=3, stride=2), 
                      nn.BatchNorm2d(8) ,
                      nn.ReLU(inplace=True)]
        self.layers+=[nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1), 
                      nn.BatchNorm2d(16), 
                      nn.ReLU(inplace=True)]
        self.layers+=[nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1), 
                      nn.BatchNorm2d(16), 
                      nn.ReLU(inplace=True)]
        self.layers+=[nn.Conv2d(16, 32,  kernel_size=3), 
                      nn.BatchNorm2d(32) ,
                      nn.ReLU(inplace=True)]
        self.layers+=[nn.Conv2d(32, 32,  kernel_size=3, stride=2), 
                      nn.BatchNorm2d(32) ,
                      nn.ReLU(inplace=True)]

        self.fc_layers = nn.Sequential(
            nn.Linear(32*5*5, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10)
        )

        self.weight_decay = 0.01

      def forward(self, x, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        for layer in self.layers:
            x = layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return F.log_softmax(x, dim=1)


      def weight_regularization_loss(self):
        l2_reg = torch.tensor(0., requires_grad=True)
        for name, param in self.named_parameters():
            if 'bias' not in name:
                l2_reg = l2_reg + torch.norm(param, p=2)
        return self.weight_decay * l2_reg
        
      def meta_named_pars(self):
          for name, param in self.named_parameters():
              if param.requires_grad:
                  yield name, param
  
      def meta_params(self):
          return [p for _, p in self.meta_named_pars()]
        
  

In [2]:
# Train without First-order approximation
def maml_train(model, device, train_tasks, test_tasks, optimizer, epoch, num_inner_updates=50, display=True):
    model.train()
    outer_loss = 0

    for task_idx, (train_loader, test_loader) in enumerate(zip(train_tasks, test_tasks)):
        # Inner loop
        inner_optimizer = optim.Adam(model.parameters(), lr=0.0005)
        for i in range(num_inner_updates):
            data, target = next(iter(train_loader))
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = F.cross_entropy(output, target)
            model.zero_grad()
            grads = torch.autograd.grad(loss, model.meta_params(), create_graph= True)
            params = OrderedDict()
            for (name, param), grad in zip(model.meta_named_pars(),grads):
              params[name] = param - 0.01 * grad.detach()
              params[name].retain_grad()
            # perform inner update
            inner_optimizer.zero_grad()
            loss.backward()
            for name, param in model.named_parameters():
                if name in params:
                    param.grad = params[name].grad
            inner_optimizer.step()

        # Outer loop
        optimizer.zero_grad()
        test_data, test_target = next(iter(test_loader))
        test_data, test_target = test_data.to(device), test_target.to(device)
        test_output = model(test_data)
        outer_loss += F.cross_entropy(test_output, test_target)
    outer_loss.backward()
    optimizer.step()

    if display:
        print('Train Epoch: {} \tLoss: {:.6f}'.format(
            epoch, outer_loss.item()))

In [None]:
# Train with fFirst-order approximationFirst-order approximation
def maml_train_FOA(model, device, train_tasks, test_tasks, optimizer, epoch, num_inner_updates=50, display=True):
    model.train()
    outer_loss = 0

    for task_idx, (train_loader, test_loader) in enumerate(zip(train_tasks, test_tasks)):
        # Inner loop
        inner_optimizer = optim.Adam(model.parameters(), lr=0.0001)
        for i in range(num_inner_updates):
            data, target = next(iter(train_loader))
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = F.cross_entropy(output, target)
            model.zero_grad()
            loss.backward()
            inner_optimizer.step()

        # Outer loop
        test_data, test_target = next(iter(test_loader))
        test_data, test_target = test_data.to(device), test_target.to(device)

        # Compute meta-gradients
        model.zero_grad()
        output = model(test_data)
        outer_loss = F.cross_entropy(output, test_target)
        outer_loss.backward()

        # Compute first-order approximation of meta-gradients
        meta_grads = []
        for param in model.parameters():
            meta_grads.append(param.grad.clone())
            param.grad.zero_()

        # Update model parameters using meta-gradients
        with torch.no_grad():
            for param, meta_grad in zip(model.parameters(), meta_grads):
                param -= 0.01 * meta_grad

        if display:
            print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, outer_loss.item()))

    optimizer.step()


In [15]:
def maml_test(model, device, test_tasks, num_inner_updates=5):
    model.eval()
    test_acc = []

    for task_idx, test_loader in enumerate(test_tasks):
        # Inner loop
        inner_optimizer = optim.Adam(model.parameters(), lr=0.0005)
        for i in range(num_inner_updates):
            data, target = next(iter(test_loader))
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = F.cross_entropy(output, target)
            model.zero_grad()
            grads = torch.autograd.grad(loss, model.meta_params(), create_graph=True)
            params = OrderedDict()
            for (name, param), grad in zip(model.meta_named_pars(), grads):
                params[name] = param - 0.01 * grad.detach()
                params[name].retain_grad()

            # perform inner update
            inner_optimizer.zero_grad()
            loss.backward()
            for name, param in model.named_parameters():
                if name in params:
                    param.grad = params[name].grad
            inner_optimizer.step()

        # Evaluate on test data
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        test_acc.append(correct / total)

    avg_acc = sum(test_acc) / len(test_acc)
    print('Test Accuracy: {:.2f}%'.format(avg_acc * 100))
    return avg_acc


In [17]:
from numpy.random import RandomState
import numpy as np

import torch.optim as optim
from torch.utils.data import Subset
from torchvision.datasets import MNIST, SVHN, STL10, Omniglot, Caltech101, Flowers102
from collections import OrderedDict
from torchvision import datasets, transforms
import torchvision.models as models

normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))

transform_val = transforms.Compose([transforms.ToTensor(), normalize]) #careful to keep this one same
transform_train = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()])# , normalize

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

svhn_train = SVHN(root="./data", split="train", transform=transform_train, download=True)
svhn_loader = torch.utils.data.DataLoader(svhn_train, batch_size=128, shuffle=True)

svhn_test = SVHN(root="./data", split="test", transform=transform_train, download=True)
svhn_test_loader = torch.utils.data.DataLoader(svhn_test, batch_size=128, shuffle=True)

stl10_train = STL10(root="./data", split="train", transform=transform_train, download=True)
stl10_loader = torch.utils.data.DataLoader(stl10_train, batch_size=128, shuffle=True)

stl10_test = STL10(root="./data", split="test", transform=transform_train, download=True)
stl10_test_loader = torch.utils.data.DataLoader(stl10_test, batch_size=128, shuffle=True)


##### Cifar Data
cifar_data = datasets.CIFAR10(root='.',train=True, transform=transform_train, download=True)
cifar_data_test = datasets.CIFAR10(root='.',train=True, transform=transform_train, download=True)
    
#We need two copies of this due to weird dataset api 
cifar_data_val = datasets.CIFAR10(root='.',train=True, transform=transform_val, download=True)
    

accs = []

for seed in range(1, 5):
  prng = RandomState(seed)
  random_permute = prng.permutation(np.arange(0, 1000))
  classes =  prng.permutation(np.arange(0,10))
  indx_train = np.concatenate([np.where(np.array(cifar_data.targets) == classe)[0][random_permute[0:25]] for classe in classes[0:2]])
  indx_val = np.concatenate([np.where(np.array(cifar_data.targets) == classe)[0][random_permute[25:225]] for classe in classes[0:2]])


  train_data = Subset(cifar_data, indx_train)
  val_data = Subset(cifar_data_val, indx_val)

  print('Num Samples For Training %d Num Samples For Val %d'%(train_data.indices.shape[0],val_data.indices.shape[0]))
  
  train_loader = torch.utils.data.DataLoader(train_data,
                                             batch_size=128, 
                                             shuffle=True)

  val_loader = torch.utils.data.DataLoader(val_data,
                                           batch_size=128, 
                                           shuffle=False)

  train_tasks = [stl10_loader, svhn_loader]
  test_tasks = [stl10_test_loader, svhn_test_loader]
  val_tasks = [val_loader]
  # train_tasks = [train_loader]
  # test_tasks = [test_loader]


  model = MAMLLearner()
  model.to(device)

  optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

  for epoch in range(200): 
    maml_train(model, device, train_tasks, test_tasks, optimizer, epoch)

  accs.append(maml_test(model, device, val_tasks))

accs = np.array(accs)
print('Acc over 5 instances: %.2f +- %.2f'%(accs.mean(),accs.std()))
 

Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Num Samples For Training 50 Num Samples For Val 400
Train Epoch: 0 	Loss: 4.591238
Train Epoch: 1 	Loss: 4.577613


KeyboardInterrupt: ignored

In [16]:
accs.append(maml_test(model, device, val_tasks))

Test Accuracy: 5.00%
