<a href="https://colab.research.google.com/github/vainaijr/cavia/blob/master/cavia_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# google drive mount

In [0]:
from google.colab import drive
drive.mount('/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /gdrive


In [0]:
%xmode verbose
%pdb on
from IPython.core.debugger import set_trace
import time
from datetime import datetime
!pip install torchsummaryX

Exception reporting mode: Verbose
Automatic pdb calling has been turned ON
Collecting torchsummaryX
  Downloading https://files.pythonhosted.org/packages/0f/58/e6a19d2cd1784c16b43f3f9f4946fa5f084a1da3e8a758545cc95edb8fc0/torchsummaryX-1.2.0-py3-none-any.whl
Installing collected packages: torchsummaryX
Successfully installed torchsummaryX-1.2.0


# arguments

In [0]:
import argparse
import torch


def parse_args():
    
    parser = argparse.ArgumentParser(description='Fast Context Adaptation via Meta-Learning (CAVIA),'
                                                 'Regression experiments')

    parser.add_argument('--task', type=str, default='sine', help='problem setting: sine or celeba')

    parser.add_argument('--n_iter', type=int, default=50000, help='number of meta-iterations')

    parser.add_argument('--tasks_per_metaupdate', type=int, default=25)

    parser.add_argument('--k_meta_train', type=int, default=10, help='data points in task training set (during meta training, inner loop)')
    parser.add_argument('--k_meta_test', type=int, default=10, help='data points in task test set (during meta training, outer loop)')
    parser.add_argument('--k_shot_eval', type=int, default=10, help='data points in task training set (during evaluation)')

    parser.add_argument('--lr_inner', type=float, default=1.0, help='inner-loop learning rate (task-specific)')
    parser.add_argument('--lr_meta', type=float, default=0.001, help='outer-loop learning rate')

    parser.add_argument('--num_inner_updates', type=int, default=1, help='number of inner-loop updates (during training)')

    parser.add_argument('--num_context_params', type=int, default=5, help='number of context parameters (added at first layer)')
    parser.add_argument('--num_hidden_layers', nargs='+', default=[40, 40])

    parser.add_argument('--first_order', action='store_true', default=False, help='run first-order version')

    parser.add_argument('--maml', action='store_true', default=False, help='run MAML')
    parser.add_argument('--seed', type=int, default=42)

    # commands specific to the CelebA image completion task
    parser.add_argument('--use_ordered_pixels', action='store_true', default=False)

    args = parser.parse_args('')

    # use the GPU if available
    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    return args

# cavia_model

In [0]:
import torch
import torch.nn.functional as F
from torch import nn


class CaviaModel(nn.Module):
    """
    Feed-forward neural network with context parameters.
    """

    def __init__(self,
                 n_in,
                 n_out,
                 num_context_params,
                 n_hidden,
                 device
                 ):
        super(CaviaModel, self).__init__()

        self.device = device

        # fully connected layers
        self.fc_layers = nn.ModuleList()
        self.fc_layers.append(nn.Linear(n_in + num_context_params, n_hidden[0]))
        for k in range(len(n_hidden) - 1):
            self.fc_layers.append(nn.Linear(n_hidden[k], n_hidden[k + 1]))
        self.fc_layers.append(nn.Linear(n_hidden[-1], n_out))

        # context parameters (note that these are *not* registered parameters of the model!)
        self.num_context_params = num_context_params
        self.context_params = None
        self.reset_context_params()

    def reset_context_params(self):
        self.context_params = torch.zeros(self.num_context_params).to(self.device)
        self.context_params.requires_grad = True

    def forward(self, x):
        # set_trace()
        
        # x -> 10 x 1
        
        # concatenate input with context parameters
        x = torch.cat((x, self.context_params.expand(x.shape[0], -1)), dim=1)
        
        # x -> 10 x 6
        
        for k in range(len(self.fc_layers) - 1):
            x = F.relu(self.fc_layers[k](x))
        
        # x -> 10 x 40
        
        y = self.fc_layers[-1](x)
        
        # y -> 10 x 1
        
        return y

# cavia

In [0]:
"""
Regression experiment using CAVIA
"""
import copy
import os
import time

import numpy as np
import scipy.stats as st
import torch
import torch.nn.functional as F
import torch.optim as optim


def run(args, log_interval=5000, rerun=False):
    assert not args.maml

    # see if we already ran this experiment
    code_root = os.path.dirname('/gdrive/My Drive/cavia/regression/' + datetime.now().strftime('%Y-%m-%d_%H_%M'))
    if not os.path.isdir('{}/{}_result_files/{}'.format(code_root, args.task, datetime.now().strftime('%Y-%m-%d_%H_%M'))):
        # os.mkdir('{}/'.format(code_root))
        # os.mkdir('{}/{}_result_files/'.format(code_root, args.task))
        os.mkdir('{}/{}_result_files/{}/'.format(code_root, args.task, datetime.now().strftime('%Y-%m-%d_%H_%M')))
    path = '{}/{}_result_files/{}/'.format(code_root, args.task, datetime.now().strftime('%Y-%m-%d_%H_%M')) + get_path_from_args(args)

    if os.path.exists(path + '.pkl') and not rerun:
        return load_obj(path)

    start_time = time.time()
    set_seed(args.seed)

    # --- initialise everything ---

    # get the task family
    if args.task == 'sine':
        task_family_train = RegressionTasksSinusoidal()
        task_family_valid = RegressionTasksSinusoidal()
        task_family_test = RegressionTasksSinusoidal()
    elif args.task == 'celeba':
        task_family_train = tasks_celebA.CelebADataset('train', device=args.device)
        task_family_valid = tasks_celebA.CelebADataset('valid', device=args.device)
        task_family_test = tasks_celebA.CelebADataset('test', device=args.device)
    else:
        raise NotImplementedError

    # initialise network
    model = CaviaModel(n_in=task_family_train.num_inputs,
                       n_out=task_family_train.num_outputs,
                       num_context_params=args.num_context_params,
                       n_hidden=args.num_hidden_layers,
                       device=args.device
                       ).to(args.device)
    
    # intitialise meta-optimiser
    # (only on shared params - context parameters are *not* registered parameters of the model)
    !pip install adabound
    import adabound
    # meta_optimiser = torch.optim.Adam(model.parameters(), 0.001)
    meta_optimiser = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
    
    # meta_optimiser = optim.Adam(model.parameters(), args.lr_meta)
    set_trace()
    # initialise loggers
    logger = Logger()
    logger.best_valid_model = copy.deepcopy(model)

    # --- main training loop ---

    for i_iter in range(args.n_iter):

        # initialise meta-gradient
        meta_gradient = [0 for _ in range(len(model.state_dict()))]
        
        # sample tasks
        target_functions = task_family_train.sample_tasks(args.tasks_per_metaupdate)

        # --- inner loop ---

        for t in range(args.tasks_per_metaupdate):

            # reset private network weights
            model.reset_context_params()

            # get data for current task
            train_inputs = task_family_train.sample_inputs(args.k_meta_train,
                                                           args.use_ordered_pixels).to(args.device)

            for _ in range(args.num_inner_updates):
                
                # forward through model
                train_outputs = model(train_inputs)
                # train_outpus -> 10 x 1
                
                # get targets
                train_targets = target_functions[t](train_inputs)
                # train_targets -> 10 x 1

                # ------------ update on current task ------------

                # compute loss for current task
                task_loss = F.mse_loss(train_outputs, train_targets)

                # compute gradient wrt context params
                task_gradients = \
                    torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]
                # task_gradients -> 5 x 1
                
                # update context params (this will set up the computation graph correctly)
                model.context_params = model.context_params - args.lr_inner * task_gradients
                # model.context_params -> 5 x 1

            # ------------ compute meta-gradient on test loss of current task ------------

            # get test data
            test_inputs = task_family_train.sample_inputs(args.k_meta_test, args.use_ordered_pixels).to(args.device)

            # get outputs after update
            test_outputs = model(test_inputs)

            # get the correct targets
            test_targets = target_functions[t](test_inputs)
            

            # compute loss after updating context (will backprop through inner loop)
            loss_meta = F.mse_loss(test_outputs, test_targets)

            # compute gradient + save for current task
            task_grad = torch.autograd.grad(loss_meta, model.parameters())

            for i in range(len(task_grad)):
                # clip the gradient
                meta_gradient[i] += task_grad[i].detach().clamp_(-10, 10)
                

        # ------------ meta update ------------

        # assign meta-gradient
        for i, param in enumerate(model.parameters()):
            # param -> 40 x 6 -> 40 -> 40 x 40 -> 40 -> 1 x 40 -> 1
            param.grad = meta_gradient[i] / args.tasks_per_metaupdate

        # do update step on shared model
        meta_optimiser.step()

        # reset context params
        model.reset_context_params()

        # ------------ logging ------------

        if i_iter % log_interval == 0:

            # evaluate on training set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_train,
                                              num_updates=args.num_inner_updates)
            logger.train_loss.append(loss_mean)
            logger.train_conf.append(loss_conf)

            # evaluate on test set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_valid,
                                              num_updates=args.num_inner_updates)
            logger.valid_loss.append(loss_mean)
            logger.valid_conf.append(loss_conf)

            # evaluate on validation set
            loss_mean, loss_conf = eval_cavia(args, copy.deepcopy(model), task_family=task_family_test,
                                              num_updates=args.num_inner_updates)
            logger.test_loss.append(loss_mean)
            logger.test_conf.append(loss_conf)

            # save logging results
            save_obj(logger, path)

            # save best model
            if logger.valid_loss[-1] == np.min(logger.valid_loss):
                print('saving best model at iter', i_iter)
                logger.best_valid_model = copy.deepcopy(model)

            # visualise results
            if args.task == 'celeba':
                tasks_celebA.visualise(task_family_train, task_family_test, copy.deepcopy(logger.best_valid_model),
                                       args, i_iter)

            # print current results
            logger.print_info(i_iter, start_time)
            start_time = time.time()

    return logger


def eval_cavia(args, model, task_family, num_updates, n_tasks=100, return_gradnorm=False):

    # get the task family
    input_range = task_family.get_input_range().to(args.device)

    # logging
    losses = []
    gradnorms = []

    # --- inner loop ---

    for t in range(n_tasks):

        # sample a task
        target_function = task_family.sample_task()

        # reset context parameters
        model.reset_context_params()

        # get data for current task
        curr_inputs = task_family.sample_inputs(args.k_shot_eval, args.use_ordered_pixels).to(args.device)
        curr_targets = target_function(curr_inputs)

        # ------------ update on current task ------------

        for _ in range(1, num_updates + 1):

            # forward pass
            curr_outputs = model(curr_inputs)

            # compute loss for current task
            task_loss = F.mse_loss(curr_outputs, curr_targets)

            # compute gradient wrt context params
            task_gradients = \
                torch.autograd.grad(task_loss, model.context_params, create_graph=not args.first_order)[0]

            # update context params
            if args.first_order:
                model.context_params = model.context_params - args.lr_inner * task_gradients.detach()
            else:
                model.context_params = model.context_params - args.lr_inner * task_gradients

            # keep track of gradient norms
            gradnorms.append(task_gradients[0].norm().item())

        # ------------ logging ------------

        # compute true loss on entire input range
        model.eval()
        losses.append(F.mse_loss(model(input_range), target_function(input_range)).detach().item())
        model.train()

    losses_mean = np.mean(losses)
    losses_conf = st.t.interval(0.95, len(losses) - 1, loc=losses_mean, scale=st.sem(losses))
    if not return_gradnorm:
        return losses_mean, np.mean(np.abs(losses_conf - losses_mean))
    else:
        return losses_mean, np.mean(np.abs(losses_conf - losses_mean)), np.mean(gradnorms)

# logger

In [0]:
import time

import numpy as np


class Logger:

    def __init__(self):
        self.train_loss = []
        self.train_conf = []

        self.valid_loss = []
        self.valid_conf = []

        self.test_loss = []
        self.test_conf = []

        self.best_valid_model = None

    def print_info(self, iter_idx, start_time):
        print(
            'Iter {:<4} - time: {:<5} - [train] loss: {:<6} (+/-{:<6}) - [valid] loss: {:<6} (+/-{:<6}) - [test] loss: {:<6} (+/-{:<6})'.format(
                iter_idx,
                int(time.time() - start_time),
                np.round(self.train_loss[-1], 4),
                np.round(self.train_conf[-1], 4),
                np.round(self.valid_loss[-1], 4),
                np.round(self.valid_conf[-1], 4),
                np.round(self.test_loss[-1], 4),
                np.round(self.test_conf[-1], 4),
            )
        )

# regression_task_sinusoidal

In [0]:
import numpy as np
import torch


class RegressionTasksSinusoidal:
    """
    Same regression task as in Finn et al. 2017 (MAML)
    """

    def __init__(self):
        self.num_inputs = 1
        self.num_outputs = 1

        self.amplitude_range = [0.1, 5.0]
        self.phase_range = [0, np.pi]

        self.input_range = [-5, 5]

    def get_input_range(self, size=100):
        return torch.linspace(self.input_range[0], self.input_range[1], steps=size).unsqueeze(1)

    def sample_inputs(self, batch_size, *args, **kwargs):
        inputs = torch.rand((batch_size, self.num_inputs))
        inputs = inputs * (self.input_range[1] - self.input_range[0]) + self.input_range[0]
        return inputs

    def sample_task(self):
        amplitude = np.random.uniform(self.amplitude_range[0], self.amplitude_range[1])
        phase = np.random.uniform(self.phase_range[0], self.phase_range[1])
        return self.get_target_function(amplitude, phase)

    @staticmethod
    def get_target_function(amplitude, phase):
        def target_function(x):
            if isinstance(x, torch.Tensor):
                return torch.sin(x - phase) * amplitude
            else:
                return np.sin(x - phase) * amplitude

        return target_function

    def sample_tasks(self, num_tasks, return_specs=False):

        amplitude = np.random.uniform(self.amplitude_range[0], self.amplitude_range[1], num_tasks)
        phase = np.random.uniform(self.phase_range[0], self.phase_range[1], num_tasks)

        target_functions = []
        for i in range(num_tasks):
            target_functions.append(self.get_target_function(amplitude[i], phase[i]))

        if return_specs:
            return target_functions, amplitude, phase
        else:
            return target_functions

    def sample_datapoints(self, batch_size):
        """
        Sample random input/output pairs (e.g. for training an orcale)
        :param batch_size:
        :return:
        """

        amplitudes = torch.Tensor(np.random.uniform(self.amplitude_range[0], self.amplitude_range[1], batch_size))
        phases = torch.Tensor(np.random.uniform(self.phase_range[0], self.phase_range[1], batch_size))

        inputs = torch.rand((batch_size, self.num_inputs))
        inputs = inputs * (self.input_range[1] - self.input_range[0]) + self.input_range[0]
        inputs = inputs.view(-1)

        outputs = torch.sin(inputs - phases) * amplitudes
        outputs = outputs.unsqueeze(1)

        return torch.stack((inputs, amplitudes, phases)).t(), outputs

# utils

In [0]:
import hashlib
import os
import pickle
import random

import numpy as np
import torch


def set_seed(seed, cudnn=True):
    """
    Seed everything we can!
    Note that gym environments might need additional seeding (env.seed(seed)),
    and num_workers needs to be set to 1.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # note: the below slows down the code but makes it reproducible
    if (seed is not None) and cudnn:
        torch.backends.cudnn.deterministic = True


def save_obj(obj, name):
    with open(name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(name):
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)


def get_path_from_args(args):
    """ Returns a unique hash for an argparse object. """
    args_str = str(args)
    path = hashlib.md5(args_str.encode()).hexdigest()
    return path


def get_base_path():
    p = os.path.dirname(os.path.realpath(__file__))
    if os.path.exists(p):
        return p
    raise RuntimeError('I dont know where I am; please specify a path for saving results.')

# main

In [0]:
if __name__ == '__main__':

    args = parse_args()

    if args.maml:
        logger = maml.run(args, log_interval=100, rerun=True)
    else:
        logger = run(args, log_interval=100, rerun=True)

Collecting adabound
  Downloading https://files.pythonhosted.org/packages/cd/44/0c2c414effb3d9750d780b230dbb67ea48ddc5d9a6d7a9b7e6fcc6bdcff9/adabound-0.0.5-py3-none-any.whl
Installing collected packages: adabound
Successfully installed adabound-0.0.5
> [0;32m<ipython-input-12-65da013eafeb>[0m(64)[0;36mrun[0;34m()[0m
[0;32m     62 [0;31m    [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     63 [0;31m    [0;31m# initialise loggers[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 64 [0;31m    [0mlogger[0m [0;34m=[0m [0mLogger[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m    [0mlogger[0m[0;34m.[0m[0mbest_valid_model[0m [0;34m=[0m [0mcopy[0m[0;34m.[0m[0mdeepcopy[0m[0;34m([0m[0mmodel[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m
saving best model at iter 0
Iter 0    - time: 14    - [train] loss: 4.6189 (+/-0.7951) - [valid] loss: 4.28   (+/-0.

# New Section

In [0]:
os.path.realpath(__file__)