In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy
from torch.autograd import Variable as V
import glob
from PIL import Image
from torchvision import transforms

# Data generation

Let's create our tasks where we train on K shots of N characters

In [4]:
class OmniTask:
    def __init__(self, K, N, noise_percent):
        # K, N as in N-way K-shot learning
        self.K = K
        self.N = N 
        self.noise_percent = noise_percent
        self.mini_train = None
        self.mini_test = None
        self.image_dir = "omniglot_dataset/images_background/*/*"
        self.trainTransform  = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()]) 
        self.mini_batch_size = 5
        
    def mini_train_set(self):
        if self.mini_train is None:
            #choose N tasks
            characters = random.sample(glob.glob(self.image_dir), self.N)
            
            train_set = []
            #choose K examples per task:
            for i,char in enumerate(characters):
                k_shots = random.sample(glob.glob(char+"/*"), self.K)
                train_set.extend([(self.trainTransform(Image.open(shot).convert('RGB')), i) for shot in k_shots])
            self.mini_train = random.sample(train_set, len(train_set))
            
        return self.mini_train
    
    def batched_mini_train_set(self):
        train_set = self.mini_train_set()
        shuffled = random.sample(train_set, len(train_set))
        
        batched = []
        current_xes = []
        current_yes = []
        for i in range(len(shuffled)):
            if (i%self.mini_batch_size==0 and i > 0) or (i == len(shuffled)-1):
                batched.append((torch.stack(current_xes), torch.LongTensor(current_yes)))
                current_xes = []
                current_yes = []
            current_xes.append(shuffled[i][0])
            current_yes.append(shuffled[i][1])

        return batched   
    
    def mini_test_set(self):
        pass
        #TODO: FIGURE OUT HOW MAML DOES THIS
    
    def eval_set(self, size=50):
        pass
        #TODO 

In [5]:
trial = OmniTask(5,5,0)
train_set = trial.batched_mini_train_set()

In [None]:
class DataGenerator:
    def __init__(self, size=50000, K = 5, N = 5, noise_percent=0):
        self.size = size
        self.K = K
        self.N = N
        self.noise_percent = noise_percent
        self.tasks = None 
        
    def generate_set(self):
        self.tasks = tasks = [OmniTask(self.K, self.N, self.noise_percent) for _ in range(self.size)]
        return tasks
    
    def shuffled_set(self):
        if self.tasks is None:
            self.generate_set()
        return random.sample(self.tasks, len(self.tasks))

# Model creation

In [None]:
class OmniglotNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

        self.conv = nn.Sequential(
            # 28 x 28 - 1
            nn.Conv2d(1, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 14 x 14 - 64
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 7 x 7 - 64
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 4 x 4 - 64
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 2 x 2 - 64
        )

        self.classifier = nn.Sequential(
            # 2 x 2 x 64 = 256
            nn.Linear(256, num_classes),
            nn.LogSoftmax(1)
        )

    def forward(self, x):
        out = x.view(-1, 1, 28, 28)
        out = self.conv(out)
        out = out.view(len(out), -1) # should be equivalent to out.view(-1, 256)
        out = self.classifier(out)
        return out

    def predict(self, prob):
        __, argmax = prob.max(1)
        return argmax

# Reptile Meta-Learning Algorithm

In [None]:
class MetaLearner():
    def __init__(self, higher_order=False, lr_inner=0.01, lr_outer=0.001, sgd_steps_inner=10):
        self.lr_inner = lr_inner
        self.lr_outer = lr_outer
        self.sgd_steps_inner = sgd_steps_inner
        self.higher_order = higher_order
        
    def inner_train(self, model, task, optimizer):
        batches = task.batched_mini_train_set()
        for x,y in batches:
            optimizer.zero_grad()
            predicted = model(x)
            loss = F.nll_loss(predicted, y)
            loss.backward()
            optimizer.step()
    
    def init_grad(self, model):
        for param in model.parameters():
            param.grad = torch.zeros_like(param)
        
class Reptile(MetaLearner):
    def __init__(self, lr_inner=0.01, lr_outer=0.001, sgd_steps_inner=10):
        super().__init__(False, lr_inner, lr_outer, sgd_steps_inner)
        
    def compute_store_gradients(self, target, current):
        current_weights = dict(current.named_parameters())
        target_weights = dict(target.named_parameters())
        gradients = {name: (current_weights[name].data - target_weights[name].data) / (self.sgd_steps_inner * self.lr_inner) for name in target_weights}

        for name in current_weights:
            current_weights[name].grad.data = gradients[name]

    def train(self, model, train_data):
        optimizer = torch.optim.Adam(model.parameters(), lr=self.lr_outer)
        self.init_grad(model)

        for i, task in enumerate(train_data.shuffled_set()):
            optimizer.zero_grad()

            inner_model = copy.deepcopy(model)
            inner_optim = torch.optim.SGD(inner_model.parameters(), lr=self.lr_inner)

            for _ in range(self.sgd_steps_inner):
                self.inner_train(inner_model, task, inner_optim)

            self.compute_store_gradients(inner_model, model)
            optimizer.step()

            if i % 100 == 0:
                print("iteration:", i)

In [None]:
reptile_model = OmniglotNet(num_classes=5)
reptile_learning_alg = Reptile()
train_data = DataGenerator()
reptile_learning_alg.train(reptile_model, train_data)


iteration: 0
iteration: 100
iteration: 200
iteration: 300
iteration: 400
iteration: 500
iteration: 600
iteration: 700
iteration: 800
iteration: 900
iteration: 1000
iteration: 1100
iteration: 1200
iteration: 1300
iteration: 1400
iteration: 1500
iteration: 1600
iteration: 1700
iteration: 1800
iteration: 1900
iteration: 2000
iteration: 2100
iteration: 2200
iteration: 2300
iteration: 2400
iteration: 2500
iteration: 2600
