# LSTM Meta Learner on MiniImageNet Dataset

Please Download data using link and save it in save folder as of this notebook after extracting: https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR


In [1]:
#import libraries
from __future__ import division, print_function, absolute_import
import os
import re
import pdb
import copy
import glob
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import PIL.Image as PILI
from tqdm import tqdm
from collections import OrderedDict
import random
import logging
import os
from torchvision.datasets.utils import download_file_from_google_drive, extract_archive

#### Step 1: Data Loader 

In [2]:
class EpisodeDataset(data.Dataset):

    def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):
        """Args:
            root (str): path to data
            phase (str): train, val or test
            n_shot (int): how many examples per class for training (k/n_support)
            n_eval (int): how many examples per class for evaluation
                - n_shot + n_eval = batch_size for data.DataLoader of ClassDataset
            transform (torchvision.transforms): data augmentation
        """
        root = os.path.join(root, phase)
        self.labels = sorted(os.listdir(root))
        images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]
        self.episode_loader = [data.DataLoader(
            ClassDataset(images=images[idx], label=idx, transform=transform),
            batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]

    def __getitem__(self, idx):
        return next(iter(self.episode_loader[idx]))

    def __len__(self):
        return len(self.labels)

In [3]:
class ClassDataset(data.Dataset):

    def __init__(self, images, label, transform=None):
        """Args:
            images (list of str): each item is a path to an image of the same label
            label (int): the label of all the images
        """
        self.images = images
        self.label = label
        self.transform = transform

    def __getitem__(self, idx):
        image = PILI.open(self.images[idx]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        return image, self.label

    def __len__(self):
        return len(self.images)

In [4]:
class EpisodicSampler(data.Sampler):

    def __init__(self, total_classes, n_class, n_episode):
        self.total_classes = total_classes
        self.n_class = n_class
        self.n_episode = n_episode

    def __iter__(self):
        for i in range(self.n_episode):
            yield torch.randperm(self.total_classes)[:self.n_class]

    def __len__(self):
        return self.n_episode

In [5]:
def prepare_data(args):

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    transform1= transforms.Compose([
            transforms.Resize(args['image_size'] * 8 // 7),
            transforms.CenterCrop(args['image_size']),
            transforms.ToTensor(),
            normalize])
    
    transform2 = transforms.Compose([
            transforms.RandomResizedCrop(args['image_size']),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.2),
            transforms.ToTensor(),
            normalize])
    
    train_set = EpisodeDataset(args['data_root'], 'train', args['n_shot'], args['n_eval'],
        transform=transform2)

    val_set = EpisodeDataset(args['data_root'], 'val', args['n_shot'], args['n_eval'],
        transform=transform1)

    test_set = EpisodeDataset(args['data_root'], 'test', args['n_shot'], args['n_eval'],
        transform=transform1)

    train_loader = data.DataLoader(train_set, num_workers=0, pin_memory=True,
        batch_sampler=EpisodicSampler(len(train_set), args['n_class'], args['episode']))

    val_loader = data.DataLoader(val_set, num_workers=0, pin_memory=True,
        batch_sampler=EpisodicSampler(len(val_set), args['n_class'], args['episode_val']))

    test_loader = data.DataLoader(test_set, num_workers=0, pin_memory=True,
        batch_sampler=EpisodicSampler(len(test_set), args['n_class'], args['episode_val']))
    return train_loader, val_loader, test_loader


#### Step 2: Learner 

In [6]:
class Learner(nn.Module):

    def __init__(self, image_size, bn_eps, bn_momentum, n_classes):
        super(Learner, self).__init__()
        self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 32, 3, padding=1)),
            ('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu1', nn.ReLU(inplace=False)),
            ('pool1', nn.MaxPool2d(2)),

            ('conv2', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu2', nn.ReLU(inplace=False)),
            ('pool2', nn.MaxPool2d(2)),

            ('conv3', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu3', nn.ReLU(inplace=False)),
            ('pool3', nn.MaxPool2d(2)),

            ('conv4', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu4', nn.ReLU(inplace=False)),
            ('pool4', nn.MaxPool2d(2))]))
        })

        clr_in = image_size // 2**4
        self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)})
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.model.features(x)
        x = torch.reshape(x, [x.size(0), -1])
        outputs = self.model.cls(x)
        return outputs

    def get_flat_params(self):
        return torch.cat([p.view(-1) for p in self.model.parameters()], 0)

    def copy_flat_params(self, cI):
        idx = 0
        for p in self.model.parameters():
            plen = p.view(-1).size(0)
            p.data.copy_(cI[idx: idx+plen].view_as(p))
            idx += plen

    def transfer_params(self, learner_w_grad, cI):
        # Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters
        #  are going to be replaced by cI
        self.load_state_dict(learner_w_grad.state_dict())
        #  replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore).
        idx = 0
        for m in self.model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
                wlen = m._parameters['weight'].view(-1).size(0)
                m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone()
                idx += wlen
                if m._parameters['bias'] is not None:
                    blen = m._parameters['bias'].view(-1).size(0)
                    m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone()
                    idx += blen

    def reset_batch_stats(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.reset_running_stats()

#### Step 4: Meta Learner 

In [7]:
class MetaLSTMCell(nn.Module):
    """C_t = f_t * C_{t-1} + i_t * \tilde{C_t}"""
    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLSTMCell, self).__init__()
        """Args:
            input_size (int): cell input size, default = 20
            hidden_size (int): should be 1
            n_learner_params (int): number of learner's parameters
        """
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_learner_params = n_learner_params
        self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))
        self.bI = nn.Parameter(torch.Tensor(1, hidden_size))
        self.bF = nn.Parameter(torch.Tensor(1, hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.uniform_(weight, -0.01, 0.01)

        # want initial forget value to be high and input value to be low so that 
        #  model starts with gradient descent
        nn.init.uniform_(self.bF, 4, 6)
        nn.init.uniform_(self.bI, -5, -4)

    def init_cI(self, flat_params):
        self.cI.data.copy_(flat_params.unsqueeze(1))

    def forward(self, inputs, hx=None):
        """Args:
            inputs = [x_all, grad]:
                x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM
                grad (torch.Tensor of size [n_learner_params]): gradients from learner
            hx = [f_prev, i_prev, c_prev]:
                f (torch.Tensor of size [n_learner_params, 1]): forget gate
                i (torch.Tensor of size [n_learner_params, 1]): input gate
                c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters
        """
        x_all, grad = inputs
        batch, _ = x_all.size()

        if hx is None:
            f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)
            i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)
            c_prev = self.cI
            hx = [f_prev, i_prev, c_prev]

        f_prev, i_prev, c_prev = hx
        
        # f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)
        f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)
        # i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)
        i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)
        # next cell/params
        c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)

        return c_next, [f_next, i_next, c_next]

    def extra_repr(self):
        s = '{input_size}, {hidden_size}, {n_learner_params}'
        return s.format(**self.__dict__)


class MetaLearner(nn.Module):

    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLearner, self).__init__()
        """Args:
            input_size (int): for the first LSTM layer, default = 4
            hidden_size (int): for the first LSTM layer, default = 20
            n_learner_params (int): number of learner's parameters
        """
        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
        self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)

    def forward(self, inputs, hs=None):
        """Args:
            inputs = [loss, grad_prep, grad]
                loss (torch.Tensor of size [1, 2])
                grad_prep (torch.Tensor of size [n_learner_params, 2])
                grad (torch.Tensor of size [n_learner_params])
            hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]
        """
        loss, grad_prep, grad = inputs
        loss = loss.expand_as(grad_prep)
        inputs = torch.cat((loss, grad_prep), 1)   # [n_learner_params, 4]

        if hs is None:
            hs = [None, None]

        lstmhx, lstmcx = self.lstm(inputs, hs[0])
        flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1])

        return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]

#### Step 4: utils 

In [8]:
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res[0].item() if len(res) == 1 else [r.item() for r in res]

In [9]:
def preprocess_grad_loss(x):
    p = 10
    indicator = (x.abs() >= np.exp(-p)).to(torch.float32)

    # preproc1
    x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1
    # preproc2
    x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x
    return torch.stack((x_proc1, x_proc2), 1)

#### Step 5: Main

In [10]:
def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args):
    for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)):
        train_input = episode_x[:, :args['n_shot']].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_shot, :]
        train_target = torch.LongTensor(np.repeat(range(args['n_class']), args['n_shot'])) # [n_class * n_shot]
        test_input = episode_x[:, args['n_shot']:].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_eval, :]
        test_target = torch.LongTensor(np.repeat(range(args['n_class']), args['n_eval'])) # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.eval()
        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)

        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)
        print ("Validation/Test Loss: {}, and Accuracy {}".format(loss.item(), acc))
        

    return acc


In [11]:
def train_learner(learner_w_grad, metalearner, train_input, train_target, args):
    cI = metalearner.metalstm.cI.data
    hs = [None]
    for _ in range(args['epoch']):
        for i in range(0, len(train_input), args['batch_size']):
            x = train_input[i:i+args['batch_size']]
            y = train_target[i:i+args['batch_size']]

            # get the loss/grad
            learner_w_grad.copy_flat_params(cI)
            output = learner_w_grad(x)
            loss = learner_w_grad.criterion(output, y)
            acc = accuracy(output, y)
            learner_w_grad.zero_grad()
            loss.backward()
            grad = torch.cat([p.grad.data.view(-1) / args['batch_size'] for p in learner_w_grad.parameters()], 0)

            # preprocess grad & loss and metalearner forward
            grad_prep = preprocess_grad_loss(grad)  # [n_learner_params, 2]
            loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2]
            metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)]
            cI, h = metalearner(metalearner_input, hs[-1])
            hs.append(h)
            print("training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}".format(loss, acc, torch.mean(grad)))

    return cI


In [None]:
def main(args):
    seed = 2019
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Get data
    train_loader, val_loader, test_loader = prepare_data(args)
    
    # Set up learner, meta-learner
    learner_w_grad = Learner(args['image_size'], args['bn_eps'], args['bn_momentum'], args['n_class'])
    learner_wo_grad = copy.deepcopy(learner_w_grad)
    metalearner = MetaLearner(args['input_size'], args['hidden_size'], learner_w_grad.get_flat_params().size(0))
    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())

    # Set up loss, optimizer, learning rate scheduler
    optim = torch.optim.Adam(metalearner.parameters(), args['lr'])

    if args['mode'] == 'test':
        _ = meta_test(args['episode'], test_loader, learner_w_grad, learner_wo_grad, metalearner, args)
        return
    best_acc = 0.0
    # Meta-training
    for eps, (episode_x, episode_y) in enumerate(train_loader):
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
        train_input = episode_x[:, :args['n_shot']].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_shot, :]
        train_target = torch.LongTensor(np.repeat(range(args['n_shot']), args['n_shot'])) # [n_class * n_shot]
        test_input = episode_x[:, args['n_shot']:].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_eval, :]
        test_target = torch.LongTensor(np.repeat(range(args['n_shot']), args['n_eval'])) # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)

        # Train meta-learner with validation loss
        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)
        
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args['grad_clip'])
        optim.step()
        
        print (eps)
        # Meta-validation
        if eps % 10 == 0 and eps != 0:
            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args)
            if acc > best_acc:
                best_acc = acc
                print ("* Best accuracy so far *\n")

    print("Done")

In [None]:
if __name__ == '__main__':
    args_train={'mode':'train','n_shot':5,'n_eval':15,'n_class':5,'input_size':4,'hidden_size':20,'lr':1e-3,'episode':50,
      'episode_val':50,'epoch':8,'batch_size':16,'image_size':84,'grad_clip':0.25,'bn_momentum': 0.95,'bn_eps': 1e-3,
       'data': "miniimagenet",'data_root': "./miniImagenet/", 'resume': None}
    
    
    args_test={'mode':'test','n_shot':5,'n_eval':15,'n_class':5,'input_size':4,'hidden_size':20,'lr':1e-3,'episode':50,
      'episode_val':10,'epoch':8,'batch_size':16,'image_size':84,'grad_clip':0.25,'bn_momentum': 0.95,'bn_eps': 1e-3,
       'data': "miniimagenet",'data_root': "./miniImagenet/", 'resume': None}
    
    
    print (" BEGIN TRAINING: ")
    main(args_train)
    
    
    print ("BEGIN TESTING")
    main(args_test)
    
    

 BEGIN TRAINING: 
training loss: 1.668335 acc: 25.000, mean grad: 0.000055
training loss: 1.906144 acc: 11.111, mean grad: -0.000132
training loss: 1.607665 acc: 31.250, mean grad: 0.000069
training loss: 1.632667 acc: 11.111, mean grad: -0.000136
training loss: 1.566477 acc: 31.250, mean grad: 0.000058
training loss: 1.440862 acc: 22.222, mean grad: -0.000116
training loss: 1.531511 acc: 25.000, mean grad: 0.000073
training loss: 1.299537 acc: 44.444, mean grad: -0.000100
training loss: 1.498922 acc: 25.000, mean grad: 0.000054
training loss: 1.192788 acc: 66.667, mean grad: -0.000099
training loss: 1.468537 acc: 18.750, mean grad: 0.000059
training loss: 1.111149 acc: 66.667, mean grad: -0.000090
training loss: 1.440375 acc: 25.000, mean grad: 0.000043
training loss: 1.046557 acc: 66.667, mean grad: -0.000107
training loss: 1.413390 acc: 31.250, mean grad: 0.000042
training loss: 0.993363 acc: 77.778, mean grad: -0.000104
0
