In [1]:
import torch
import torch.nn as nn
import numpy as np
import torchvision.datasets as dataset
import torchvision.transforms as transform
from torch.utils.data import DataLoader

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
datasetPath = "./drive/MyDrive/dataset/"
parameterPath = "./drive/MyDrive/parameters"

In [4]:
mnist_train = dataset.MNIST(root = datasetPath,
                            train = True,
                            transform = transform.ToTensor(),
                            download = True)

mnist_test = dataset.MNIST(root = datasetPath,
                           train = False,
                           transform = transform.ToTensor(),
                           download = True)

In [5]:
class LeNet5(nn.Module):
  def __init__(self):
    super(LeNet5, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
    self.fc1 = nn.Linear(in_features=256, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=84)
    self.fc3 = nn.Linear(in_features=84, out_features=10)
    self.relu = nn.ReLU()
    self.maxPool2d = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, x):
    y = self.relu(self.conv1(x))
    y = self.maxPool2d(y)
    y = self.relu(self.conv2(y))
    y = self.maxPool2d(y)

    y = y.view(-1, 256)
    y = self.relu(self.fc1(y))
    y = self.relu(self.fc2(y))
    y = self.fc3(y)
    return y

In [6]:
batch_size = 100
learning_rate = 0.1
training_epochs = 15
loss_function = nn.CrossEntropyLoss()
network = LeNet5()
optimizer = torch.optim.SGD(network.parameters(), lr = learning_rate)

data_loader = DataLoader(dataset = mnist_train,
                         batch_size = batch_size,
                         shuffle = True,
                         drop_last = True)

In [7]:
for epoch in range(training_epochs):
  avg_cost = 0
  total_batch = len(data_loader)

  for img, label in data_loader:
    pred = network(img)

    loss = loss_function(pred, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    avg_cost += loss / total_batch

  print('Epoch: %d Loss = %f' %(epoch+1, avg_cost))
print('Learning finished')

Epoch: 1 Loss = 0.593568
Epoch: 2 Loss = 0.106955
Epoch: 3 Loss = 0.074221
Epoch: 4 Loss = 0.058942
Epoch: 5 Loss = 0.050926
Epoch: 6 Loss = 0.041762
Epoch: 7 Loss = 0.035533
Epoch: 8 Loss = 0.031363
Epoch: 9 Loss = 0.028650
Epoch: 10 Loss = 0.025161
Epoch: 11 Loss = 0.021964
Epoch: 12 Loss = 0.019374
Epoch: 13 Loss = 0.018154
Epoch: 14 Loss = 0.016893
Epoch: 15 Loss = 0.015241
Learning finished


In [8]:
torch.save(network.state_dict(), parameterPath+'LeNet5_mnist.pth')

In [9]:
new_network = LeNet5()
new_network.load_state_dict(torch.load(parameterPath+'LeNet5_mnist.pth'))

<All keys matched successfully>

In [10]:
#test
with torch.no_grad():
    img_test = mnist_test.data.float().unsqueeze(1)
    label_test = mnist_test.targets

    prediction = network(img_test)

    correct_prediction = torch.argmax(prediction, 1) == label_test
    accuracy = correct_prediction.float().mean()
    print("Accuracy:", accuracy.item())

Accuracy: 0.9876999855041504
