In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import sys, copy, os, shutil, time
from tqdm.notebook import tqdm

In [58]:
'''
2 modules: 14110 parameters.
3 modules: 5130 parameters.
4 modules: 2750 parameters.
'''
class MNIST_CNN(nn.Module):
    def __init__(self, num_modules):
        super().__init__()
        
        # record our param
        self.num_modules = num_modules
        
        # first conv + pool module
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        
        # add second conv + pool module ONLY IF num_modules >= 2
        if num_modules >= 2:
            self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, groups=32, padding=1)
            self.conv3 = nn.Conv2d(in_channels=32, out_channels=20, kernel_size=1)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)

         # add third conv + pool module ONLY IF num_modules >= 3
        if num_modules >= 3:
            self.conv4 = nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, groups=20, padding=1)
            self.conv5 = nn.Conv2d(in_channels=20, out_channels=20, kernel_size=1)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)

         # add fourth conv + pool module ONLY IF num_modules >= 4
        if num_modules >= 4:
            self.conv6 = nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, groups=20, padding=1)
            self.conv7 = nn.Conv2d(in_channels=20, out_channels=20, kernel_size=1)
            self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4, ceil_mode=True)

        # final linear output layer
        if num_modules == 2:
            self.linear = nn.Linear(in_features=1280, out_features=10)
        elif num_modules == 3:
            self.linear = nn.Linear(in_features=320, out_features=10)
        elif num_modules == 4:
            self.linear = nn.Linear(in_features=20, out_features=10)

    # governs the forward-pass
    def forward(self, x):
        
        # always do this part
        x = self.conv1(x)
        x = self.pool1(x)
        x = F.relu(x)

        # checkpointing just like constructor
        if self.num_modules >= 2:
            x = self.conv2(x)
            x = self.conv3(x)
            x = self.pool2(x)
            x = F.relu(x)
        if self.num_modules >= 3:
            x = self.conv4(x)
            x = self.conv5(x)
            x = self.pool3(x)
            x = F.relu(x)
        if self.num_modules >= 4:
            x = self.conv6(x)
            x = self.conv7(x)
            x = self.pool4(x)
            x = F.relu(x)
        
        # reshape based on the number of convolutional layers
        if self.num_modules == 2:
            x = x.reshape(-1, 1280)
        elif self.num_modules == 3:
            x = x.reshape(-1, 320)
        elif self.num_modules == 4:
            x = x.reshape(-1, 20)
        
        # apply our final output layer
        x = self.linear(x)

        return x

In [62]:
sum(p.numel() for p in MNIST_CNN(num_modules=4).parameters())

2750