In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Simple CNN Model

In [6]:
class LeNet(nn.Module):

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s

        return num_features

    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)     # input image has 1 channel (grayscale), output has 6 channels, and the kernel is a 3x3 matrix
        self.conv2 = nn.Conv2d(6, 16, 3)    # outputs 16 feature maps per image, 3x3 kernel 
        self.fc1 = nn.Linear(16*6*6, 120)   # 16*6*6 input dimension after linearizing the feature maps, output 120
        self.fc2 = nn.Linear(120, 84)       
        self.fc3 = nn.Linear(84, 10)        # output 10, each for a numerical digit

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        # linearize the feature maps for the feed forward layer
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


### Test the forward pass

In [16]:
model = LeNet()
print(model)

input = torch.rand(1, 1, 32, 32)      # 1 image of 32 X 32 px single-channel (grayscla) image
print("Input shape: {}".format(input.shape))

output = model(input)       # pass the input through the model and get the result
print("Output: {}".format(output))      # model.forward() is not explicitly called


LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
Input shape: torch.Size([1, 1, 32, 32])
Output: tensor([[-0.0161,  0.0063, -0.0261, -0.1041,  0.0186, -0.0553,  0.1102,  0.0265,
          0.0832, -0.0244]], grad_fn=<AddmmBackward0>)


### Training on a real dataset

In [24]:
%matplotlib inline

import torchvision
import torchvision.transforms as transforms

In [25]:
# convert loaded images to equivalent tensors using ToTensor()
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])