In [None]:
import torch
import torch.nn as nn
import numpy as np
from collections import OrderedDict
import math

In [None]:
class Multi_Layer_Perceptron(nn.Sequential):
    def __init__(self, input_dim, intern_dim, output_dim, depth = 2, isBiased = False):
        
        dict = OrderedDict([("input",nn.Linear(input_dim,intern_dim, bias=isBiased))])
        for i in range(depth):
            dict.update({str(i) : nn.Linear(intern_dim,intern_dim,bias=isBiased)})
        dict.update({"output" : nn.Linear(intern_dim,output_dim,bias=isBiased)})

        super().__init__(dict)

        self.reset_init_weights_biases() # so that we do not use a default initialization

    def reset_init_weights_biases(self, norm = None):
        for layer in self.children():
            if norm == None:
                stdv = 1. / math.sqrt(layer.weight.size(1))
            else :
                stdv = norm
            
            layer.weight.data.uniform_(-stdv, stdv)
            if layer.bias is not None:
                layer.biases.data.uniform_(-stdv, stdv)

In [None]:
def train(model, input_data, output_data, lossFct = nn.MSELoss(), optimizer = None, epochs = 20, init_norm = None, save = True, debug = False, savename='model.pt'):

    if optimizer is None:
        optimizer = torch.optim.SGD(model.parameters())
    
    if init_norm is not None:
        model.reset_init_weights_biases(init_norm)

    for i in range(epochs):
        y_pred = model(input_data)
        loss = lossFct(y_pred, output_data)

        if math.isnan(loss.item()):
            print(f"Epoch: {i+1}   Loss: {loss.item()}")
            break
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if debug:
            if (i+1)%(epochs/debug) == 0:
                print(f"Epoch: {i+1}   Loss: {loss.item()}")
        
        if save:
            torch.save(model.state_dict(), DIRPATH+savename)
    

In [1]:
%run main_Sam.py

Epoch: 1000   Loss: 3.208e+01
Clean observations
Model 1:
   - objective: 2.840e+01
   - weights norm: 9.72
Model 2:
   - objective: 3.208e+01
   - weights norm: 7.41
8.019290924072266
Epoch: 1000   Loss: 3.327e+01
Noisy observations
Model 1:
   - objective: 2.951e+01
   - weights norm: 10.02
Model 2:
   - objective: 3.327e+01
   - weights norm: 6.77
