### Goal
This code example shows you how to train your model with an L1 or L2 penalty on all or some of its weights. Regularization is particularly important when training models on smaller datasets where models are more likely to overfit the training data. Here weight regularization can help to prevent overfitting and increase the likelihood of the model generalizing to new unseen data.

In [5]:
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import devtorch
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
class ANNClassifier(devtorch.DevModel):
    
    def __init__(self, n_in, n_hidden, n_out):
        super().__init__()
        self.layer1 = nn.Linear(n_in, n_hidden, bias=False)
        self.layer2 = nn.Linear(n_hidden, n_out, bias=False)
        self.init_weight(self.layer1.weight, "glorot_uniform")
        self.init_weight(self.layer2.weight, "glorot_uniform")
    
    def forward(self, x):
        x = F.leaky_relu(self.layer1(x.flatten(1, 3)))
        return F.leaky_relu(self.layer2(x))
    
    # We create a function for the trainer to query the list of weights to regularize
    def get_params_to_regularize(self):    
        return [self.layer1.weight, self.layer2.weight]  # <= Add any number of weights you need

In [3]:
model = ANNClassifier(784, 4000, 10)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST("../../data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("../../data", train=False, download=True, transform=transform)

regulirization_lambda = 10 ** -6

def loss(output, target, model):
    classification_loss = F.cross_entropy(output, target.long())
    regulirization_loss = 0
    for param in model.get_params_to_regularize():
        regulirization_loss = regulirization_loss + regulirization_lambda * torch.norm(param, p=1)  # change p=2 for L2 penalty
    
    return classification_loss + regulirization_loss

trainer = devtorch.get_trainer(loss, model=model, train_dataset=train_dataset, n_epochs=10, batch_size=128, lr=0.001, device="cuda")
trainer.train()

INFO:trainer:Completed epoch 0 with loss 161.70190712809563 in 7.9183s
INFO:trainer:Completed epoch 1 with loss 66.76015388965607 in 7.8260s
INFO:trainer:Completed epoch 2 with loss 49.29641507565975 in 7.8815s
INFO:trainer:Completed epoch 3 with loss 42.48477016761899 in 7.8142s
INFO:trainer:Completed epoch 4 with loss 40.64281286671758 in 7.8393s
INFO:trainer:Completed epoch 5 with loss 41.481686882674694 in 7.8129s
INFO:trainer:Completed epoch 6 with loss 40.89248041063547 in 7.8094s
INFO:trainer:Completed epoch 7 with loss 36.84841175749898 in 7.8102s
INFO:trainer:Completed epoch 8 with loss 35.3649190030992 in 7.8085s
INFO:trainer:Completed epoch 9 with loss 33.943525440990925 in 7.8092s


In [4]:
def eval_metric(output, target):
    return (torch.max(output, 1)[1] == target).sum().cpu().item()

scores = devtorch.compute_metric(model, test_dataset, eval_metric, batch_size=256)
print(f"Accuracy = {torch.Tensor(scores).sum()/len(test_dataset)}")

Accuracy = 0.9786999821662903
