# CNN with MNIST 

Goal: Implement LeNet with PyTorch, but with own Convolution and Subsampling/Pooling Function

In [1]:
from sklearn.metrics import classification_report, confusion_matrix 
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.utils import check_random_state
import numpy as np 
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms
from tqdm.auto import tqdm

In [2]:
# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 128
# percentage of training set to use as validation
valid_size = 0.2
# set the seed
seed = 42
# use a gpu if one is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(42)

<torch._C.Generator at 0x28b55ae2110>

In [3]:
# create the transformations
normalization = transforms.Normalize((0.1306,), (0.3015,))
test_transforms = transforms.Compose([transforms.ToTensor(), normalization])
# in this case we are using the same transformations for training and testing
train_transforms = test_transforms

# get the training and testing datasets
train_data = datasets.MNIST(root = 'data', train = True, download = True, transform=train_transforms)
valid_data = datasets.MNIST(root = 'data', train = True, download = True, transform=test_transforms)
test_data = datasets.MNIST(root = 'data', train = False, download = True, transform=test_transforms)

# split the training data into train and validation data
indices = np.arange(len(train_data))
train_indices, valid_indices = train_test_split(indices, test_size=valid_size, random_state=seed, stratify=train_data.targets)

# create the data loaders
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size,
                                           sampler = train_sampler, num_workers = num_workers, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size,
                                          sampler = valid_sampler, num_workers = num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size,
                                         num_workers = num_workers, drop_last=True)

print(f'Number Training Samples: {len(train_loader) * batch_size}')
print(f'Number Validation Samples: {len(valid_loader) * batch_size}')
print(f'Number Test Samples: {len(test_loader) * batch_size}')

Number Training Samples: 48000
Number Validation Samples: 11904
Number Test Samples: 9984


In [5]:
test_transforms.size()

AttributeError: 'Compose' object has no attribute 'size'

### Own Convolutional and Pooling Layer 

In [None]:
class MyWindowOperation(nn.Module):
    def __init__(self, kernel_size, padding=0, stride=1, dilation=1):
        super().__init__()

        # define all dimensions we need
        self.kernel_size = kernel_size if hasattr(kernel_size, '__iter__') else (kernel_size, kernel_size)
        self.dilation = dilation if hasattr(dilation, '__iter__') else (dilation, dilation)
        self.padding = padding if hasattr(padding, '__iter__') else (padding, padding)
        self.stride = stride if hasattr(stride, '__iter__') else (stride, stride)
        
    def forward(self, x):
        raise NotImplemented('This function is not implemented yet. Please implement in subclass')
    
    def calculateWindows(self, x):
        # get the windows/patches along the height
        windows = F.unfold(x, kernel_size=self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride)
        # switch the number of windows with the number of values per window
        windows = windows.transpose(1, 2) # (batch_dim, num_windows, num_values_per_window)

        return windows

    def calculateNewWidth(self, x):
        # the formula to calculate the output size is: ((input - kernel_size + 2 * padding) / stride) + 1
        input_width = x.shape[-2]

        return (
            (input_width + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1)
            // self.stride[0]
        ) + 1

    def calculateNewHeight(self, x):
        input_height = x.shape[-1]

        return (
            (input_height + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1)
            // self.stride[1]
        ) + 1

class MyConv2d(MyWindowOperation):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, dilation=1):
        super().__init__(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)

        self.out_channels = out_channels
        self.in_channels = in_channels
        
        # parameters of conv layer
        self.weights = nn.Parameter(torch.ones(self.out_channels, self.in_channels, *self.kernel_size))
        self.bias = nn.Parameter(torch.zeros((self.out_channels)))

        # init parameters
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.weights)
        nn.init.constant_(self.bias, 0)

    def forward(self, x):
        is_batched = len(x.shape) > 3
        if not is_batched:
            x = x.unsqueeze(0)
            
        batch_size = x.shape[0]
    
        # get the width, heigth of the output
        width = self.calculateNewWidth(x)
        height = self.calculateNewHeight(x)
        # get all the windows on the input
        windows = self.calculateWindows(x) # (batch_size, num_windows, num_values_per_window)

        res = windows.matmul(self.weights.view(self.out_channels, -1).T) + self.bias # (batch_size, num_windows, out_channels)
        res = res.transpose(1, 2) # (batch_size, out_channels, num_windows)
        
        # view the result
        res = res.view(batch_size, self.out_channels, height, width)

        # if the input is not batched remove the batch size of 1 at the first dimension
        if not is_batched:
            res = res.squeeze(0)

        return res 
    

class MyAvgPool2d(MyWindowOperation):
    def __init__(self, kernel_size, padding=0, stride=1, dilation=1, ):
        super().__init__(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)

    def forward(self, x):
        is_batched = len(x.shape) > 3
        if not is_batched:
            x = x.unsqueeze(0)
            
        batch_size = x.shape[0]
        num_channels = x.shape[1]
    
        # get the width, heigth of the output
        width = self.calculateNewWidth(x)
        height = self.calculateNewHeight(x)
        # get all the windows on the input
        windows = self.calculateWindows(x) # (batch_size, num_windows, num_values_per_window)

        # get the average across the values of each window
        res = windows.contiguous().view(batch_size, num_channels, -1, self.kernel_size[0] * self.kernel_size[1]).mean(-1) 
        
        # view the result
        res = res.view(batch_size, -1, height, width)

        # if the input is not batched remove the batch size of 1 at the first dimension
        if not is_batched:
            res = res.squeeze(0)

        return res 

In [None]:
def visualize(img):
    plt.imshow(img.permute(1, 2, 0), cmap=plt.cm.gray)
    plt.show()

In [None]:
# print the test image for visualizing
test_image, test_image_label = train_data[0]
print(f'Label: {5}')
visualize(test_image)

In [None]:
# check that the convolution implementation is correct
my_conv = MyConv2d(in_channels=1, out_channels=6, kernel_size=5)
pt_conv = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)

my_conv_weight = my_conv.weights.data.numpy()
my_conv_bias = my_conv.bias.data.numpy()
print(f'Shape Bias: {my_conv_bias.shape}')
print(f'Shape Weights: {my_conv_weight.shape}')

# to test whether our calculations are correct assign the same weights
my_conv.weights = pt_conv.weight
my_conv.bias = pt_conv.bias

assert (my_conv.weights == pt_conv.weight).all()
assert (my_conv.bias == pt_conv.bias).all()

print(f'Visualizing first kernel:')
plt.imshow(my_conv.weights.data.numpy()[0, 0, ...])
plt.show()

my_conv_res = my_conv(test_image)
pt_conv_res = pt_conv(test_image)
assert my_conv_res.shape == pt_conv_res.shape

print(f'Abs Error: {(pt_conv_res - my_conv_res).abs().max()}')
print(f'Output Shape: {my_conv_res.shape}\n')

In [None]:
fig, axs = plt.subplots(1, my_conv_res.shape[0] + 1, figsize=(15, 5))
axs[0].imshow(test_image.permute(1, 2, 0), cmap=plt.cm.gray)
for i in range(my_conv_res.shape[0]):
   axs[i+1].imshow(my_conv_res[i:i+1].detach().permute(1, 2, 0), cmap=plt.cm.gray)

In [None]:
# check that the pooling implementation is correct
my_pool = MyAvgPool2d(kernel_size=2, stride=2)
pt_pool = nn.AvgPool2d(kernel_size=2, stride=2)

my_pool_res = my_pool(test_image)
pt_pool_res = pt_pool(test_image)
assert my_conv_res.shape == pt_conv_res.shape

print(f'Abs Error: {(pt_pool_res - my_pool_res).abs().max()}')
print(f'Output Shape: {my_pool_res.shape}\n')

visualize(my_pool_res)

In [None]:
# Define own model = LeCun ConvNet
class ConvNet(torch.nn.Module):
    def __init__(self, own_conv=True, own_pool=True):
        super().__init__()

        # define which convolution to use
        if own_conv:
            print("Using own conv")
            self.conv1 = MyConv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
            self.conv2 = MyConv2d(in_channels=6, out_channels=16, kernel_size=5)
        else:
            print("Using pytorch conv")
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
            self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)

        # define which pooling layer to use
        if own_pool:
            print("Using own pooling layer")
            self.pool = MyAvgPool2d(kernel_size=2, stride=2)
        else:
            print("Using pytorch pooling layer")
            self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.output = nn.Linear(84, 10)
        
        
    def forward(self, x):
        x = F.sigmoid(self.conv1(x))
        x = self.pool(x)
        
        x = F.sigmoid(self.conv2(x))
        x = self.pool(x)

        # flatten the convolution results
        x = x.view(-1, 16*5*5)

        x = F.sigmoid(self.fc1(x))
        x = F.sigmoid(self.fc2(x))

        return self.output(x)

### Optimization with torch

In [None]:
def train(network, train_loader, val_loader, test_loader, epochs=200, lr=1e-3):
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    
    num_train_samples = (len(train_loader) * train_loader.batch_size)
    num_val_samples = (len(val_loader) * val_loader.batch_size)
    num_test_samples = (len(test_loader) * test_loader.batch_size)

    best_model_valid_loss = np.Inf
    num_epochs_without_val_loss_reduction = 0
    early_stopping_window = 5

    for epoch in range(epochs):
        running_train_loss = 0
        running_val_loss = 0

        correct_train_preds = 0
        correct_val_preds = 0

        # iterate the training data
        with tqdm(train_loader, desc="Training") as train_epoch_pbar:
            for i, (x, y) in enumerate(train_epoch_pbar):
                x, y = x.to(device), y.to(device)
                
                optimizer.zero_grad()
                logits = network(x)
                loss = loss_fn(logits, y)
                loss.backward()
                optimizer.step()

                running_train_loss += loss * len(x)
                correct_train_preds += (logits.argmax(-1) == y).sum()

                if i % 10 == 0:
                    # we can divide by (i * len(x)) because we are dropping the last batch
                    num_train_samples_so_far = (i + 1) * len(x)
                    train_epoch_pbar.set_postfix(train_loss=running_train_loss.item() / num_train_samples_so_far, accuracy=correct_train_preds.item() / num_train_samples_so_far * 100)

        # iterate the val data
        with tqdm(val_loader, desc="Validating") as val_epoch_pbar:
            for i, (x, y) in enumerate(val_epoch_pbar):
                x, y = x.to(device), y.to(device)
                logits = network(x)
                loss = loss_fn(logits, y)

                running_val_loss += loss * len(x)
                correct_val_preds += (logits.argmax(-1) == y).sum()

                if i % 10 == 0:
                    num_val_samples_so_far = (i + 1) * len(x)
                    val_epoch_pbar.set_postfix(train_loss=running_val_loss.item() / num_val_samples_so_far, accuracy=correct_val_preds.item() / num_val_samples_so_far * 100)

        
        avg_train_loss = running_train_loss.item() / num_train_samples
        train_acc = correct_train_preds.item() / num_train_samples * 100
        avg_val_loss = running_val_loss.item() / num_val_samples
        val_acc = correct_val_preds.item() / num_val_samples * 100

        print(f'Epoch {epoch}: \tAvg Train Loss: {avg_train_loss:.2f} \tTrain Acc: {train_acc:.2f} \tAvg Val Loss: {avg_val_loss:.2f} \tVal Acc: {val_acc:.2f}')

        # perform early stopping if necessary
        if avg_val_loss <= best_model_valid_loss:
            print(f'Validation loss decreased ({best_model_valid_loss:.6f} --> {avg_val_loss:.6f}).  Saving model ...')
            torch.save(network.state_dict(), 'model.pt')
            best_model_valid_loss = avg_val_loss
            num_epochs_without_val_loss_reduction = 0
        else:
            num_epochs_without_val_loss_reduction += 1
        
        if num_epochs_without_val_loss_reduction >= early_stopping_window:
            # if we haven't had a reduction in validation loss for `early_stopping_window` epochs, then stop training
            print(f'No reduction in validation loss for {early_stopping_window} epochs. Stopping training...')
            break


    running_test_loss = 0
    correct_test_preds = 0
    # only after finishing training we are testing our model
    with tqdm(test_loader, desc="Testing") as test_pbar:
        for i, (x, y) in enumerate(test_pbar):
            x, y = x.to(device), y.to(device)
            logits = network(x)
            loss = loss_fn(logits, y)

            running_test_loss += loss * len(x)
            correct_test_preds += (logits.argmax(-1) == y).sum()

            if i % 10 == 0:
                num_test_samples_so_far = (i + 1) * len(x)
                val_epoch_pbar.set_postfix(train_loss=running_test_loss.item() / num_test_samples_so_far, accuracy=correct_test_preds.item() / num_test_samples_so_far * 100)

    avg_test_loss = running_test_loss.item() / num_test_samples
    test_acc = correct_test_preds.item() / num_test_samples * 100
    
    print(f'Test Set: \tAvg Test Loss: {avg_test_loss:.2f} \tTest Acc: {test_acc:.2f}')



In [None]:
model = ConvNet(own_conv=True, own_pool=True).to(device)

In [None]:
train(model, train_loader, valid_loader, test_loader, epochs=200, lr=1e-3)


In [None]:
model.cpu()
filter_output = model.conv1(test_image)
fig, axs = plt.subplots(1, filter_output.shape[0] + 1, figsize=(15, 5))
axs[0].imshow(test_image.permute(1, 2, 0), cmap=plt.cm.gray)
for i in range(filter_output.shape[0]):
    axs[i + 1].imshow(filter_output[i:i + 1].detach().permute(1, 2, 0), cmap=plt.cm.gray)