In [1]:
import math
import torch

### Create some fake data to fit

In [2]:
batch_size = 400
input_dim = 200
output_dim = 10

x = torch.randn(batch_size, input_dim).cuda()
y = torch.randn(batch_size, output_dim).cuda()

### Construct residual-MLP module

In [3]:
block_depth = 2
num_blocks = 10
width = 100

from module.atomic import Identity, Linear, ScaledReLU

# create residual block
residue = Linear(width, width) @ ScaledReLU()
residue = residue ** block_depth
block = (1-1/num_blocks) * Identity() + 1/num_blocks * residue

# create whole network
first_layer = Linear(width, input_dim)
last_layer = Linear(output_dim, width)
net = last_layer @ block ** num_blocks @ first_layer

print(net)

Module of mass 22 and sensitivity 1.0.


### Run modular gradient descent

In [4]:
init_lr = 0.5
beta = 0.9
wd = 0.01
steps = 1000

net.initialize(device="cuda")

from tqdm.auto import trange
for i in (pbar := trange(steps)):
    out = net.forward(x)
    loss = (out-y).square().mean()
    loss.backward()
        
    net.update(init_lr * (1 - i / steps), beta, wd)
                
    pbar.set_description(f"loss: {loss.item():.4f}")

print(f"Final loss: {loss.item()}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Final loss: 1.1355193407780462e-07
