In [1]:
#@title Đoạn này import thư viện tổng
from meta_cntk_utils import *
from maml_utils import *
from tqdm.notebook import trange
import torch.nn.functional as F
from time import time
from copy import deepcopy
import pickle
import numpy as np

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [5]:
# !pip install numpy scipy matplotlib pandas scikit-image scikit-learn higher torchmeta cupy-cuda111

In [6]:
#@title Đoạn này cũng thư viện inner-level
"""
Meta-learning Omniglot and mini-imagenet experiments with iMAML-GD (see [1] for more details).

The code is quite simple and easy to read thanks to the following two libraries which need both to be installed.
- higher: https://github.com/facebookresearch/higher (used to get stateless version of torch nn.Module-s)
- torchmeta: https://github.com/tristandeleu/pytorch-meta (used for meta-dataset loading and minibatching)


[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019).
    Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124).
    https://arxiv.org/abs/1909.04630
"""
import math
import argparse
import time

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from torchmeta.datasets.helpers import omniglot, miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader

import higher
from maml_utils import np2torch
import hypergrad as hg
import tqdm



In [7]:
class Task:
    """
    Handles the train and valdation loss for a single task
    """

    def __init__(self, reg_param, meta_model, data, batch_size=None):
        device = next(meta_model.parameters()).device

        # stateless version of meta_model
        self.fmodel = higher.monkeypatch(meta_model, device=device, copy_initial_weights=True)

        self.n_params = len(list(meta_model.parameters()))
        self.train_input, self.train_target, self.test_input, self.test_target = data
        self.reg_param = reg_param
        self.batch_size = 1 if not batch_size else batch_size
        self.val_loss, self.val_acc = None, None

    def bias_reg_f(self, bias, params):
        # l2 biased regularization
        return sum([((b - p) ** 2).sum() for b, p in zip(bias, params)])

    def train_loss_f(self, params, hparams):
        # biased regularized cross-entropy loss where the bias are the meta-parameters in hparams
        out = self.fmodel(self.train_input, params=params)
        return F.cross_entropy(out, self.train_target) + 0.5 * self.reg_param * self.bias_reg_f(hparams, params)

    def val_loss_f(self, params, hparams):
        # cross-entropy loss (uses only the task-specific weights in params
        out = self.fmodel(self.test_input, params=params)
        val_loss = F.cross_entropy(out, self.test_target)/self.batch_size
        self.val_loss = val_loss.item()  # avoid memory leaks

        pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        self.val_acc = pred.eq(self.test_target.view_as(pred)).sum().item() / len(self.test_target)

        return val_loss

In [8]:
def inner_loop(hparams, params, optim, n_steps, log_interval, create_graph=False):
    params_history = [optim.get_opt_params(params)]

    for t in range(n_steps):
        params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))

        if log_interval and (t % log_interval == 0 or t == n_steps-1):
            print('t={}, Loss: {:.6f}'.format(t, optim.curr_loss.item()))

    return params_history

def get_cnn_omniglot(hidden_size, n_classes):
    def conv_layer(ic, oc, ):
        return nn.Sequential(
            nn.Conv2d(ic, oc, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.BatchNorm2d(oc, momentum=1., affine=True,
                           track_running_stats=True
                           )
        )

    net =  nn.Sequential(
        conv_layer(1, hidden_size),
        conv_layer(hidden_size, hidden_size),
        conv_layer(hidden_size, hidden_size),
        conv_layer(hidden_size, hidden_size),
        nn.Flatten(),
        nn.Linear(hidden_size, n_classes)
    )

    initialize(net)
    return net


def initialize(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.zero_()
            m.bias.data.zero_()

    return net

def get_inner_opt(train_loss):
    inner_opt_class = hg.GradientDescent
    inner_opt_kwargs = {'step_size': .1}
    return inner_opt_class(train_loss, **inner_opt_kwargs)

In [9]:
def train_imaml(meta_model,db,reg_param,hg_mode,K,T,outer_opt,inner_log_interval,notebook=True):
    meta_model.train()
    n_train_iter = db.x_train.shape[0] // db.batchsz
    for batch_idx in range(n_train_iter):
        tr_xs, tr_ys, tst_xs, tst_ys = db.next()
        outer_opt.zero_grad()

        val_loss, val_acc = 0, 0
        forward_time, backward_time = 0, 0
        for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):
            start_time_task = time.time()

            # single task set up
            task = Task(reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y), batch_size=tr_xs.shape[0])
            inner_opt = get_inner_opt(task.train_loss_f)

            # single task inner loop
            params = [p.detach().clone().requires_grad_(True) for p in meta_model.parameters()]
            last_param = inner_loop(meta_model.parameters(), params, inner_opt, T, log_interval=inner_log_interval)[-1]
            forward_time_task = time.time() - start_time_task

            # single task hypergradient computation
            if hg_mode == 'CG':
                # This is the approximation used in the paper CG stands for conjugate gradient
                cg_fp_map = hg.GradientDescent(loss_f=task.train_loss_f, step_size=1.)
                hg.CG(last_param, list(meta_model.parameters()), K=K, fp_map=cg_fp_map, outer_loss=task.val_loss_f)
            elif hg_mode == 'fixed_point':
                hg.fixed_point(last_param, list(meta_model.parameters()), K=K, fp_map=inner_opt,
                               outer_loss=task.val_loss_f)

            backward_time_task = time.time() - start_time_task - forward_time_task

            val_loss += task.val_loss
            val_acc += task.val_acc / task.batch_size

            forward_time += forward_time_task
            backward_time += backward_time_task

    outer_opt.step()

In [10]:
def test_imaml(test_tasks, meta_model, n_steps, get_inner_opt, reg_param, log_interval=None):
    meta_model.train()
    device = next(meta_model.parameters()).device

    val_losses, val_accs = [], []
    tr_xs,tr_ys,tst_xs,tst_ys = test_tasks

    tr_xs, tr_ys, tst_xs, tst_ys = np2torch([tr_xs, tr_ys, tst_xs, tst_ys],device=device,label_long_type=True)
    for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):

        task = Task(reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y))
        inner_opt = get_inner_opt(task.train_loss_f)

        params = [p.detach().clone().requires_grad_(True) for p in meta_model.parameters()]
        last_param = inner_loop(meta_model.parameters(), params, inner_opt, n_steps, log_interval=log_interval)[-1]

        task.val_loss_f(last_param, meta_model.parameters())

        val_losses.append(task.val_loss)
        val_accs.append(task.val_acc)

    return np.array(val_losses), np.array(val_accs)

In [7]:
torch.cuda.is_available()

  return torch._C._cuda_getDeviceCount() > 0


False

In [None]:
cuda = torch.cuda.is_available()
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
seed = 42
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

print("Success init", device)

torch.random.manual_seed(seed)
np.random.seed(seed)

In [None]:
hg_mode = 'CG'
inner_log_interval = None
inner_log_interval_test = None
ways = n_way = 5
shots = n_shot = 5
test_shots=20-n_shot

n_channels = 64

reg_param = 2
T, K = 16, 5

# K: number of CG step
# T: number of innerloop update step

T_test = T
inner_lr = .1



meta_model = get_cnn_omniglot(n_channels, ways).to(device)
outer_opt = torch.optim.Adam(params=meta_model.parameters())
inner_opt_class = hg.GradientDescent
inner_opt_kwargs = {'step_size': inner_lr}
import pickle
def get_inner_opt(train_loss):
    return inner_opt_class(train_loss, **inner_opt_kwargs)

In [None]:
accs = {}
n_characters = 20 # number of characters (i.e., classes) in the training dataset

assert n_characters % n_way == 0 

# Each task consists of n_way classes 
# So the least number of tasks is
# n_characters//n_way.
n_task = n_characters//n_way

# number of channels for MAML CNN
n_channel_maml = 64

In [None]:
#@title Load raw dataset and encode label
from types import SimpleNamespace
import time
dataset = load_dataset(n_task,True,seed)
train_set,test_set = get_train_data(dataset,n_test_per_class=1)

label_encoding_dim = 250
batch_norm = True # batch_norm for the feature extractor
dropout = 0 # Dropout for the feature extractor
pretrain_batch_size = 16
weight_decay = 0
pretrain_epochs = 50
train_data_enlarge_ratio =15
# Construct a randomly initialized CNN as the feature extractor.
net = build_CNN(train_set['n_class'], device, n_channel=label_encoding_dim,batch_norm=batch_norm,dropout=dropout)

if pretrain_epochs > 0:
    # Train a CNN on training data by supervised learning, in order
    # to obtain a better feature extractor than a random CNN when
    # training daorig_datasetta is relatively large. As the training data are of 
    # small-size, the supervised training leads to overfitted CNN,
    # which is a worse feature extractor than a random CNN.
    net,test_accs,test_losses = pretrain(net,train_set,test_set,device,seed=seed,epochs=pretrain_epochs,
                                         weight_decay=weight_decay,batch_size=pretrain_batch_size)

encode_labels(dataset,net,device)

# Given n_way*n_task classes of samples, we can us resampling to obtain many tasks that consists of n_way distinct classes
# of samples. The following is the resampling procedure.
from copy import deepcopy
orig_dataset = deepcopy(dataset)
if train_data_enlarge_ratio > 1:
    augment_train_data(dataset,enlarge_ratio=train_data_enlarge_ratio,n_way=n_way,n_shot=n_shot,seed=seed)

preprocess_label_embeddings(dataset)

In [None]:
#@title Process data thành few-shot setting
preprocess_label_embeddings(orig_dataset)
tasks = vars(orig_dataset)

X = np.concatenate([tasks['X_qry'],tasks['X_spt']],axis=1)
Y = np.concatenate([tasks['Y_qry'],tasks['Y_spt']],axis=1)
Y_emb = np.concatenate([tasks['Y_qry_emb'],tasks['Y_spt_emb']],axis=1)
new_X =[]
new_Y_emb = []
for x,y,y_emb in zip(X,Y,Y_emb):
    idxes = np.argsort(y).reshape(n_way,-1)
    for i in range(idxes.shape[0]):
        new_X.append([])
        new_Y_emb.append([])
        for j in range(idxes.shape[1]):
            idx = idxes[i,j]
            new_X[-1].append(x[idx])
            new_Y_emb[-1].append(y_emb[idx])
x_train = remove_padding(np.array(new_X))
y_train = None

test_tasks = remove_padding(tasks['test_X_spt']), tasks['test_Y_spt'],remove_padding(tasks['test_X_qry']), tasks['test_Y_qry']

from support.omniglot_loaders_original import OmniglotNShot
n_channel = n_channel_maml
batchsz = min(32 if n_channel <= 1024 else 8, n_characters)

db = OmniglotNShot(root=None,
    batchsz=batchsz,
    n_way=5,
    k_shot=1,
    k_query=19,
    imgsz=28,
    device=device,
    n_train_tasks=None,
    given_x=True,
    x_train=x_train,
    x_test=None,
    y_train = y_train,
)
n_out = n_way
net, meta_opt = build_MAML_model(n_out,device,lr=1e-3,n_channel=n_channel)

In [None]:
Dtr_spx, Dtr_spy, Dtr_qrx, Dtr_qry = db.next()
print(f"Batchsize: {batchsz}")
print(f"D train support set X shape: {Dtr_spx.shape}")
print(f"D train support set Y shape: {Dtr_spy.shape}")
print(f"D train query set X shape: {Dtr_qrx.shape}")
print(f"D train query set Y shape: {Dtr_qry.shape}")

In [None]:
#@title iMAML

from tqdm.notebook import trange
import time

imaml_epochs = 200
eval_interval = 5
log_interval = 5

test_accs = []
t = trange(imaml_epochs,desc='i-MAML Epoch')
for k in t:
    train_imaml(meta_model,db,reg_param,hg_mode,K,T,outer_opt,inner_log_interval)
    if k % eval_interval == 0:
        test_losses, test_acc = test_imaml(test_tasks, meta_model, T_test, get_inner_opt, reg_param, log_interval=None)
        test_acc = np.mean(test_acc)*100
        test_accs.append(test_acc)
        t.set_postfix(test_acc=test_acc,max_test_acc=np.max(test_accs))
accs['i-maml_accs'] = np.max(test_accs)