In [None]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transform
import numpy as np
import copy

In [None]:
class Omniglot:
    def __init__(self, bsize=32, N=5, K=5, Q=15):
        """
        params:
            bsize = batch size
            N = N-way
            K = K-shot
            Q = Num of query
        """
        assert K+Q <= 20, "num of K + num of Q should be less than 20"
        self.bsize=bsize
        self.N=N
        self.K=K
        self.Q=Q
        
        self.dtrain=torchvision.datasets.Omniglot(
            root="./omniglot", background=True, download=True, transform = transform.Compose([transform.Resize([28,28], interpolation=2), 
                                                                                              transform.ToTensor()])
        )
        self.dtest=torchvision.datasets.Omniglot(
            root="./omniglot", background=False, download=True, transform = transform.Compose([transform.Resize([28,28], interpolation=2), 
                                                                                              transform.ToTensor()])
        )
        
        self.data_num = len(self.dtrain)+len(self.dtest) 
        self.cls_num = int(self.data_num/20)
                           
        print('Num of total cls :', self.cls_num)
        print('Num of total data :',  self.data_num)
        
    def get_task(self, mode='train'):
        """
        params:
            mode : 'train' or 'test'
        """
        if mode=='train':
            dset = self.dtrain
            cls_num = int(len(self.dtrain)/20)
        else:
            dset = self.dtest
            cls_num = int(len(self.dtest)/20)
            
        spt_xs=torch.zeros([self.bsize, self.N*self.K, 28, 28])
        spt_ys=torch.zeros([self.bsize, self.N*self.K], dtype=torch.int64)
        qry_xs=torch.zeros([self.bsize, self.N*self.Q, 28, 28])
        qry_ys=torch.zeros([self.bsize, self.N*self.Q], dtype=torch.int64)
        
        for i in range(self.bsize):
            n_way = np.random.choice(cls_num, self.N, replace=False)
            
            spt_x=torch.zeros([self.N, self.K,28,28])
            spt_y=torch.zeros([self.N, self.K])
            qry_x=torch.zeros([self.N, self.Q,28,28])
            qry_y=torch.zeros([self.N, self.Q])
             
            for j, idx in enumerate(n_way):
                spt_x_, _ = zip(*[dset[i] for i in range(idx*20, idx*20+self.K)])
                spt_x[j] = torch.stack(spt_x_).resize(self.K,28,28)
                spt_y[j] = torch.tensor([j for k in range(self.K)])
                qry_x_, _ = zip(*[dset[i] for i in range(idx*20+self.K, idx*20+self.K+self.Q)])
                perm = torch.randperm(self.N * self.Q)
                qry_x[j] = torch.stack(qry_x_).resize(self.Q,28,28)
                qry_y[j] = torch.tensor([j for k in range(self.Q)])
            
            perm = torch.randperm(self.N * self.K)
            spt_xs[i] = spt_x.reshape(self.N * self.K, 28, 28)[perm]
            spt_ys[i] = spt_y.reshape(self.N * self.K)[perm]
            perm = torch.randperm(self.N * self.Q)
            qry_xs[i] = qry_x.reshape(self.N * self.Q, 28, 28)[perm]
            qry_ys[i] = qry_y.reshape(self.N * self.Q)[perm]
        
        return spt_xs, spt_ys, qry_xs, qry_ys



In [None]:
class Learner(nn.Module):
    def __init__(self, N):
        nn.Module.__init__(self)
        """
        params:
            N : Num of class to classify
        """
        self.N = N
        
        in_channel = 1
        out_channel = 64
        
        layers = []
        #4 layers with BN and ReLU for feature extractor
        for i in range(4):
            layers += [nn.Conv2d(in_channel, out_channel, 3, 2, 1), nn.BatchNorm2d(out_channel), nn.ReLU()]
            in_channel=out_channel
        self.layers = nn.Sequential(*layers)
        
        #N-way classfier
        self.classifier = nn.Sequential(
            nn.Linear(256, N),
            nn.LogSoftmax(1)
        )

    def forward(self, x):
        """
        params:
            x : input image to classify
        """
        out = x.view(-1, 1, 28, 28)
        out = self.layers(out)
        out = out.view(len(out), -1)
        out = self.classifier(out)
        return out

    def pred(self, x, y):
        """
        params:
            x : input image to predict
            y : label for x
        """
        x=self.forward(x)
        _, pred = x.max(1)
        accs = (pred==y).sum()/float(len(y))
        return pred, accs

In [None]:
class MetaLearner(nn.Module):
    def __init__(self, bsize=32, N=5, K=5, Q=15):
        """
        params:
            bsize : batchsize
            N, K : N-way k-shot
            Q : Num of query data(for each class)
        """
        nn.Module.__init__(self)
        self.bsize=32
        self.N=N
        self.K=5
        self.Q=Q
        self.net = Learner(N)
        
    def forward(self, x):
        return self.net(x)

In [None]:
def reptile_train(meta, dset, meta_lr, lr, train_num=5, epoch=1, bsize=32):
    """
    params:
        meta : Meta learner module
        dset : omniglot dataset
        meta lr : learning rate for meta learner
        lr : learning rate for learner
        train_num : num of train for each task
        epoch : epcoh
        bsize : batchszie
    """
    CE = nn.CrossEntropyLoss()
    
    #get the task from dset
    mspt_x, mspt_y, mqry_x, mqry_y = dset.get_task('test')
    mspt_x = torch.autograd.Variable(mspt_x)
    mspt_y = torch.autograd.Variable(mspt_y)
    mqry_x = torch.autograd.Variable(mqry_x)
    mqry_y = torch.autograd.Variable(mqry_y)
    
    if torch.cuda.is_available():
        CE=CE.cuda()
        mspt_x=mspt_x.cuda()
        mspt_y=mspt_y.cuda()
        mqry_x=mqry_x.cuda()
        mqry_y=mqry_y.cuda()
   
    #train for metalearner
    for i in range(epoch):
        meta_lr = meta_lr*(epoch-i)/epoch
        
        meta.train()
        
        spt_x, spt_y, qry_x, qry_y = dset.get_task('train')
        spt_x = torch.autograd.Variable(spt_x)
        spt_y = torch.autograd.Variable(spt_y)
        qry_x = torch.autograd.Variable(qry_x)
        qry_y = torch.autograd.Variable(qry_y) 
       
        if torch.cuda.is_available(): 
            spt_x=spt_x.cuda()
            spt_y=spt_y.cuda()
            qry_x=qry_x.cuda()
            qry_y=qry_y.cuda()
        
        meta_param = meta.net.state_dict()
        gradient = {name : 0 for name in meta_param}
        
        for j in range(bsize):
            learner = copy.deepcopy(meta.net)
            optim = torch.optim.SGD(learner.parameters(), lr = lr)
            learner.train()
            
            #train each task for learner
            for k in range(train_num):
                score = learner(spt_x[j])
                loss = CE(score, spt_y[j])
                optim.zero_grad()
                loss.backward()
                optim.step()            
            
            #Theta = Theta + epsilon * (1/batch size) * sigma(W - Theta), Equation in paper 
            learner_param = learner.state_dict()
            for name in gradient:
                gradient[name] = gradient[name] + (learner_param[name] - meta_param[name])
                
        meta.net.load_state_dict(({name : meta_param[name] + meta_lr * (gradient[name] / bsize) for name in meta_param}))
        
        
        #evaluation
        if i%10==0:
            pre_acc=0
            after_acc=0
            temp=0
            
            losses=0
            
            for j in range(bsize):
                learner = copy.deepcopy(meta.net)
                optim = torch.optim.SGD(learner.parameters(), lr = lr)
                
                #accuracy before learning
                _, acc= learner.pred(mqry_x[j], mqry_y[j])
                pre_acc=pre_acc+acc
                
                #train the learner model
                learner.train()
                
                for k in range(train_num):
                    score = learner(mspt_x[j])
                    loss = CE(score, mspt_y[j])
                    losses=losses+loss

                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    
                learner.eval()
                
                #accuracy after t
                _, acc= learner.pred(mqry_x[j], mqry_y[j])
                after_acc = after_acc+acc
                
            print("Epoch ",i, ":")
            print("loss :", losses/bsize)
            print("accuracy before training : ", pre_acc/bsize)
            print("accuracy after training : ", after_acc/bsize)
    
        

In [None]:
bsize=32
N=5
K=5
Q=5
epoch=1000

omniglot = Omniglot(bsize=bsize, N=N, K=K, Q=Q)
meta = MetaLearner(bsize=bsize, N=N,K=K,Q=Q)

if torch.cuda.is_available():
    print("cuda is on")
    meta.cuda()
    
reptile_train(meta=meta, dset=omniglot, meta_lr= 0.5, lr=0.01, train_num=5, epoch=epoch, bsize=bsize)