<a href="https://colab.research.google.com/github/pgurazada/fast-fast-ai/blob/master/cnn_official_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [0]:
class Net(nn.Module):
  
  def __init__(self):
    
    super(Net, self).__init__()
    # 1 input image channel, 6 output channels, 5x5 square convolution kernel
    self.conv1 = nn.Conv2d(1, 6, 5)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16*5*5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
    
  def forward(self, x):
    
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    
    return x
  
  def num_flat_features(self, x):
    size = x.size()[1:]

    num_features = 1
    
    for s in size:
      num_features *= s
    
    return num_features

In [0]:
net = Net()

In [0]:
print(net)

An alternative and probably more Keras like implementation is below

In [0]:
class ConvNet(nn.Module):
  
  def __init__(self):
    
    super(ConvNet, self).__init__()
    
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 5),
                               nn.MaxPool2d((2, 2)))
    
    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5),
                               nn.MaxPool2d((2, 2)))
    
    self.fc1 = nn.Sequential(nn.Linear(16*5*5, 120),
                             nn.ReLU())
    
    self.fc2 = nn.Sequential(nn.Linear(120, 84),
                             nn.ReLU())
    
    self.fc3 = nn.Sequential(nn.Linear(84, 10), 
                             nn.Sigmoid())
    
  def forward(self, x):
    
    out = self.conv1(x)
    out = self.conv2(out)
    out = out.view(-1, self.num_flat_features(x))
    out = self.fc1(out)
    out = self.fc2(out)
    out = self.fc3(out)
    
    return out
  
  def num_flat_features(self, x):
    size = x.size()[1:]

    num_features = 1
    
    for s in size:
      num_features *= s
    
    return num_features

In [0]:
conv_net = ConvNet()

In [0]:
print(conv_net)