In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import numpy as np

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader, random_split

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

learning_rate = 0.001
training_epochs = 10
batch_size = 100

mnist_train = datasets.MNIST(root='MNIST_data/', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root='MNIST_data/', train=False, transform=transforms.ToTensor(), download=True)
data_size = len(mnist_train)
train_size = int(data_size * 0.8)
validation_size = int(data_size * 0.2)

train_dataset, validation_dataset = random_split(mnist_train, [train_size, validation_size])

dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=20, shuffle=True, drop_last=True)
test_dataloader = DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 77879596.82it/s]


Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 75709808.64it/s]


Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 39950156.24it/s]


Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3475105.58it/s]

Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw






In [3]:
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    # First layer
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    # Second layer
    self.layer2 = nn.Sequential(
        nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # Fully Connected Layer
    self.fc = nn.Linear(7*7*64, 10, bias=True)

    # init weights
    nn.init.xavier_uniform_(self.fc.weight)

  def forward(self,x):
    out = self.layer1(x)
    out = self.layer2(out)
    out = out.view(out.size(0), -1)
    out.size(1)
    out = self.fc(out)
    return out

  

In [4]:
class EarlyStopping:
  def __init__(self, patience=3, verbose=False, delta=0, path='checkpoint.pt'):
    self.patience = patience
    self.verbose = verbose
    self.counter = 0
    self.best_score = None
    self.early_stop = False
    self.val_loss_min = np.Inf
    self.delta = delta
    self.path = path

  def __call__(self, val_loss, model):

    score = -val_loss
    if self.best_score is None:
      self.best_score = score
      self.save_checkpoint(val_loss, model)
    elif score < self.best_score + self.delta:
      self.counter +=1
      print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
      if self.counter >= self.patience:
        self.early_stop = True
    else:
      self.best_score = score
      self.save_checkpoint(val_loss, model)
      self.counter = 0

  def save_checkpoint(self, val_loss, model):
    if self.verbose:
      print(f"Validation loss decreased({self.val_loss_min:.6f} --> {val_loss: .6f}). Saving model..")
    torch.save(model.state_dict(), self.path)
    self.val_loss_min = val_loss

In [6]:
model = CNN().to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [9]:
train_losses = []
valid_losses = []
avg_train_losses = []
avg_valid_losses = []
early_stopping = EarlyStopping(patience = 3, verbose=True)

for epoch in range(training_epochs):
  avg_cost = 0
  model.train()
  for X, Y in dataloader:
    X = X.to(device)
    Y = Y.to(device)

    optimizer.zero_grad()
    hypothesis = model(X)
    cost = criterion(hypothesis, Y)
    cost.backward()
    optimizer.step()
    avg_cost += cost / len(dataloader)

  print('[Epoch: {:>2}] cost = {:>.9}'.format(epoch+1, avg_cost))

  model.eval()
  with torch.no_grad():
    correct = 0
    total = 0
    for x, y in validation_dataloader:
      output = model(x)
      loss = criterion(output, y)
      valid_losses.append(loss.item())

      _,predict = torch.max(output, 1)

      total += y.size(0)
      correct += (predict==y).sum()

    train_loss = np.average(train_losses)
    valid_loss = np.average(valid_losses)
    avg_train_losses.append(train_loss)
    avg_valid_losses.append(valid_loss)

    train_losses = []
    valid_losses = []

    early_stopping(valid_loss, model)

    if early_stopping.early_stop:
      print("Early stopping")
      break
  model.load_state_dict(torch.load('checkpoint.pt'))
  
  print("Validation Accuracy : {}".format(correct/total))

[Epoch:  1] cost = 0.0281698089
Validation loss decreased(inf -->  0.044137). Saving model..
Validation Accuracy : 0.9865000247955322
[Epoch:  2] cost = 0.0234814044
EarlyStopping counter: 1 out of 3
Validation Accuracy : 0.9863333106040955
[Epoch:  3] cost = 0.0242797583
EarlyStopping counter: 2 out of 3
Validation Accuracy : 0.9862499833106995
[Epoch:  4] cost = 0.0250254609
EarlyStopping counter: 3 out of 3
Early stopping
Best Validation Accuracy : 0.9857500195503235


In [8]:
with torch.no_grad():
    X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
    Y_test = mnist_test.test_labels.to(device)

    prediction = model(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print('Accuracy:', accuracy.item())




Accuracy: 0.9871000051498413
