In [None]:
import torch
import torch.nn as nn

affine_par = True


class Bottleneck(nn.Module):
  # how does it work? Imagine you want to process a feature map with 256 channels. Instead of applying a "big" convolution from 256 to 256 (which is computationally expensive),
  # the bottleneck does this: 1) Reduces the number of channels (e.g., from 256 → 64) → less data to process 2)  Processes those fewer channels
  # 3) Expands the channels back to the original size (from 64 → 256) 4) Adds the original input to the final result (skip connection).
  # In this way it learns a lot, even if it is working in a smaller size.
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): #inplanes and planes are the number of channels
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # here the compression happens
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) #it normalizes values obtained in the middle hidden layers to make the algorithm more stable.
        #affine is used to include or not in the normalization of values the learned parameters gamma and beta.
        for i in self.bn1.parameters():
            i.requires_grad = False #it unables the gradient for the parameters of the batch normalization level. It means that those parameters won't be updated
            # (i.e. learned) by the model during training phase (those parameters are beta and gamma)
        padding = dilation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        for i in self.bn2.parameters():
            i.requires_grad = False
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) # here the expansion happens
        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
        for i in self.bn3.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x): #it is the method called in order to work with the input and obtain an output
        residual = x # store the original input in residual
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None: #if the out has different size from residual
            residual = self.downsample(x) #then, adapt residual to out's size
        out += residual # basically residual is added to out in order to understand how to modify the input in order to obtain a better prediction
        out = self.relu(out)

        return out


class ClassifierModule(nn.Module):
    def __init__(self, inplanes, dilation_series, padding_series, num_classes): #dilation_series and padding_series are lists of values for dilation and padding in the
    #convolutional layers
    #the dilation parameter is equal to the number of holes in the kernel (atrous)
        super(ClassifierModule, self).__init__()
        self.conv2d_list = nn.ModuleList() # a list of convolutional layers
        for dilation, padding in zip(dilation_series, padding_series): #it is creating parallel convolutions with different values of dilation (useful for ASPP).
            self.conv2d_list.append(
                nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding,
                          dilation=dilation, bias=True))

        for m in self.conv2d_list:
            m.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.conv2d_list[0](x)
        for i in range(len(self.conv2d_list) - 1): #the maps (che sono l'output dei diversi branch convoluzionali paralleli) all have the same size if the paddings
        # are chosen correctly according to dilation
            out += self.conv2d_list[i + 1](x) # the features extracted are fused to generate the final result
        #the feature maps obtained for each type of kernel have different fields of view. Therefore, when we concatenate them, we will have information of a local and global context
        return out


class ResNetMulti(nn.Module): #it is the modification of ResNet to obtain DeepLab.
    def __init__(self, block, layers, num_classes):
        self.inplanes = 64
        super(ResNetMulti, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
        for i in self.bn1.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        # up to layer2 we have a downsampling phase, where we are aquiring knowledge about what is represented (semantic). Then, with layer 3 and 4 we are using kernels with larger
        # receptieve fields in order to improve the spatial density.
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
        # the layer6 is used to apply ASPP in order to have multiscale image representations
        self.layer6 = ClassifierModule(2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
      # make_layer is used to build a stack of residual blocks.
      # - block is the type of residual block (e.g. BottleNeck)
      # - planes = number of output channels before expansion
      # - blocks = how many residual blocks to stack
      # - stride = used for spatial downsampling
      # - dilation = used for dilated convolutions
        downsample = None
        if (stride != 1
                or self.inplanes != planes * block.expansion
                or dilation == 2
                or dilation == 4): # if the feature maps are going to change the size, then we need to downsample in order to add the input to the output block
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
        for i in downsample._modules['1'].parameters():
            i.requires_grad = False
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))
        # it basically returns a list of BottleNeck blocks.
        # The Bottleneck does the same work as a chain of standard convolutional layers, but it does it in a smarter and more efficient way,
        # making it possible to build much deeper networks without exploding in parameters or computations.
        return nn.Sequential(*layers)

    def forward(self, x):
        _, _, H, W = x.size()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer6(x)

        x = torch.nn.functional.interpolate(x, size=(H, W), mode='bilinear')
        # Because the output from the deep layers is smaller than the input image (due to downsampling),
        # we resize it back to match the original input size using bilinear interpolation.
        if self.training == True:
            return x, None, None

        return x

    def get_1x_lr_params_no_scale(self):
        """
        This generator returns all the parameters of the net except for
        the last classification layer. Note that for each batchnorm layer,
        requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
        any batchnorm parameter
        """
        b = []

        b.append(self.conv1)
        b.append(self.bn1)
        b.append(self.layer1)
        b.append(self.layer2)
        b.append(self.layer3)
        b.append(self.layer4)

        for i in range(len(b)):
            for j in b[i].modules():
                jj = 0
                for k in j.parameters():
                    jj += 1
                    if k.requires_grad:
                        yield k

    def get_10x_lr_params(self):
        """
        This generator returns all the parameters for the last layer of the net,
        which does the classification of pixel into classes
        """
        b = []
        if self.multi_level:
            b.append(self.layer5.parameters())
        b.append(self.layer6.parameters())

        for j in range(len(b)):
            for i in b[j]:
                yield i

    def optim_parameters(self, lr):
        return [{'params': self.get_1x_lr_params_no_scale(), 'lr': lr},
                {'params': self.get_10x_lr_params(), 'lr': 10 * lr}]


def get_deeplab_v2(num_classes=19, pretrain=True, pretrain_model_path='DeepLab_resnet_pretrained_imagenet.pth'):
    # This function returns a DeepLabV2-like model with a ResNet-101 backbone, and optionally loads pretrained weights.
    # num_classes: Number of classes for segmentation (default = 19, for example Cityscapes dataset)
    # pretrain: Whether to load pretrained weights
    # pretrain_model_path: Path to the pretrained weights file
    model = ResNetMulti(Bottleneck, [3, 4, 23, 3], num_classes)

    # Pretraining loading
    if pretrain:
        print('Deeplab pretraining loading...')
        saved_state_dict = torch.load(pretrain_model_path)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params, strict=False)
    return model # Now the model is ready — with or without pretrained weights. The pretrained weights give an initial advantage for the training phase.