In [2]:
import math

from matplotlib import pyplot as plt
%matplotlib qt5

import torch
import torch.nn as nn

from torch.optim.rmsprop import RMSprop
# from torch.optim.optimizer import Optimizer
from torch.optim import Adam

from tropix_linear import TropixLinear



In [9]:
class Module1D(nn.Module):
    def plot(self, ax, start, stop, num: int):
        training = self.training
        
        self.eval()

        with torch.no_grad():
            y = torch.linspace(start, stop, num).view(-1, 1)
            z = self.forward(y)
            ax.plot(y.numpy(), z.detach().numpy(), '-')

        self.train(training)


class TrueFunc(Module1D):
    dim_in = 1
    dim_out = 1

    def forward(self, y) -> torch.Tensor:
        # return y ** 3
        return 3 * y ** 2 - torch.sin(y * 5) * (y - 0.2)


class Func(Module1D):
    def __init__(self, dim_in, dim_out):
        super(Func, self).__init__()

        self.net = nn.Sequential(
            TropixLinear(dim_in, 50, True, False),
            nn.ReLU(),
            TropixLinear(50, 50, True, True),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            TropixLinear(50, 50, True, True),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            TropixLinear(50, dim_out, True, False)
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.constant_(m.bias, val=0)

    def forward(self, y):
        return self.net(y)
    
    def switch_to_tropix(self, behave_as_tropix=True):
        for layer in self.net:
            if isinstance(layer, TropixLinear):
                layer.behave_as_tropix = behave_as_tropix


In [10]:
func = Func(TrueFunc.dim_in, TrueFunc.dim_out)
# func.switch_to_tropix(True)

optimizer = Adam(func.parameters())

losses = []

fig = plt.figure(figsize=(12, 4), facecolor='white')  # type: plt.Figure
ax = fig.add_subplot(111)

for i in range(2500):
    y = torch.linspace(-1.0, 1.0, 15).view(-1, 1)
    y = y + torch.randn_like(y) * (1 - math.tanh(10.0 * i / 2500.0))
    z = TrueFunc().forward(y)

#     if i > 10:
#         func.switch_to_tropix(i % 2 == 0)

    func.train()
    z_ = func(y)
    loss = torch.nn.functional.mse_loss(z, z_)  # type: torch.Tensor
    losses.append(loss.detach().item())

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    if i % 10 == 0:
#         with torch.no_grad():
#             y = torch.linspace(-1.0, 1.0, 100).view(-1, 1)
#             z = TrueFunc().forward(y)
#             z_ = func(y)

        ax.cla()
        TrueFunc().plot(ax, -1.0, 1.0, 100)
        func.plot(ax, -1.0, 1.0, 100)
#         ax.plot(y.detach().numpy(), z.detach().numpy(), '-')
#         ax.plot(y.detach().numpy(), z_.detach().numpy(), '-')
        ax.set_title(f'{i}')

        plt.draw()
        plt.pause(0.05)



In [None]:
# func.switch_to_tropix(True)

fig = plt.figure()
ax = fig.add_subplot(111)

TrueFunc().plot(ax, -1.0, 1.0, 100)
func.plot(ax, -1.0, 1.0, 100)

plt.draw()