In [None]:
!pip install biotorch



LeNet on MNIST dataset using FA/DFA

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


from biotorch.benchmark.run import Benchmark
from biotorch.module.biomodule import BioModule

In [None]:
# initial parameters
n_epochs = 100
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = True
torch.manual_seed(random_seed)

<torch._C.Generator at 0x78a4af183870>

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                                torchvision.transforms.Resize((32,32)),
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                                torchvision.transforms.Resize((32,32)),
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

print("Train size: ", len(train_loader.dataset))
print("Test size: ", len(test_loader.dataset))

Train size:  60000
Test size:  10000


In [None]:
class LeNet(nn.Module):
  def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1)
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(in_features=120, out_features=84)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=84, out_features=10)

  def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.pool1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.pool2(out)
        out = self.conv3(out)
        out = self.relu3(out)
        out = torch.flatten(out, 1)
        out = self.fc1(out)
        out = self.relu4(out)
        out = self.fc2(out)
        return out

In [None]:
lenet = LeNet()

lenet_fa = BioModule(lenet, mode='fa', output_dim=10)
print(lenet_fa)

Module has been converted to fa mode:

The layer configuration was:  {'type': 'fa', 'options': {'constrain_weights': False, 'gradient_clip': False, 'init': 'xavier'}}
- All the 3 <class 'torch.nn.modules.conv.Conv2d'> layers were converted successfully.
- All the 2 <class 'torch.nn.modules.linear.Linear'> layers were converted successfully.
BioModule(
  (module): LeNet(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (relu1): ReLU()
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (relu2): ReLU()
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
    (relu3): ReLU()
    (fc1): Linear(in_features=120, out_features=84, bias=True)
    (relu4): ReLU()
    (fc2): Linear(in_features=84, out_features=10, bias=True)
  )
)


In [None]:
from torch.optim import SGD

optimizer = SGD(lenet_fa.parameters(), lr=0.01, momentum=0.9, weight_decay=10e-3)

In [None]:
from torch.optim.lr_scheduler import MultiStepLR

lr_schedule = MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)

In [None]:
from torch.nn import CrossEntropyLoss

loss = CrossEntropyLoss()

In [None]:
metric_config = {
    'top_k': 5,
    'display_iterations': 100,  # Display metrics every 100 iterations
    'layer_alignment': False,     # Enable layer-wise alignment
    'weight_ratio': False      # Set weight ratio for alignment regularization
}

In [None]:
from biotorch.training.trainer import Trainer

trainer = Trainer(model=lenet_fa,
                  mode='fa',
                  loss_function=loss,
                  optimizer=optimizer,
                  train_dataloader=train_loader,
                  val_dataloader=test_loader,
                  device='cpu',
                  epochs=n_epochs,
                  output_dir='./output',
                  lr_scheduler=lr_schedule,
                  metrics_config=metric_config)


In [None]:
trainer.run()

Epoch: [0][  0/938]	Time  0.360 ( 0.360)	Data  0.128 ( 0.128)	Loss 2.2958e+00 (2.2958e+00)	Acc@1  14.06 ( 14.06)	Acc@5  51.56 ( 51.56)
Epoch: [0][100/938]	Time  0.036 ( 0.043)	Data  0.018 ( 0.021)	Loss 1.2646e+05 (7.3957e+03)	Acc@1   7.81 ( 17.13)	Acc@5  54.69 ( 59.41)
Epoch: [0][200/938]	Time  0.039 ( 0.041)	Data  0.020 ( 0.021)	Loss 9.4318e+08 (1.6610e+08)	Acc@1   6.25 ( 13.53)	Acc@5  46.88 ( 54.70)
Epoch: [0][300/938]	Time  0.038 ( 0.043)	Data  0.019 ( 0.022)	Loss 2.7773e+09 (1.2701e+09)	Acc@1   6.25 ( 12.66)	Acc@5  75.00 ( 53.52)
Epoch: [0][400/938]	Time  0.040 ( 0.042)	Data  0.020 ( 0.021)	Loss 2.2509e+09 (1.9330e+09)	Acc@1  25.00 ( 13.92)	Acc@5  71.88 ( 56.57)
Epoch: [0][500/938]	Time  0.038 ( 0.042)	Data  0.020 ( 0.021)	Loss 4.4632e+07 (1.7195e+09)	Acc@1  34.38 ( 16.98)	Acc@5  81.25 ( 61.77)
Epoch: [0][600/938]	Time  0.040 ( 0.043)	Data  0.020 ( 0.022)	Loss 5.8710e+06 (1.4359e+09)	Acc@1  32.81 ( 18.41)	Acc@5  81.25 ( 63.82)
Epoch: [0][700/938]	Time  0.038 ( 0.042)	Data  0.020 ( 

In [None]:
from biotorch.training.functions import test

test(lenet_fa, loss, test_loader, device='cpu', top_k=5)

 * Acc@1 78.900 Acc@5 97.800


(tensor(78.9000), 1.0849393248558044)