construct and train model -
multi-layer perceptron neural network model implementation on MNIST data

In [37]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# define the model
class MLPerceptron(nn.Module):
    def __init__(self):
        super(MLPerceptron, self).__init__()
        input_size = 28 * 28
        hidden_size = 128
        output_size = 10
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, hidden_size)
        self.layer3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = self.layer3(x)
        return x

# define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# download and load the data
dataset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)

# split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))  # 80% for training
valid_size = len(dataset) - train_size  # 20% for validation
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

# create training and validation data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

# instantiate the model, loss criterion, and optimizer
model = MLPerceptron()
loss_criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.8)

print(model)

# number of epochs
n_epochs = 100

# training and validation
for epoch in range(n_epochs):
    # training
    model.train()
    train_loss = 0.0
    for inputs, labels in train_loader:
        # zero the parameter gradients
        optimizer.zero_grad()

        # reshape the inputs and forward pass
        inputs = inputs.view(inputs.shape[0], -1)
        outputs = model(inputs)

        # transform labels to match the output shape
        labels = nn.functional.one_hot(labels, num_classes=10).float()

        # calculate loss
        loss = loss_criterion(outputs, labels)
        train_loss += loss.item()

        # backward pass and optimization
        loss.backward()
        optimizer.step()

    # calculate average loss over an epoch
    train_loss = train_loss / len(train_loader)

    print('epoch: {} \ttraining loss: {:.6f}'.format(epoch+1, train_loss))

MLPerceptron(
  (layer1): Linear(in_features=784, out_features=128, bias=True)
  (layer2): Linear(in_features=128, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=10, bias=True)
)
epoch: 1 	training loss: 0.027357
epoch: 2 	training loss: 0.013101
epoch: 3 	training loss: 0.010187
epoch: 4 	training loss: 0.008643
epoch: 5 	training loss: 0.007635
epoch: 6 	training loss: 0.006886
epoch: 7 	training loss: 0.006331
epoch: 8 	training loss: 0.005852
epoch: 9 	training loss: 0.005450
epoch: 10 	training loss: 0.005123
epoch: 11 	training loss: 0.004812
epoch: 12 	training loss: 0.004565
epoch: 13 	training loss: 0.004350
epoch: 14 	training loss: 0.004189
epoch: 15 	training loss: 0.003924
epoch: 16 	training loss: 0.003790
epoch: 17 	training loss: 0.003628
epoch: 18 	training loss: 0.003454
epoch: 19 	training loss: 0.003315
epoch: 20 	training loss: 0.003202
epoch: 21 	training loss: 0.003060
epoch: 22 	training loss: 0.002965
epoch: 23 	training loss: 0.0

model validation

In [38]:
# validation
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in valid_loader:
        inputs = inputs.view(inputs.shape[0], -1)
        outputs = model(inputs)
        
        # get the predicted class for each sample in the batch
        _, predicted = torch.max(outputs, 1)
        
        # count total number of labels and correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# calculate the percentage of correct predictions
accuracy = correct / total * 100

print('model accuracy: {:.2f}%'.format(accuracy))

model accuracy: 97.43%
