<a href="https://colab.research.google.com/github/varunSabnis/pytorch_course_udemy/blob/master/MNIST_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install torch torchvision

In [0]:
import torch
from torchvision import datasets, transforms
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt

In [0]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [0]:
transform = transforms.Compose([transforms.Resize((28,28)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))]) # mean and standard deviation of 0.5 for channel 1, here image has only one channel
training_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
validation_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

training_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=100, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=100, shuffle=False)

In [0]:
def image_convert(im_tensor):
  image = im_tensor.cpu().clone().detach().numpy() 
  image = image.transpose(1, 2, 0) # Change shape of image from 1*28*28 to 28*28*1
  image = image*(np.array((0.5, 0.5, 0.5))) + np.array((0.5, 0.5, 0.5)) 
  """
  y = (x - mean)/std
  x = y*std + mean  -- this will get back original image that was normalized
  """
  image = image.clip(0, 1) # Keep every pixel between 0 and 1
  return image

In [0]:
dataiter = iter(training_loader)
images, labels = dataiter.next()
fig = plt.figure(figsize=(25, 4))
for i in np.arange(20):
  ax = fig.add_subplot(2, 10, i+1)
  plt.imshow(image_convert(images[i]))
  ax.set_title(labels[i].item())


In [0]:
class LeNet(nn.Module):
  def compute_out_size(self, num_prev_layer_features, num_filters, kernel_size, stride=1, num_prev_layer_channels=1):
    return(int(((num_prev_layer_features - kernel_size)/stride + 1)/2), num_filters)

  def __init__(self, img_dim, filters, kernel_sizes, fc_layers, im_channels=1, pool_size=2, stride=1):
    self.pool_size = pool_size
    super().__init__()
    self.conv1 = nn.Conv2d(im_channels, filters[0], kernel_sizes[0], 1)
    conv1_out_size, _ = self.compute_out_size(img_dim, filters[0], kernel_sizes[0])
    self.conv2 = nn.Conv2d(filters[0], filters[1], kernel_sizes[1], 1)
    conv2_out_size, _ = self.compute_out_size(conv1_out_size, filters[1], kernel_sizes[1])
    self.fc_layer_input = conv2_out_size*conv2_out_size*filters[1]
    self.fc1 = nn.Linear(conv2_out_size*conv2_out_size*filters[1], fc_layers[0])
    self.dropout1 = nn.Dropout(0.5)
    self.fc2 = nn.Linear(fc_layers[0], fc_layers[1])

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, self.pool_size, self.pool_size)
    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, self.pool_size, self.pool_size)
    x = x.view(-1, self.fc_layer_input)
    x = F.relu(self.fc1(x))
    x = self.dropout1(x)
    x = self.fc2(x)
    return x

In [0]:
def get_image_shape(images):
  image = image_convert(images[0])
  return image.shape[0], image.shape[1]

In [0]:
im_shape_x, im_shape_y = get_image_shape(images)
# For nxn type images and kernels with no extra  padding  
model = LeNet(im_shape_x, filters=[20, 50], kernel_sizes=[5,5], fc_layers=[500,10]).to(device) 
model 

In [0]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

In [0]:
epochs = 12
running_loss_history = []
running_correct_history = []
val_running_loss_history = []
val_running_correct_history = []

print("len of training loader {}".format(len(training_loader)))
print("len of validation loader {}".format(len(validation_loader)))
for e in range(epochs):
  running_loss = 0.0
  running_corrects = 0.0
  validation_running_loss = 0.0
  validation_running_corrects = 0.0
  
  for inputs, labels in training_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    _, preds = torch.max(outputs, 1)
    running_corrects += torch.sum(preds == labels.data)
    running_loss += loss.item()
  with torch.no_grad():
    for val_inputs, val_labels in validation_loader:
      val_inputs = val_inputs.to(device)
      val_labels = val_labels.to(device)
      val_outputs = model(val_inputs)
      val_loss = criterion(val_outputs, val_labels)
      _, val_preds = torch.max(val_outputs, 1)
      validation_running_corrects += torch.sum(val_preds == val_labels.data)
      validation_running_loss += val_loss.item()

  epoch_loss = running_loss/(len(training_loader))
  acc = running_corrects/(len(training_loader))
  val_epoch_loss = validation_running_loss/(len(validation_loader))
  val_acc = validation_running_corrects/(len(validation_loader))

  val_running_loss_history.append(val_epoch_loss)
  running_loss_history.append(epoch_loss)
  running_correct_history.append(acc)
  val_running_correct_history.append(val_acc)

  print("training loss : {:.4f} training accuracy : {:.2f}".format(epoch_loss, acc))
  print("Validation loss : {:.4f} Validation accuracy : {:.2f}".format(val_epoch_loss, val_acc))

In [0]:
dataiter = iter(validation_loader)
images, labels = dataiter.next()
images = images.to(device)
labels = labels.to(device)
output = model(images)
_, preds = torch.max(output, 1)

fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
  ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
  plt.imshow(image_convert(images[idx]))
  ax.set_title("{};[{}]".format(str(preds[idx].item()), str(labels[idx].item())), color=("green" if preds[idx].item() == labels[idx].item() else "red"))

In [0]:
plt.plot(running_loss_history, label="training loss")
plt.plot(val_running_loss_history, label="validation loss")
plt.legend()

In [0]:
plt.plot(running_correct_history, label="training accuracy")
plt.plot(val_running_correct_history, label="validation accuracy")