In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import random
import seaborn as sbs
from torch.autograd import Variable
from tqdm import tqdm_notebook as tqdm
import copy
%matplotlib inline

class Sine_Model(nn.Module):
    def __init__(self):
        super(Sine_Model, self).__init__()
        self.lin1 = nn.Linear(1, 128)
        self.lin2 = nn.Linear(128, 128)
        self.lin3 = nn.Linear(128, 1)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x):
        x = self.dropout(self.lin1(x.view(-1, 1)))
        x = self.dropout(torch.tanh(self.lin2(x)))
        return torch.tanh(self.lin3(torch.tanh(x))).squeeze()
    
class SineWaveTask:
    def __init__(self, train_size):
        self.a = np.random.uniform(0.5, 1.0)
        self.b = np.random.uniform(0, 2)
        self.train_x, self.train_y = self.training_set(size = train_size)
        self.test_x,  self.test_y  = self.test_set()
        
    def f(self, x):
        return self.a * np.sin(x + self.b * np.pi)
        
    def training_set(self, size=10):
        x = np.random.uniform(-5, 5, size)
        y = self.f(x)
        return torch.Tensor(x), torch.Tensor(y)
    
    def test_set(self, size=50):
        x = np.linspace(-10, 10, size)
        y = self.f(x)
        return torch.Tensor(x), torch.Tensor(y)
    
    def plot(self, *args, **kwargs):
        plt.plot(self.test_x.numpy(),  self.test_y.numpy(), 'o-',  label = 'Test', color='blue')
        plt.scatter(self.train_x.numpy(), self.train_y.numpy(),   label = 'Train', color='black')
def replace_grad(parameter_gradients, parameter_name):
    """Creates a backward hook function that replaces the calculated gradient
    with a precomputed value when .backward() is called.
    
    See
    https://pytorch.org/docs/stable/autograd.html?highlight=hook#torch.Tensor.register_hook
    for more info
    """
    def replace_grad_(module):
        return parameter_gradients[parameter_name]

    return replace_grad_

In [None]:
source_tasks = [SineWaveTask(train_size = 50) for i in range(100)]
target_task  = SineWaveTask(train_size = 50)

for task in source_tasks:
    task.plot()
plt.show()  

target_task.plot()
plt.show()

device = torch.device("cuda:0")

In [None]:
test_loss = []
model = Sine_Model().to(device)
criterion = nn.MSELoss()
# Transfer Learning (Learning all the tasks simultaneously with a single model, fail)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-3)
for epoch in np.arange(10) + 1:
    for out_batch in np.arange(128):
        for task in source_tasks + [target_task]:
            model.train()
            optimizer.zero_grad()
            x, y = task.train_x.to(device), Variable(task.train_x).to(device)
            pred_y = model.forward(x)
            loss = criterion(pred_y, y)
            loss.backward()
            optimizer.step()

        model.eval()
        x, y = target_task.test_x.to(device), target_task.test_y.to(device)
        pred_y = model.forward(x)
        l = criterion(pred_y, y)
        test_loss += [l.cpu().detach().numpy()]
    plt.title(l.cpu().detach().numpy())
    plt.plot(x.cpu().detach().numpy(), pred_y.cpu().detach().numpy(), '^--', label='predict')
    target_task.plot()
    plt.legend()
    plt.show()
plt.plot(test_loss)
plt.show()

In [None]:
test_loss = []
model = Sine_Model().to(device)
criterion = nn.MSELoss()
# Supervise
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-3)
for epoch in np.arange(10) + 1:
    for batch in np.arange(128):
        model.train()
        optimizer.zero_grad()
        x, y = target_task.train_x.to(device), Variable(target_task.train_y).to(device)
        pred_y = model.forward(x)
        loss = criterion(pred_y, y)
        loss.backward()
        optimizer.step()
        model.eval()
        x, y = target_task.test_x.to(device), target_task.test_y.to(device)
        pred_y = model.forward(x)
        l = criterion(pred_y, y)
        test_loss += [l.cpu().detach().numpy()]
    plt.title(l.cpu().detach().numpy())
    plt.plot(x.cpu().detach().numpy(), pred_y.cpu().detach().numpy(), '^--', label='predict')
    target_task.plot()
    plt.legend()
    plt.show()
plt.plot(test_loss)
plt.show()

In [None]:
meta_train_loss = []
meta_test_loss  = []
lr_inner = 0.01
model     = Sine_Model().to(device)
criterion = nn.MSELoss()
# MAML
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
for epoch in np.arange(10) + 1:
    for meta_batch in np.arange(128):
        task_grads = []
        task_losss = []
        '''
        Cumulate Inner Gradient
        '''
        for task in source_tasks:
            model_tmp = copy.deepcopy(model)
            model_tmp.train()
            optimizer_tmp = torch.optim.Adam(model.parameters())
            for inner_batch in range(4):
                optimizer_tmp.zero_grad()
                x, y = task.train_x.to(device), Variable(task.train_y).to(device)
                pred_y = model_tmp.forward(x)
                loss = criterion(pred_y, y)
                loss.backward()
                optimizer_tmp.step()
            model_tmp.eval()
            optimizer_tmp.zero_grad()
            x, y = task.test_x.to(device), Variable(task.test_y).to(device)
            pred_y = model_tmp.forward(x)
            loss = criterion(pred_y, y)   
            task_losss += [loss.cpu().detach().numpy()]
            loss.backward()
            task_grads += [{name: param.grad for (name, param) in model_tmp.named_parameters()}]
        meta_train_loss += [np.average(task_losss)]
        '''
        Evaluate:
        '''
        model_tmp = copy.deepcopy(model)
        model_tmp.train()
        optimizer_tmp = torch.optim.Adam(model_tmp.parameters())
        for inner_batch in range(4):
            optimizer_tmp.zero_grad()
            x, y = target_task.train_x.to(device), Variable(target_task.train_y).to(device)
            pred_y = model_tmp.forward(x)
            loss = criterion(pred_y, y)
            loss.backward()
            optimizer_tmp.step()
        model_tmp.eval()
        x, y = target_task.test_x.to(device), Variable(target_task.test_y).to(device)
        pred_target = model_tmp.forward(x)
        loss = criterion(pred_target, y)   
        meta_test_loss += [loss.cpu().detach().numpy()]  
        
        '''
        Meta-Update
        '''
        
        avg_task_grad = {name: torch.stack([name_grad[name] for name_grad in task_grads]).mean(dim=0)
                                  for name in task_grads[0].keys()}
        
        hooks = []
        for name, param in model.named_parameters():
            hooks.append(
                param.register_hook(replace_grad(avg_task_grad, name))
            )

        optimizer.zero_grad() 
        pred_y = model.forward(x)
        loss = criterion(pred_y, y)
        loss.backward()  
        optimizer.step()
        
        for h in hooks:
            h.remove()
    plt.plot(meta_train_loss, label = 'meta_train')
    plt.plot(meta_test_loss,  label = 'meta_test')
    plt.legend()
    plt.show()
    
    plt.plot(target_task.test_x.cpu().detach().numpy(), pred_y.cpu().detach().numpy(), '^--', label='predict')
    target_task.plot()
    plt.legend()
    plt.show()