In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [9]:
# Training Loop 

def train(net, X, Y, loss_f, opt, iters=100, pp=10):

    if hasattr(net, 'iters_trained_for'):
        net.train(True)

        for i in range(iters):

            y = net(X)
            l = loss_f(y, Y)
            l.backward()

            opt.step()
            opt.zero_grad()

            if (i + 1) % pp == 0: print(f'Loss at epoch {i + 1}: {l}\n')

            if l == 0: break

        net.iters_trained_for += iters
        return l
    
    else: print('Error! Make sure the network has an \'iters_trained_for\' attribute'); return None

In [44]:
class SadNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.iters_trained_for = 0
        
        self.lin = nn.Linear(10, 10, bias=False)
        
    def forward(self, X):
        return self.lin(X)

In [89]:
torch.manual_seed(118)

X = torch.randn(20, 1, 10)
Y = X 

sadNet = SadNet()
loss_f = nn.MSELoss()
opt = optim.Adadelta(sadNet.parameters(), lr=0.1)

In [96]:
loss = train(sadNet, X, Y, loss_f, opt, iters=1000, pp=100)

Loss at epoch 100: 4.583590885315347e-12

Loss at epoch 200: 4.3782573053152696e-12

Loss at epoch 300: 4.289868373724692e-12

Loss at epoch 400: 4.210805315540966e-12

Loss at epoch 500: 4.106787392793576e-12

Loss at epoch 600: 4.080192780864245e-12

Loss at epoch 700: 4.067870606333512e-12

Loss at epoch 800: 4.078000524071479e-12

Loss at epoch 900: 4.0424356570478714e-12

Loss at epoch 1000: 4.061495497559298e-12



In [97]:
sadNet.iters_trained_for

6000

In [77]:
class MicroResNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.iters_trained_for = 0
        
        self.lin = nn.Linear(10, 10, bias=False)
        
    def forward(self, X):
        return self.lin(X) + X

In [78]:
torch.manual_seed(118)

# X1 = torch.randn(20, 1, 10)
# Y1 = X1 

mResNet = MicroResNet()
loss_f = nn.MSELoss()
opt1 = optim.Adam(mResNet.parameters(), lr=0.1)

In [83]:
loss1 = train(mResNet, X, Y, loss_f, opt1, iters=1000, pp=100)

Loss at epoch 100: 0.0006777478265576065

Loss at epoch 200: 0.0001802691404009238

Loss at epoch 300: 0.0008175352704711258

Loss at epoch 400: 1.369751407764852e-05

Loss at epoch 500: 8.7250693468377e-05

Loss at epoch 600: 0.0003115302824880928

Loss at epoch 700: 0.00036252086283639073

Loss at epoch 800: 2.193046748288907e-05

Loss at epoch 900: 9.196352266371832e-07

Loss at epoch 1000: 0.00045749536366201937

