In [None]:
import torch

f = torch.tensor([[-1., -3., 4.], [-3., 3., -1.]])
target = torch.tensor([0, 1])
criterion = torch.nn.CrossEntropyLoss()
criterion(f, target)

tensor(2.5141)

In [None]:
from torch import nn
model = nn.LogSoftmax(dim = 1)
criterion = nn.NLLLoss()
criterion(model(f), target)

tensor(2.5141)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size = 5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size = 5)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [42]:
import torchvision

device = 'cuda'
train_set = torchvision.datasets.MNIST(root = 'train/', train = True, download = True)
train_input = train_set.data.view(-1, 1, 28, 28).float()
train_targets = train_set.targets

test_set = torchvision.datasets.MNIST(root = 'test/', train = False, download = True)
test_input = test_set.data.view(-1, 1, 28, 28).float()
test_targets = test_set.targets

lr, nb_epochs, batch_size, lambda_l2, lambda_l1 = 2e-1, 50, 100, 0., 0.00001

model = Net()

criterion = nn.CrossEntropyLoss()

model.to(device)
criterion.to(device)
train_input, train_targets = train_input.to(device), train_targets.to(device)
test_input, test_targets = test_input.to(device), test_targets.to(device)

mu, std = train_input.mean(), train_input.std()
train_input.sub_(mu).div_(std)
test_input.sub_(mu).div_(std)

for e in range(nb_epochs):
    if e % 5 == 0:
        lr /= 2
    optimizer = torch.optim.SGD(model.parameters(), lr = lr)

    for input, targets in zip(train_input.split(batch_size), train_targets.split(batch_size)):
        output = model(input)
        loss = criterion(output, targets)

        # L2 regularization
        for p in model.parameters():
            loss += lambda_l2 * p.pow(2).sum()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        # L1 regularization
        # with torch.no_grad():
        #     for p in model.parameters():
        #         p.sub_(p.sign() * p.abs().clamp(max = lambda_l1))

    
    with torch.no_grad():
        train_output = model(train_input)
        train_loss = criterion(train_output, train_targets)
        # print(f"Train input: {train_input.size()}, Test input: {test_input.size()}")
        test_output = model(test_input)
        test_preds = torch.argmax(test_output, dim=1)

        test_loss = criterion(test_output, test_targets)
        test_errors = torch.count_nonzero(test_preds - test_targets)
        test_accuracy = 1 - float(test_errors) / test_targets.size(0)

        non_zero_p = 0
        for p in model.parameters():
            non_zero_p += p.count_nonzero()

        print(f"Epoch {e}:, training loss = {train_loss}, validation_loss = {test_loss}, validation_accuracy = {test_accuracy:2F}, non_zero_p = {non_zero_p}")

Epoch 0:, training loss = 0.0802069678902626, validation_loss = 0.07143258303403854, validation_accuracy = 0.976700, non_zero_p = 105506
Epoch 1:, training loss = 0.046900298446416855, validation_loss = 0.04377318546175957, validation_accuracy = 0.985700, non_zero_p = 105506
Epoch 2:, training loss = 0.03512101620435715, validation_loss = 0.03788353130221367, validation_accuracy = 0.987700, non_zero_p = 105506
Epoch 3:, training loss = 0.02483223006129265, validation_loss = 0.031930096447467804, validation_accuracy = 0.989200, non_zero_p = 105506
Epoch 4:, training loss = 0.019239705055952072, validation_loss = 0.029932910576462746, validation_accuracy = 0.990200, non_zero_p = 105506
Epoch 5:, training loss = 0.014487617649137974, validation_loss = 0.02779056318104267, validation_accuracy = 0.990800, non_zero_p = 105506
Epoch 6:, training loss = 0.012567013502120972, validation_loss = 0.027384541928768158, validation_accuracy = 0.991100, non_zero_p = 105506
Epoch 7:, training loss = 0.