In [1]:
import sys
import random
import datetime
import operator
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torch.utils.data import DataLoader
from IPython.display import display, HTML

display(HTML(data="""
<style>
    div#notebook-container    { width: 95%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>
"""))


def print_message(s):
    print("[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s), flush=True)

In [2]:
class OurFunc:

    def __init__(self, num_layers, params = None):
        self.num_layers = num_layers
        if params == None:
            self.ws = [random.uniform(-1, 1) for i in range(num_layers)]
            self.bs = [random.uniform(-1, 1) for i in range(num_layers)]
        else:
            assert len(params['ws']) == num_layers and len(params['bs']) == num_layers
            self.ws = params['ws']
            self.bs = params['bs']
        self.layer_outputs = []
        
    def forward(self, x):
        self.layer_outputs.clear()
        self.layer_outputs.append(x)
        for i in range(self.num_layers):
            x = np.tanh(x * self.ws[i] + self.bs[i])
            self.layer_outputs.append(x)
        return x
    
    def backward(self, y, lr):
        grads = [0 for i in range(self.num_layers)]
        grads_w = [0 for i in range(self.num_layers)]
        grads_b = [0 for i in range(self.num_layers)]
        grads[self.num_layers - 1] = -2 * (y - self.layer_outputs[self.num_layers])
        for i in range(self.num_layers - 2, - 1, -1):
            grads[i] = grads[i + 1] * (1 - self.layer_outputs[i + 2]**2) * self.ws[i + 1]
        for i in range(self.num_layers):
            grads_w[i] = grads[i] * (1 - self.layer_outputs[i + 1]**2) * self.layer_outputs[i]
            grads_b[i] = grads[i] * (1 - self.layer_outputs[i + 1]**2)
        for i in range(self.num_layers):
            self.ws[i] -= lr * grads_w[i]
            self.bs[i] -= lr * grads_b[i]
        self.layer_outputs.clear()
        return {'ws' : grads_w, 'bs' : grads_b}

    def get_params(self):
        return {'ws' : self.ws, 'bs' : self.bs}

In [3]:
class DNNFunc(torch.nn.Module):
    
    def __init__(self, num_layers):
        super(DNNFunc, self).__init__()
        layers = [[nn.Linear(1, 1), nn.Tanh()] for i in range(num_layers)]
        layers = [layer for layer_groups in layers for layer in layer_groups]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)
    
    def get_params(self):
        params = {}
        params['ws'] = []
        params['bs'] = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                if name.endswith('weight'):
                    params['ws'].append(param.data[0].item())
                elif name.endswith('bias'):
                    params['bs'].append(param.data[0].item())
        return params
    
    def get_grads(self):
        grads = {}
        grads['ws'] = []
        grads['bs'] = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                if name.endswith('weight'):
                    grads['ws'].append(param.grad.item())
                elif name.endswith('bias'):
                    grads['bs'].append(param.grad.item())
        return grads

In [4]:
class SimpleDataset(data.Dataset):

    def __init__(self, true_func):
        super(SimpleDataset, self).__init__()
        self.num_samples = 500
        xs = [random.uniform(-1, 1) for i in range(self.num_samples)]
        self.samples = [(torch.from_numpy(np.asarray([x], dtype=np.float32)), torch.from_numpy(np.asarray([true_func.forward(x)], dtype=np.float32))) for x in xs]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [13]:
num_layers = 1
lr = 0.1

true_func = OurFunc(num_layers)
true_params = {k : [round(x, 3) for x in v] for k, v in true_func.get_params().items()}
print('true params =                              ', true_params)

dnn_func = DNNFunc(num_layers)
our_func = OurFunc(num_layers, params = dnn_func.get_params())

dnn_params = {k : [round(x, 3) for x in v] for k, v in dnn_func.get_params().items()}
our_params = {k : [round(x, 3) for x in v] for k, v in our_func.get_params().items()}
print('[before training]              dnn params = {} our params = {} '.format(dnn_params, our_params))

dataset = SimpleDataset(true_func)
dataloader = DataLoader(dataset, shuffle=True, batch_size=1)
criterion = nn.MSELoss()
optimizer = optim.SGD(dnn_func.parameters(), lr=lr)
dnn_func.train()
batch_idx = 0
for _, batch in enumerate(dataloader):
    optimizer.zero_grad()
    xs = batch[0]
    ys = batch[1]
    out = dnn_func(xs)
    loss = criterion(out, ys)
    loss.backward()
    optimizer.step()
    our_func.forward(xs.item())
    our_grads = our_func.backward(ys.item(), lr)        
    batch_idx += 1
    if batch_idx % 50 == 0:
        dnn_params = {k : [round(x, 3) for x in v] for k, v in dnn_func.get_params().items()}
        our_params = {k : [round(x, 3) for x in v] for k, v in our_func.get_params().items()}
        print('[after batch {}] loss = {:.3f} dnn params = {} our params = {}'.format(batch_idx, loss.item(), dnn_params, our_params))

true params =                               {'bs': [0.766], 'ws': [-0.316]}
[before training]              dnn params = {'bs': [-0.4], 'ws': [0.366]} our params = {'bs': [-0.4], 'ws': [0.366]} 
[after batch 50] loss = 0.001 dnn params = {'bs': [0.732], 'ws': [-0.207]} our params = {'bs': [0.732], 'ws': [-0.207]}
[after batch 100] loss = 0.000 dnn params = {'bs': [0.746], 'ws': [-0.279]} our params = {'bs': [0.746], 'ws': [-0.279]}
[after batch 150] loss = 0.000 dnn params = {'bs': [0.763], 'ws': [-0.303]} our params = {'bs': [0.763], 'ws': [-0.303]}
[after batch 200] loss = 0.000 dnn params = {'bs': [0.765], 'ws': [-0.311]} our params = {'bs': [0.765], 'ws': [-0.311]}
[after batch 250] loss = 0.000 dnn params = {'bs': [0.766], 'ws': [-0.314]} our params = {'bs': [0.766], 'ws': [-0.314]}
[after batch 300] loss = 0.000 dnn params = {'bs': [0.766], 'ws': [-0.315]} our params = {'bs': [0.766], 'ws': [-0.315]}
[after batch 350] loss = 0.000 dnn params = {'bs': [0.766], 'ws': [-0.316]} our p