In [63]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import random_split, TensorDataset, DataLoader

In [64]:
seed = 123
dtype = torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
torch.set_default_device(device)
torch.manual_seed(seed)
generator = torch.Generator(device=device).manual_seed(seed)

In [65]:
dataset = torchvision.datasets.MNIST(root = './cnn_dataset',
                                               train = True,
                                               transform = transforms.ToTensor(),
                                               download = True)
test_dataset = torchvision.datasets.MNIST(root = './cnn_dataset',
                                               train = False,
                                               transform = transforms.ToTensor(),
                                               download = True)

total_size = len(dataset)
train_size = int(total_size * 0.8)
val_size = int(total_size * 0.2)

train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, generator=generator)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, generator=generator)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, generator=generator)

In [66]:
class LeNet(nn.Module):

    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        self.layers = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5, padding=2),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.LazyLinear(120),
            nn.Sigmoid(),
            nn.LazyLinear(84),
            nn.Sigmoid(),
            nn.LazyLinear(num_classes)
        )

    def forward(self, x):
        return self.layers(x)

In [67]:
epoch = 20
print_per_epoch = 2

def accuracy(model, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for _, data in enumerate(test_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device) 
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)  
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return correct / total

def eval(model, loss_fn, val_loader):
    running_loss = 0
    model.eval()

    with torch.no_grad():
        for _, data in enumerate(val_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device) 

            y_pred = model(inputs)

            loss = loss_fn(y_pred, labels)
            running_loss += loss.item()
    return running_loss / len(val_loader)


def train(model, loss_fn, optimizer, train_loader, val_loader):
    model.train()

    for i in range(epoch):
        running_loss = 0
        for _, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device) 

            optimizer.zero_grad()

            y_pred = model(inputs)
            loss = loss_fn(y_pred, labels)
            running_loss += loss.item()

            loss.backward()

            optimizer.step()
        train_loss_avg = running_loss / len(train_loader)
        val_loss_avg = eval(model, loss_fn, val_loader)
        
        if (i+1)%print_per_epoch == 0:
            print('LOSS train {} valid {}'.format(train_loss_avg, val_loss_avg))

In [68]:
output_size = 10
learning_rate = 1e-3

model = LeNet(output_size)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

train(model, criterion, optimizer, train_loader, val_loader)
acc = accuracy(model, test_loader)

print('ACC {}'.format(acc))

LOSS train 0.24041613482435545 valid 0.20753994064529738
LOSS train 0.11815076679239671 valid 0.11709857572366794
LOSS train 0.08561528619968643 valid 0.08392030773932735
LOSS train 0.07073091052131106 valid 0.07797321507707239
LOSS train 0.06242269573515902 valid 0.07903533106918137
LOSS train 0.058495379920428 valid 0.07341462381867071
LOSS train 0.05421706466570807 valid 0.06324978337685268
LOSS train 0.052317071638302876 valid 0.06130129466826717
LOSS train 0.05118797958549112 valid 0.054921136902024346
LOSS train 0.04766035318467766 valid 0.07479340862234433
ACC 0.981
