### AlexNet

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch as th
import torch.nn as nn
from dataset import mnist
from util import (
    get_torch_size_string,
    print_model_parameters,
    print_model_layers,
    model_train,
    model_eval,
    model_test
)
np.set_printoptions(precision=3)
th.set_printoptions(precision=3)
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(th.__version__))

PyTorch version:[2.0.1].


### Hyperparameters

In [2]:
device = 'cpu' # cpu / mps
print ("Ready.")

Ready.


### Define AlexNet

In [3]:
class AlexNetClass(nn.Module):
    def __init__(self):
        super(AlexNetClass,self).__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(96, kernel_size=11, stride=4, padding=1),
            nn.ReLU(), 
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LazyConv2d(256, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LazyConv2d(384, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.LazyConv2d(384, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.LazyConv2d(256, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2), 
            nn.Flatten(),
            nn.LazyLinear(4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.LazyLinear(4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.LazyLinear(1000)
        )
        self.layer_names = []
        for l_idx,layer in enumerate(self.net):
            layer_name = "%s_%02d"%(type(layer).__name__.lower(),l_idx)
            self.layer_names.append(layer_name)
            
    def forward(self,x):
        """
            Forward propagate
        """
        intermediate_output_list = []
        for layer in self.net:
            x = layer(x)
            intermediate_output_list.append(x)
        final_output = x
        return final_output,intermediate_output_list
    
print ("Ready.")

Ready.


In [4]:
alexnet = AlexNetClass()
print ("Ready.")

Ready.




### Print model layers

In [5]:
x_torch = th.randn((16,3,224,224)).to(device)
print_model_layers(alexnet,x_torch)

batch_size:[16]
[  ] layer:[          input] size:[  16x3x224x224]
[ 0] layer:[  lazyconv2d_00] size:[   16x96x54x54] numel:[   4478976]
[ 1] layer:[        relu_01] size:[   16x96x54x54] numel:[   4478976]
[ 2] layer:[   maxpool2d_02] size:[   16x96x26x26] numel:[   1038336]
[ 3] layer:[  lazyconv2d_03] size:[  16x256x26x26] numel:[   2768896]
[ 4] layer:[        relu_04] size:[  16x256x26x26] numel:[   2768896]
[ 5] layer:[   maxpool2d_05] size:[  16x256x12x12] numel:[    589824]
[ 6] layer:[  lazyconv2d_06] size:[  16x384x12x12] numel:[    884736]
[ 7] layer:[        relu_07] size:[  16x384x12x12] numel:[    884736]
[ 8] layer:[  lazyconv2d_08] size:[  16x384x12x12] numel:[    884736]
[ 9] layer:[        relu_09] size:[  16x384x12x12] numel:[    884736]
[10] layer:[  lazyconv2d_10] size:[  16x256x12x12] numel:[    589824]
[11] layer:[        relu_11] size:[  16x256x12x12] numel:[    589824]
[12] layer:[   maxpool2d_12] size:[    16x256x5x5] numel:[    102400]
[13] layer:[     flatte

### Print model parameters

In [6]:
print_model_parameters(alexnet)

[ 0] parameter:[               net.0.weight] shape:[  96x3x11x11] numel:[     34848]
[ 1] parameter:[                 net.0.bias] shape:[          96] numel:[        96]
[ 2] parameter:[               net.3.weight] shape:[  256x96x5x5] numel:[    614400]
[ 3] parameter:[                 net.3.bias] shape:[         256] numel:[       256]
[ 4] parameter:[               net.6.weight] shape:[ 384x256x3x3] numel:[    884736]
[ 5] parameter:[                 net.6.bias] shape:[         384] numel:[       384]
[ 6] parameter:[               net.8.weight] shape:[ 384x384x3x3] numel:[   1327104]
[ 7] parameter:[                 net.8.bias] shape:[         384] numel:[       384]
[ 8] parameter:[              net.10.weight] shape:[ 256x384x3x3] numel:[    884736]
[ 9] parameter:[                net.10.bias] shape:[         256] numel:[       256]
[10] parameter:[              net.14.weight] shape:[   4096x6400] numel:[  26214400]
[11] parameter:[                net.14.bias] shape:[        4096]