# Tutorial 5.1: Inception, ResNet and DenseNet

In this tutorial, we will implement and discuss variants of modern CNN architectures.

We use PyTorch Lightning for the first time here.

In [3]:
## Standard libraries
import os
import json
import math
import numpy as np 
import scipy.linalg
import random

## Imports for plotting
import matplotlib.pyplot as plt
# %matplotlib inline 
# from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('svg', 'pdf') # For export
# from matplotlib.colors import to_rgb
# import matplotlib
# matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
# from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateLogger, ModelCheckpoint

# Tensorboard extension (for visualization purposes later)
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial5"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

Throughout this tutorial, we use the CIFAR dataset. (We can change the dataset if necessary)

In [7]:
# Transformations applied on each image => make them a tensor and normalize
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5,), (0.5,))
                                     ])
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,), (0.5,))
                                     ])
# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
pl.seed_everything(42)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
pl.seed_everything(42)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## PyTorch Lightning

In [14]:
act_fn_by_name = {
    "tanh": nn.Tanh,
    "relu": nn.ReLU,
    "gelu": nn.GELU
}

In [10]:
model_dict = {}

def create_model(model_name, model_hparams):
    if model_name in model_dict:
        return model_dict[model_name](**model_hparams)
    else:
        assert False, "Unknown model name \"%s\". Available models are: %s" % (model_name, str(model_dict.keys()))

In [11]:
class CIFARTrainer(pl.LightningModule):
    
    def __init__(self, model_name, model_hparams, lr):
        super().__init__()
        self.save_hyperparameters()
        # Create model
        self.model = create_model(model_name, model_hparams)
        # Create loss module
        self.loss_module = nn.CrossEntropyLoss()
        # Example input for visualizing the graph
        self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)
    
    
    def forward(self, imgs):
        return self.model(imgs)
    
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.2)
        return [optimizer], [scheduler]
    
    
    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss)
        return result
    
    
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        
        result = pl.EvalResult(checkpoint_on=acc)
        result.log('val_acc', acc)
        return result
        
        
    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        
        result = pl.EvalResult()
        result.log('test_acc', acc)
        return result

In [20]:
def train_model(model_name, **kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, model_name),
                         checkpoint_callback=ModelCheckpoint(save_weights_only=True, mode="max"),
                         gpus=1,
                         max_epochs=150,
                         callbacks=[LearningRateLogger("epoch")])
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, model_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = CIFARTrainer.load_from_checkpoint(pretrained_filename)
    else:
        model = CIFARTrainer(model_name=model_name, **kwargs)
        trainer.fit(model, train_loader, val_loader)
    # Test best model on validation and test set
    val_result = trainer.test(model, test_dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, test_dataloaders=test_loader, verbose=False)
    result = {"test": test_result, "val": val_result}
    return model, result

## Inception

In [39]:
class InceptionBlock(nn.Module):
    
    def __init__(self, c_in, c_red : dict, c_out : dict, act_fn):
        super().__init__()
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(c_in, c_out["1x1"], kernel_size=1),
            nn.BatchNorm2d(c_out["1x1"]),
            act_fn()
        )
        
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(c_in, c_red["3x3"], kernel_size=1),
            nn.BatchNorm2d(c_red["3x3"]),
            act_fn(),
            nn.Conv2d(c_red["3x3"], c_out["3x3"], kernel_size=3, padding=1),
            nn.BatchNorm2d(c_out["3x3"]),
            act_fn()
        )
        
        self.conv_5x5 = nn.Sequential(
            nn.Conv2d(c_in, c_red["5x5"], kernel_size=1),
            nn.BatchNorm2d(c_out["5x5"]),
            act_fn(),
            nn.Conv2d(c_red["5x5"], c_out["5x5"], kernel_size=5, padding=2),
            nn.BatchNorm2d(c_out["5x5"]),
            act_fn()
        )
        
        self.max_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(c_in, c_out["max"], kernel_size=1),
            nn.BatchNorm2d(c_out["max"]),
            act_fn()
        )
        
    
    def forward(self, x):
        x_1x1 = self.conv_1x1(x)
        x_3x3 = self.conv_3x3(x)
        x_5x5 = self.conv_5x5(x)
        x_max = self.max_pool(x)
        x_out = torch.cat([x_1x1, x_3x3, x_5x5, x_max], dim=1)
        return x_out

In [40]:
class GoogleNet(nn.Module):
    
    def __init__(self, num_classes=10, act_fn_name="relu", **kwargs):
        super().__init__()
        self.hparams = {"num_classes": num_classes, "act_fn": act_fn_by_name[act_fn_name]}
        self._create_network()
    
    
    def _create_network(self):
        input_net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            self.hparams["act_fn"]()
        )
        inception_blocks = nn.Sequential(
            InceptionBlock(64, c_red={"3x3":32,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"]),
            InceptionBlock(128, c_red={"3x3":64,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"]),
            nn.MaxPool2d(3, stride=2, padding=1),
            InceptionBlock(128, c_red={"3x3":64,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"]),
            InceptionBlock(128, c_red={"3x3":64,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"]),
            InceptionBlock(128, c_red={"3x3":64,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"]),
            InceptionBlock(128, c_red={"3x3":64,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"]),
            nn.MaxPool2d(3, stride=2, padding=1),
            InceptionBlock(128, c_red={"3x3":64,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"]),
            InceptionBlock(128, c_red={"3x3":64,"5x5":16}, c_out={"1x1":32,"3x3":64,"5x5":16,"max":16}, act_fn=self.hparams["act_fn"])
        )
        output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(128, self.hparams["num_classes"])
        )
        self.net = nn.Sequential(
            input_net,
            inception_blocks,
            output_net
        )
    
    
    def forward(self, x):
        return self.net(x)

In [41]:
model_dict["GoogleNet"] = GoogleNet

In [42]:
googlenet_model, googlenet_results = train_model(model_name="GoogleNet", model_hparams={}, lr=1e-3)

GPU available: True, used: True
I1010 20:05:33.196611 139734903662400 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I1010 20:05:33.198423 139734903662400 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I1010 20:05:33.199871 139734903662400 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params | In sizes       | Out sizes
------------------------------------------------------------------------------
0 | model       | GoogleNet        | 456 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module | CrossEntropyLoss | 0      | ?              | ?        
I1010 20:05:36.200331 139734903662400 lightning.py:1449] 
  | Name        | Type             | Params | In sizes       | Out sizes
------------------------------------------------------------------------------
0 | model       | GoogleNet        | 456 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module | CrossEntropyLoss | 0      | ?  

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..
I1010 20:06:43.394120 139734903662400 training_loop.py:1136] Saving latest checkpoint..





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…




In [43]:
print(googlenet_results)

{'test': [{'test_acc': 0.6590366363525391}],
 'val': [{'test_acc': 0.6532832384109497}]}

## ResNet

In [21]:
class ResNetBlock(nn.Module):

    def __init__(self, c_in, act_fn, subsample=False):
        super().__init__()
        c_out = c_in if not subsample else 2*c_in
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_in, kernel_size=3, padding=1, stride=1 if not subsample else 2),
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=3, padding=1),
            nn.BatchNorm2d(c_out)
        )
        # Average Pool with kernel size means we only take the upper left value. Represents a identity mapping with stride
        self.downsample = nn.AvgPool2d(kernel_size=1, stride=2) if subsample else None
        self.act_fn = act_fn()

        
    def forward(self, x):
        z = self.net(x)
        if self.downsample is not None:
            x = self.downsample(x)
            x = torch.cat([x, torch.zeros_like(x)], dim=1)
        z = z + x
        out = self.act_fn(z)
        return out

In [25]:
class ResNet(nn.Module):

    def __init__(self, num_classes=10, c_hidden=32, num_blocks=3, act_fn_name="relu", **kwargs):
        super().__init__()
        self.hparams = {"num_classes": num_classes, "c_hidden": c_hidden, "num_blocks": num_blocks, "act_fn": act_fn_by_name[act_fn_name]}
        self._create_network()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=act_fn_name)

        
    def _create_network(self):
        c_hidden = self.hparams["c_hidden"]
        input_net = nn.Sequential(
            nn.Conv2d(3, c_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(c_hidden),
            self.hparams["act_fn"]()
        )
        # Clarify the setup 
        stack1 = nn.Sequential(*[ResNetBlock(c_in=c_hidden,
                                             act_fn=self.hparams["act_fn"])
                                     for _ in range(self.hparams["num_blocks"])])
        stack2a = ResNetBlock(c_in=c_hidden,
                                   act_fn=self.hparams["act_fn"],
                                   subsample=True)
        stack2b = nn.Sequential(*[ResNetBlock(c_in=2*c_hidden,
                                              act_fn=self.hparams["act_fn"])
                                     for _ in range(self.hparams["num_blocks"]-1)])
        stack3a = ResNetBlock(c_in=2*c_hidden,
                                   act_fn=self.hparams["act_fn"],
                                   subsample=True)
        stack3b = nn.Sequential(*[ResNetBlock(c_in=4*c_hidden,
                                              act_fn=self.hparams["act_fn"])
                                     for _ in range(self.hparams["num_blocks"]-1)])
        output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(4*c_hidden, self.hparams["num_classes"])
        )

        self.net = nn.Sequential(
            input_net,
            stack1,
            stack2a,
            stack2b,
            stack3a,
            stack3b,
            output_net
        )

        
    def forward(self, x):
        return self.net(x)

In [26]:
model_dict["ResNet"] = ResNet

In [27]:
resnet_model, resnet_results = train_model(model_name="ResNet", model_hparams={}, lr=1e-3)

GPU available: True, used: True
I1010 19:31:18.436130 139734903662400 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I1010 19:31:18.438191 139734903662400 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I1010 19:31:18.439631 139734903662400 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params | In sizes       | Out sizes
------------------------------------------------------------------------------
0 | model       | ResNet           | 235 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module | CrossEntropyLoss | 0      | ?              | ?        
I1010 19:31:18.873739 139734903662400 lightning.py:1449] 
  | Name        | Type             | Params | In sizes       | Out sizes
------------------------------------------------------------------------------
0 | model       | ResNet           | 235 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module | CrossEntropyLoss | 0      | ?  

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..
I1010 19:31:23.234957 139734903662400 training_loop.py:1136] Saving latest checkpoint..





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…




In [44]:
print(resnet_results)

{'test': [{'test_acc': 0.44416800141334534}],
 'val': [{'test_acc': 0.447191447019577}]}

## DenseNet

In [92]:
class DenseLayer(nn.Module):
    
    def __init__(self, c_in, bn_size, growth_rate, act_fn):
        """
            bn_size - Bottleneck size (factor of growth rate)
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, bn_size * growth_rate, kernel_size=1, bias=False),
            nn.BatchNorm2d(bn_size * growth_rate),
            act_fn(),
            nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
        )
        
    def forward(self, x):
        out = self.net(x)
        out = torch.cat([out, x], dim=1)
        return out

In [93]:
class DenseBlock(nn.Module):
    
    def __init__(self, c_in, num_layers, bn_size, growth_rate, act_fn):
        super().__init__()
        layers = []
        for layer_idx in range(num_layers):
            layers.append(
                DenseLayer(c_in=c_in + layer_idx * growth_rate,
                           bn_size=bn_size,
                           growth_rate=growth_rate,
                           act_fn=act_fn)
            )
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.block(x)
        return out

In [94]:
class ReductionLayer(nn.Module):
    
    def __init__(self, c_in, c_out, act_fn):
        super().__init__()
        self.reduction = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=1, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
    def forward(self, x):
        return self.reduction(x)

In [95]:
class DenseNet(nn.Module):
    
    def __init__(self, num_classes=10, num_layers=[4,4,4,4], bn_size=2, growth_rate=16, act_fn_name="relu", **kwargs):
        super().__init__()
        self.hparams = {"num_classes": num_classes,
                        "num_layers": num_layers,
                        "bn_size": bn_size,
                        "growth_rate": growth_rate,
                        "act_fn": act_fn_by_name[act_fn_name]}
        self._create_network()
        
    def _create_network(self):
        c_hidden = self.hparams["growth_rate"] * self.hparams["bn_size"]
        input_net = nn.Sequential(
            nn.Conv2d(3, c_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(c_hidden),
            self.hparams["act_fn"]()
        )
        blocks = []
        for block_idx, num_layers in enumerate(self.hparams["num_layers"]):
            blocks.append( 
                DenseBlock(c_in=c_hidden, 
                           num_layers=num_layers, 
                           bn_size=self.hparams["bn_size"],
                           growth_rate=self.hparams["growth_rate"],
                           act_fn=self.hparams["act_fn"])
            )
            c_hidden = c_hidden + num_layers * self.hparams["growth_rate"]
            if block_idx < len(self.hparams["num_layers"])-1:
                blocks.append(
                    ReductionLayer(c_in=c_hidden,
                                   c_out=c_hidden // 2,
                                   act_fn=self.hparams["act_fn"]))
                c_hidden = c_hidden // 2
        blocks = nn.Sequential(*blocks)
        output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(c_hidden, self.hparams["num_classes"])
        )
        self.net = nn.Sequential(
            input_net,
            blocks,
            output_net
        )
        
    def forward(self, x):
        return self.net(x)

In [96]:
model_dict["DenseNet"] = DenseNet

In [97]:
densenet_model, densenet_results = train_model(model_name="DenseNet", model_hparams={}, lr=1e-3)

GPU available: True, used: True
I1010 20:38:49.941073 139734903662400 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I1010 20:38:49.943002 139734903662400 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I1010 20:38:49.944466 139734903662400 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params | In sizes       | Out sizes
------------------------------------------------------------------------------
0 | model       | DenseNet         | 135 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module | CrossEntropyLoss | 0      | ?              | ?        
I1010 20:38:50.764374 139734903662400 lightning.py:1449] 
  | Name        | Type             | Params | In sizes       | Out sizes
------------------------------------------------------------------------------
0 | model       | DenseNet         | 135 K  | [1, 3, 32, 32] | [1, 10]  
1 | loss_module | CrossEntropyLoss | 0      | ?  

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..
I1010 20:39:47.688781 139734903662400 training_loop.py:1136] Saving latest checkpoint..





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…




In [98]:
print(densenet_results)

{'test': [{'test_acc': 0.7139729261398315}], 'val': [{'test_acc': 0.7143987417221069}]}
