In [1]:
import torch.nn as nn
from collections import OrderedDict
############Import required functions##########################################

class MAMLModel(nn.Module):
    def __init__(self):
        super(MAMLModel, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(1,40)),
            ('relu1', nn.ReLU())
        ]))
        self.modela = nn.Sequential(OrderedDict([
            ('l2a', nn.Linear(40,40)),
            ('relu2a', nn.ReLU()),
            ('l3a', nn.Linear(40,1))
        ]))
        self.modelb = nn.Sequential(OrderedDict([
            ('l2b', nn.Linear(40,40)),
            ('relu2b', nn.ReLU()),
            ('l3b', nn.Linear(40,1))
        ]))
        
    def forward(self, x):
        return self.modela(self.model(x)), self.modelb(self.model(x))
    
    def parameterised(self, x, weights):
        # like forward, but uses ``weights`` instead of ``model.parameters()``
        # it'd be nice if this could be generated automatically for any nn.Module...
        x = nn.functional.linear(x, weights[0], weights[1])
        x = nn.functional.relu(x)
        xa = nn.functional.linear(x, weights[2], weights[3])
        xa = nn.functional.relu(xa)
        xa = nn.functional.linear(xa, weights[4], weights[5])
        xb = nn.functional.linear(x, weights[6], weights[7])
        xb = nn.functional.relu(xb)
        xb = nn.functional.linear(xb, weights[8], weights[9])
        return xa, xb

In [2]:
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

############Import required functions##########################################
import torch
import torch.nn as nn
############Import required functions##########################################
class MAML():
    def __init__(self, model, tasks, inner_lr, meta_lr, K=30, inner_steps=1, tasks_per_meta_batch=1000):
        
        # important objects
        self.tasks = tasks
        self.model = model
        # Puting model on gpu if available
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        self.weights = list(model.parameters())
        self.criterion = nn.MSELoss()
        self.meta_optimiser = torch.optim.Adam(self.weights, meta_lr)
    
        
        # hyperparameters
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.K = K
        self.inner_steps = inner_steps 
        self.tasks_per_meta_batch = tasks_per_meta_batch 
        
        # metrics
        self.plot_every = 5
        self.print_every = 5
        self.meta_losses = []
        self.meta_mean_losses = []
        self.meta_sigma_losses = []
    
    def inner_loop(self, task):
        # reset inner model to current maml weights
        temp_weights = [w.clone() for w in self.weights]
        
        # perform training on data sampled from task
        X, y = task.sample_data(True,self.K)
        #y_noise = (y2 - y)
        X = X.to(self.device)
        y = y.to(self.device)
        #y2 = y2.to(self.device)
        #y_noise = y_noise.to(self.device)
        for step in range(self.inner_steps):
            #print((self.model.parameterised(X, temp_weights)[1]).shape)
            #print("I m here")
            mean_loss = self.criterion(self.model.parameterised(X, temp_weights)[0], y) / self.K # kind of training loss
            sigma_loss = self.criterion(self.model.parameterised(X, temp_weights)[1] , torch.abs(y - self.model.parameterised(X, temp_weights)[0])) / self.K # kind of training loss
            final_loss = (mean_loss + sigma_loss)
            
            # compute grad and update inner loop weights
            grad =torch.autograd.grad(final_loss, temp_weights)
            #grad_noise =torch.autograd.grad(noise_loss, temp_noise_weights)
            temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]
            
        
        #sample new data for meta-update and compute loss
        X, y =  task.sample_data(True, self.K)
       # y_noise = (y2 - y)
        X = X.to(self.device)
        y = y.to(self.device)
        #y2 = y2.to(self.device)
        #y_noise = y_noise.to(self.device)
       #print("y_noise:", y_noise.shape)
        mean_loss = self.criterion(self.model.parameterised(X, temp_weights)[0], y) / self.K # kind of training loss
        sigma_loss = self.criterion(self.model.parameterised(X, temp_weights)[1] , torch.abs(y - self.model.parameterised(X, temp_weights)[0])) / self.K
        final_loss = mean_loss + sigma_loss
        return (final_loss, mean_loss.item(), sigma_loss.item())
    
    def main_loop(self, num_iterations):
        epoch_loss = 0
        epoch_mean_loss = 0
        epoch_sigma_loss = 0
        
        for iteration in range(1, num_iterations+1):
            
            # compute meta loss
            meta_loss = 0
            meta_mean_loss = 0
            meta_sigma_loss = 0
            
            for i in range(self.tasks_per_meta_batch):
                task = self.tasks.sample_task()
                a,b,c = self.inner_loop(task)
                meta_loss += a
                meta_mean_loss += b
                meta_sigma_loss += c
            
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.weights)
            
            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            self.meta_optimiser.step()
            
            
            # log metrics
            epoch_loss += meta_loss.item() / self.tasks_per_meta_batch
            epoch_mean_loss += meta_mean_loss / self.tasks_per_meta_batch
            epoch_sigma_loss += meta_sigma_loss / self.tasks_per_meta_batch
            
            if iteration % self.print_every == 0:
                print("{}/{}. loss: {}".format(iteration, num_iterations, epoch_loss / self.plot_every))
                print("{}/{}. y_loss: {}".format(iteration, num_iterations, epoch_mean_loss / self.plot_every))
                print("{}/{}. noise_loss: {}".format(iteration, num_iterations, epoch_sigma_loss / self.plot_every))
                
            
            if iteration % self.plot_every == 0:
                self.meta_losses.append(epoch_loss / self.plot_every)
                self.meta_mean_losses.append(epoch_mean_loss / self.plot_every)
                self.meta_sigma_losses.append(epoch_sigma_loss / self.plot_every)
                epoch_loss = 0
                epoch_mean_loss = 0
                epoch_sigma_loss = 0

In [3]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Oct  3 19:36:48 2019

@author: debjani
"""
############Import required functions##########################################
import torch
import numpy as np
from network_noisy import MAMLModel
from meta_training_noisy import MAML
from src.sine_tasks_uncertainty import Sine_Task_Distribution
###############################################################################
def main():
    
    #sample tasks
    tasks = Sine_Task_Distribution(0.1, 5, 0, np.pi, -5, 5, 0, np.pi)
    maml = MAML(MAMLModel(), tasks, inner_lr=0.001, meta_lr=0.001)
    maml.main_loop(num_iterations=10000)
    # save the model
    torch.save(maml.model.state_dict(), 'models/model_30_10000.pth')
    
if __name__ == '__main__':
    main()

5/10000. loss: 0.4407075622558594
5/10000. y_loss: 0.2309282738971524
5/10000. noise_loss: 0.20977933550585295
10/10000. loss: 0.39687709350585937
10/10000. y_loss: 0.22014662838004298
10/10000. noise_loss: 0.17673051777255025
15/10000. loss: 0.37053208618164063
15/10000. y_loss: 0.2161868543075427
15/10000. noise_loss: 0.15434524910601322
20/10000. loss: 0.3653948364257813
20/10000. y_loss: 0.22141700086892815
20/10000. noise_loss: 0.1439776853163261
25/10000. loss: 0.3650126037597656
25/10000. y_loss: 0.22675484180497008
25/10000. noise_loss: 0.13825779327200727
30/10000. loss: 0.34609348754882807
30/10000. y_loss: 0.21802871153541345
30/10000. noise_loss: 0.1280648700384423
35/10000. loss: 0.3572537536621094
35/10000. y_loss: 0.22471961892640682
35/10000. noise_loss: 0.1325341299958527
40/10000. loss: 0.34173215332031254
40/10000. y_loss: 0.2168102829143012
40/10000. noise_loss: 0.12492186211533843
45/10000. loss: 0.327935595703125
45/10000. y_loss: 0.20983465414918903
45/10000. noi

360/10000. loss: 0.2902663513183594
360/10000. y_loss: 0.1923353200821206
360/10000. noise_loss: 0.09793103165533394
365/10000. loss: 0.2992430053710937
365/10000. y_loss: 0.19868253155224957
365/10000. noise_loss: 0.10056038821376859
370/10000. loss: 0.2844400207519531
370/10000. y_loss: 0.19005736813335677
370/10000. noise_loss: 0.09438267751801759
375/10000. loss: 0.2949103210449219
375/10000. y_loss: 0.19595385260423645
375/10000. noise_loss: 0.0989565235702321
380/10000. loss: 0.2967346801757812
380/10000. y_loss: 0.1969207351332996
380/10000. noise_loss: 0.09981397429872305
385/10000. loss: 0.3000840881347656
385/10000. y_loss: 0.19865505812466142
385/10000. noise_loss: 0.10142899267598986
390/10000. loss: 0.29992590942382813
390/10000. y_loss: 0.1984351046252996
390/10000. noise_loss: 0.10149067143108695
395/10000. loss: 0.29225361938476563
395/10000. y_loss: 0.19449964554701
395/10000. noise_loss: 0.09775387675873935
400/10000. loss: 0.3041015686035156
400/10000. y_loss: 0.2018

710/10000. loss: 0.29161337280273436
710/10000. y_loss: 0.1939239886483643
710/10000. noise_loss: 0.09768950376342982
715/10000. loss: 0.29093756713867186
715/10000. y_loss: 0.1926042590962723
715/10000. noise_loss: 0.09833325190208853
720/10000. loss: 0.29304679565429687
720/10000. y_loss: 0.19467510025659576
720/10000. noise_loss: 0.09837174670714885
725/10000. loss: 0.3001169860839844
725/10000. y_loss: 0.19860780515070073
725/10000. noise_loss: 0.1015091838981956
730/10000. loss: 0.3035155883789062
730/10000. y_loss: 0.20067188452258705
730/10000. noise_loss: 0.10284371464438738
735/10000. loss: 0.312658203125
735/10000. y_loss: 0.2074966427822597
735/10000. noise_loss: 0.10516165658608079
740/10000. loss: 0.30016772460937496
740/10000. y_loss: 0.19976110856840384
740/10000. noise_loss: 0.10040667170342057
745/10000. loss: 0.2990158142089844
745/10000. y_loss: 0.19808368604788557
745/10000. noise_loss: 0.10093219341430812
750/10000. loss: 0.2906202758789062
750/10000. y_loss: 0.192

KeyboardInterrupt: 

In [5]:
import matplotlib.pyplot as plt
plt.plot(maml.meta_losses)


NameError: name 'maml' is not defined