# PSPNet 

## Set up
1. download resnet 50 model. This will be our initial model. The code will improve it by using PSPNet architecture. 
2. unzip Camvid data
3. check python version

In [1]:
!wget -O "resnet50_v2.pth" --no-check-certificate 'https://docs.google.com/uc?export=download&id=1w5pRmLJXvmQQA5PtCbHhZc_uC4o0YbmA'
!mkdir initmodel && mv resnet50_v2.pth initmodel/

--2023-08-07 20:25:22--  https://docs.google.com/uc?export=download&id=1w5pRmLJXvmQQA5PtCbHhZc_uC4o0YbmA
Resolving docs.google.com (docs.google.com)... 2607:f8b0:4002:c09::64, 2607:f8b0:4002:c09::8b, 2607:f8b0:4002:c09::8a, ...
Connecting to docs.google.com (docs.google.com)|2607:f8b0:4002:c09::64|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-14-c8-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/nm67gdbkm1gipj06ljdqeq111nd1bije/1691465100000/15543419006824780045/*/1w5pRmLJXvmQQA5PtCbHhZc_uC4o0YbmA?e=download&uuid=edf5bc75-a33d-4c90-a661-b416db5916e2 [following]
--2023-08-07 20:25:27--  https://doc-14-c8-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/nm67gdbkm1gipj06ljdqeq111nd1bije/1691465100000/15543419006824780045/*/1w5pRmLJXvmQQA5PtCbHhZc_uC4o0YbmA?e=download&uuid=edf5bc75-a33d-4c90-a661-b416db5916e2
Resolving doc-14-c8-docs.googleusercontent.com (doc-14-c8-docs.googleusercontent.co

In [2]:
!cd Camvid && unzip camvid_semseg11.zip && cd ..

Archive:  camvid_semseg11.zip
   creating: semseg11/
  inflating: semseg11/0016E5_08085_L.png  
  inflating: semseg11/0006R0_f01080_L.png  
  inflating: semseg11/0016E5_00480_L.png  
  inflating: semseg11/0016E5_01350_L.png  
  inflating: semseg11/0016E5_01560_L.png  
  inflating: semseg11/0016E5_01440_L.png  
  inflating: semseg11/0016E5_06150_L.png  
  inflating: semseg11/0016E5_08045_L.png  
  inflating: semseg11/0016E5_00390_L.png  
  inflating: semseg11/0016E5_07680_L.png  
  inflating: semseg11/0016E5_08139_L.png  
  inflating: semseg11/0016E5_08280_L.png  
  inflating: semseg11/0016E5_01110_L.png  
 extracting: semseg11/0016E5_08550_L.png  
  inflating: semseg11/0006R0_f02850_L.png  
  inflating: semseg11/0016E5_05940_L.png  
  inflating: semseg11/0016E5_07230_L.png  
  inflating: semseg11/0016E5_00720_L.png  
  inflating: semseg11/0006R0_f01320_L.png  
  inflating: semseg11/0016E5_08061_L.png  
  inflating: semseg11/0016E5_08019_L.png  
  inflating: semseg11/0016E5_04980_L.png 

In [4]:
!python3 --version

Python 3.8.10


In [5]:
from types import SimpleNamespace

In [2]:
import os
import torch
from torch import nn

# Build Resnet50


In [3]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

In [4]:
""" Args:
    - block: Union[BasicBlock, Bottleneck] indicate if the block is bottleneck or basic block
    - layers: List[int] 
"""

class ResNet(nn.Module):
    def __init__(
        self, block, layers, num_classes: int = 1000, deep_base: bool = True
    ) -> None:
        super(ResNet, self).__init__()
        self.deep_base = deep_base
        if not self.deep_base:
            self.inplanes = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
        else:
            self.inplanes = 128
            self.conv1 = conv3x3(3, 64, stride=2)
            self.bn1 = nn.BatchNorm2d(64)
            self.conv2 = conv3x3(64, 64)
            self.bn2 = nn.BatchNorm2d(64)
            self.conv3 = conv3x3(64, 128)
            self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, planes=64,  blocks=layers[0])
        self.layer2 = self._make_layer(block, planes=128, blocks=layers[1], stride=2)
        self.layer3 = self._make_layer(block, planes=256, blocks=layers[2], stride=2)
        self.layer4 = self._make_layer(block, planes=512, blocks=layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def _make_layer(self, block, planes: int, blocks: int, stride: int = 1):
        """
        Args:
            block: Union[BasicBlock, Bottleneck] structure of fundamental block of layers that is repeated
            planes
            blocks: number of times the block module is repeated sequentially
            stride: stride of conv layers

        Returns:
        """
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

In [5]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample = None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [6]:
def resnet50(pretrained: bool = False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
        model_path = "./initmodel/resnet50_v2.pth"
        model.load_state_dict(torch.load(model_path), strict=False)
    return model

## Pyramid Pooling
### About Pyramid Pooling
A hierarchical global prior, containing information with different scaleand varying among different sub-regions.
The pyramid pooling module fuses features under four different pyramid scales.
The coarsest level highlighted in red is global pooling to generate a single bin output. The following pyramid level separates the feature map into different sub-regions and forms pooled representation for different locations. The output of different levels in the pyramid pooling module contains the feature map with varied sizes. To maintain the weight of global feature, we use 1×1 convolution layer after each pyramid level to reduce the dimension of context representation to 1/N of the original one
if the level size of pyramid is N.
### Alogithm for each level of pooling
1. use nn.AdaptiveAvgPool2d(bin) to break an image into (bin x bin) subregions
2. pools all entries inside each subregion.
3. use convolutional layer with 1x1 kernel and 2d batch norm
4. use ReLU for output
### Algorithm for forwarding
1. unsample to the desired output size and append each on 1d list
2. concentrate along channel dimension.

In [7]:
class PPM(nn.Module):
    def __init__(self, in_dim: int, reduction_dim: int, bins) -> None:
        super(PPM, self).__init__()
        self.features = []
        # for each bins(List[int]), operate the ppm. Therefore, the levels of the ppm will be as same as len(bins) where size of the bins is bins[i] ** 2
        for bin in bins:
            sequential = nn.Sequential(nn.AdaptiveAvgPool2d(bin), nn.Conv2d(in_dim,reduction_dim,1, bias=False), nn.BatchNorm2d(reduction_dim), nn.ReLU())
            self.features.append(sequential)
        self.reduction_dim = reduction_dim
        self.features = nn.ModuleList(self.features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: tensor of shape (N,in_dim,H,W)

        Returns:
            out: tensor of shape (N,C,H,W) where
                C = in_dim + len(bins)*reduction_dim
        """
        _,_, H, W = x.shape
        upsample = nn.Upsample(size=(H,W))
        output = x
        for i in range(len(self.features)):
            forwarded = self.features[i](x)
            forwarded = upsample(forwarded)
            output = torch.cat((output, forwarded), 1)
        return output

## Build PSPNet
The final feature map size is 1/8 of the input image.
### 1) Set up the model with ResNet50
The model will have 5 layers which got from pretrained-deep-based ResNet50. Layer0 will be sequential of conv1 -> bn1 -> relu -> conv2 -> bn2 -> relu -> conv3 -> bn3 -> relu -> maxpool.


In [9]:
class PSPNet(nn.Module):
    def __init__(
        self,
        layers: int = 50,
        bins=(1, 2, 3, 6),
        dropout: float = 0.1,
        num_classes: int = 2,
        zoom_factor: int = 8,
        use_ppm: bool = True,
        criterion=nn.CrossEntropyLoss(ignore_index=255),
        pretrained: bool = True,
        deep_base: bool = True,
    ) -> None:
        """
        Args:
            layers: int = 50,
            bins: list of grid dimensions for PPM, e.g. (1,2,3) means to create
                (1x1), (2x2), and (3x3) grids
            dropout: float representing probability of dropping out data
            num_classes: number of classes
            zoom_factor: it is a ration between largest and smallest of focal length that can give more detail in the display.
            use_ppm: boolean representing whether to use the Pyramid Pooling
                Module
            criterion: loss function module
            pretrained: boolean representing ...
        """
        super(PSPNet, self).__init__()
        assert layers == 50
        assert 2048 % len(bins) == 0
        assert num_classes > 1
        assert zoom_factor in [1, 2, 4, 8]
        self.dropout = dropout
        self.zoom_factor = zoom_factor
        self.use_ppm = use_ppm
        self.criterion = criterion
        
                                                      
        # Initializing ResNet backbone, and set the layers

        resnet = resnet50(pretrained=pretrained, deep_base=deep_base)
        self.layer0 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.conv2,
            resnet.bn2,
            resnet.relu,
            resnet.conv3,
            resnet.bn3,
            resnet.relu,
            resnet.maxpool,
        )
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.__replace_conv_with_dilated_conv()

        # The reduction_dim should be equal to the output number of ResNet feature maps, div num_classesided by the number of PPM bins.
        # Afterwards, set fea_dim to the updated feature dimension to be passed to the classifier.

        self.ppm = PPM(in_dim=2048, reduction_dim=num_classes, bins=bins)
        fea_dim = 2048

        self.cls = self.__create_classifier(
            in_feats=fea_dim, out_feats=512, num_classes=num_classes)
        self.aux = self.__create_classifier(
            in_feats=1024, out_feats=256, num_classes=num_classes)

    def __replace_conv_with_dilated_conv(self):
        for name, module in self.layer3.named_modules():
            if '.conv2' in name:
                module.stride = (1,1)
                module.padding = (2,2)
                module.dilation = (2,2)
            elif '.downsample.0' in name:
                module.stride = 1
        for name, module in self.layer4.named_modules():
            if '.conv2' in name:
                module.stride = (1,1)
                module.padding = (4,4)
                module.dilation = (4,4)
            elif '.downsample.0' in name:
                module.stride = (1,1)
        print("dilated layer 3 and 4: ")
        print("layer3: ")
        for module in self.layer3.named_modules():
            print(module)
        print("layer4: ")
        for module in self.layer4.named_modules():
            print(module)
    # Todo: forward method 