In [2]:
import torch
import torch.nn as nn

class BasicBlock(nn.Module):
  def __init__(self, in_channels, out_channels, hidden_dim):
    super(BasicBlock, self).__init__()

    self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1)
    self.relu = nn.ReLU()

    self.pool = nn.ManPool2d(kernel_size=2, stride=2)

  def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.relu(x)
    x = self.pool(x)

    return x

In [4]:
class CNN(nn.Module):
  def __init__(self, num_classes):
    super(CNN, self).__init__()

    self.block1 = BasicBlock(in_channels=3, out_channels = 32, hidden_dim = 16)
    self.block2 = BasicBlock(in_channels=32, out_channels=128, hidden_dim=64)
    self.block3 = BasicBlock(in_channels=128, out_channles=256, hidden_dim=128)

    self.fc1 = nn.Linear(in_features=4096, out_features=2048)
    self.fc2 = nn.Linear(in_features=2048, out_features=256)
    self.fc3 = nn.Linear(in_features=256, out_features=num_classes)

    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = torch.flatten(x, start_dim=1)

    x = self.fc1(x)
    x = self.relu()
    x = self.fc2(x)
    x = self.relu()
    x = self.fc3(x)

    return x