In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from PIL import Image
from PIL.Image import Image as PilImage

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!pip install yacs
from yacs.config import CfgNode as CN

# configs for HRNet48
HRNET_48 = CN()
HRNET_48.FINAL_CONV_KERNEL = 1

HRNET_48.STAGE1 = CN()
HRNET_48.STAGE1.NUM_MODULES = 1
HRNET_48.STAGE1.NUM_BRANCHES = 1
HRNET_48.STAGE1.NUM_BLOCKS = [4]
HRNET_48.STAGE1.NUM_CHANNELS = [64]
HRNET_48.STAGE1.BLOCK = 'BOTTLENECK'
HRNET_48.STAGE1.FUSE_METHOD = 'SUM'

HRNET_48.STAGE2 = CN()
HRNET_48.STAGE2.NUM_MODULES = 1
HRNET_48.STAGE2.NUM_BRANCHES = 2
HRNET_48.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_48.STAGE2.NUM_CHANNELS = [48, 96]
HRNET_48.STAGE2.BLOCK = 'BASIC'
HRNET_48.STAGE2.FUSE_METHOD = 'SUM'

HRNET_48.STAGE3 = CN()
HRNET_48.STAGE3.NUM_MODULES = 4
HRNET_48.STAGE3.NUM_BRANCHES = 3
HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192]
HRNET_48.STAGE3.BLOCK = 'BASIC'
HRNET_48.STAGE3.FUSE_METHOD = 'SUM'

HRNET_48.STAGE4 = CN()
HRNET_48.STAGE4.NUM_MODULES = 3
HRNET_48.STAGE4.NUM_BRANCHES = 4
HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384]
HRNET_48.STAGE4.BLOCK = 'BASIC'
HRNET_48.STAGE4.FUSE_METHOD = 'SUM'


# configs for HRNet32
HRNET_32 = CN()
HRNET_32.FINAL_CONV_KERNEL = 1

HRNET_32.STAGE1 = CN()
HRNET_32.STAGE1.NUM_MODULES = 1
HRNET_32.STAGE1.NUM_BRANCHES = 1
HRNET_32.STAGE1.NUM_BLOCKS = [4]
HRNET_32.STAGE1.NUM_CHANNELS = [64]
HRNET_32.STAGE1.BLOCK = 'BOTTLENECK'
HRNET_32.STAGE1.FUSE_METHOD = 'SUM'

HRNET_32.STAGE2 = CN()
HRNET_32.STAGE2.NUM_MODULES = 1
HRNET_32.STAGE2.NUM_BRANCHES = 2
HRNET_32.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_32.STAGE2.NUM_CHANNELS = [32, 64]
HRNET_32.STAGE2.BLOCK = 'BASIC'
HRNET_32.STAGE2.FUSE_METHOD = 'SUM'

HRNET_32.STAGE3 = CN()
HRNET_32.STAGE3.NUM_MODULES = 4
HRNET_32.STAGE3.NUM_BRANCHES = 3
HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128]
HRNET_32.STAGE3.BLOCK = 'BASIC'
HRNET_32.STAGE3.FUSE_METHOD = 'SUM'

HRNET_32.STAGE4 = CN()
HRNET_32.STAGE4.NUM_MODULES = 3
HRNET_32.STAGE4.NUM_BRANCHES = 4
HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
HRNET_32.STAGE4.BLOCK = 'BASIC'
HRNET_32.STAGE4.FUSE_METHOD = 'SUM'

# configs for HRNet24
HRNET_24 = CN()
HRNET_24.FINAL_CONV_KERNEL = 1

HRNET_24.STAGE1 = CN()
HRNET_24.STAGE1.NUM_MODULES = 1
HRNET_24.STAGE1.NUM_BRANCHES = 1
HRNET_24.STAGE1.NUM_BLOCKS = [4]
HRNET_24.STAGE1.NUM_CHANNELS = [64]
HRNET_24.STAGE1.BLOCK = 'BOTTLENECK'
HRNET_24.STAGE1.FUSE_METHOD = 'SUM'

HRNET_24.STAGE2 = CN()
HRNET_24.STAGE2.NUM_MODULES = 1
HRNET_24.STAGE2.NUM_BRANCHES = 2
HRNET_24.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_24.STAGE2.NUM_CHANNELS = [24, 48]
HRNET_24.STAGE2.BLOCK = 'BASIC'
HRNET_24.STAGE2.FUSE_METHOD = 'SUM'

HRNET_24.STAGE3 = CN()
HRNET_24.STAGE3.NUM_MODULES = 4
HRNET_24.STAGE3.NUM_BRANCHES = 3
HRNET_24.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_24.STAGE3.NUM_CHANNELS = [24, 48, 96]
HRNET_24.STAGE3.BLOCK = 'BASIC'
HRNET_24.STAGE3.FUSE_METHOD = 'SUM'

HRNET_24.STAGE4 = CN()
HRNET_24.STAGE4.NUM_MODULES = 3
HRNET_24.STAGE4.NUM_BRANCHES = 4
HRNET_24.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_24.STAGE4.NUM_CHANNELS = [24, 48, 96, 192]
HRNET_24.STAGE4.BLOCK = 'BASIC'
HRNET_24.STAGE4.FUSE_METHOD = 'SUM'


# configs for HRNet18
HRNET_18 = CN()
HRNET_18.FINAL_CONV_KERNEL = 1

HRNET_18.STAGE1 = CN()
HRNET_18.STAGE1.NUM_MODULES = 1
HRNET_18.STAGE1.NUM_BRANCHES = 1
HRNET_18.STAGE1.NUM_BLOCKS = [4]
HRNET_18.STAGE1.NUM_CHANNELS = [64]
HRNET_18.STAGE1.BLOCK = 'BOTTLENECK'
HRNET_18.STAGE1.FUSE_METHOD = 'SUM'

HRNET_18.STAGE2 = CN()
HRNET_18.STAGE2.NUM_MODULES = 1
HRNET_18.STAGE2.NUM_BRANCHES = 2
HRNET_18.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_18.STAGE2.NUM_CHANNELS = [18, 36]
HRNET_18.STAGE2.BLOCK = 'BASIC'
HRNET_18.STAGE2.FUSE_METHOD = 'SUM'

HRNET_18.STAGE3 = CN()
HRNET_18.STAGE3.NUM_MODULES = 4
HRNET_18.STAGE3.NUM_BRANCHES = 3
HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72]
HRNET_18.STAGE3.BLOCK = 'BASIC'
HRNET_18.STAGE3.FUSE_METHOD = 'SUM'

HRNET_18.STAGE4 = CN()
HRNET_18.STAGE4.NUM_MODULES = 3
HRNET_18.STAGE4.NUM_BRANCHES = 4
HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]
HRNET_18.STAGE4.BLOCK = 'BASIC'
HRNET_18.STAGE4.FUSE_METHOD = 'SUM'


MODEL_CONFIGS = {
    'hrnet18': HRNET_18,
    'hrnet32': HRNET_32,
    'hrnet48': HRNET_48,
}

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting yacs
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Installing collected packages: yacs
Successfully installed yacs-0.1.8


In [None]:
BN_MOMENTUM = 0.1
ALIGN_CORNERS = None

if torch.__version__.startswith('0'):
    relu_inplace = False
else:
    BatchNorm2d_class = BatchNorm2d = torch.nn.BatchNorm2d
    relu_inplace = True

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=relu_inplace)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = BatchNorm2d(planes * self.expansion,
                               momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=relu_inplace)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.relu(out)

        return out


class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
                 num_channels, fuse_method, multi_scale_output=True):
        super(HighResolutionModule, self).__init__()
        self._check_branches(
            num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        self.fuse_method = fuse_method
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        self.branches = self._make_branches(
            num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()
        self.relu = nn.ReLU(inplace=relu_inplace)

    def _check_branches(self, num_branches, blocks, num_blocks,
                        num_inchannels, num_channels):
        if num_branches != len(num_blocks):
            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
                num_branches, len(num_blocks))
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
                num_branches, len(num_channels))
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
                num_branches, len(num_inchannels))
            raise ValueError(error_msg)

    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.num_inchannels[branch_index],
                          num_channels[branch_index] * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
                            momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.num_inchannels[branch_index],
                            num_channels[branch_index], stride, downsample))
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(block(self.num_inchannels[branch_index],
                                num_channels[branch_index]))

        return nn.Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []

        for i in range(num_branches):
            branches.append(
                self._make_one_branch(i, block, num_blocks, num_channels))

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(nn.Sequential(
                        nn.Conv2d(num_inchannels[j],
                                  num_inchannels[i],
                                  1,
                                  1,
                                  0,
                                  bias=False),
                        nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i-j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                BatchNorm2d(num_outchannels_conv3x3, 
                                            momentum=BN_MOMENTUM)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                BatchNorm2d(num_outchannels_conv3x3,
                                            momentum=BN_MOMENTUM),
                                nn.ReLU(inplace=relu_inplace)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def get_num_inchannels(self):
        return self.num_inchannels

    def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i] = self.branches[i](x[i])

        x_fuse = []
        for i in range(len(self.fuse_layers)):
            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y = y + x[j]
                elif j > i:
                    width_output = x[i].shape[-1]
                    height_output = x[i].shape[-2]
                    y = y + F.interpolate(
                        self.fuse_layers[i][j](x[j]),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=ALIGN_CORNERS)
                else:
                    y = y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        return x_fuse


blocks_dict = {
    'BASIC': BasicBlock,
    'BOTTLENECK': Bottleneck
}


class HighResolutionNet(nn.Module):

    def __init__(self, num_classes, final_conv_kernel):
        super(HighResolutionNet, self).__init__()

        # stem net
        self.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=2, padding=1,
                               bias=False)
        self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
                               bias=False)
        self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=relu_inplace)

        self.stage1_cfg = HRNET_24.STAGE1
        num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
        block = blocks_dict[self.stage1_cfg['BLOCK']]
        num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
        stage1_out_channel = block.expansion*num_channels

        self.stage2_cfg = HRNET_24.STAGE2
        num_channels = self.stage2_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage2_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition1 = self._make_transition_layer(
            [stage1_out_channel], num_channels)
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels)

        self.stage3_cfg = HRNET_24.STAGE3
        num_channels = self.stage3_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage3_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition2 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels)

        self.stage4_cfg = HRNET_24.STAGE4
        num_channels = self.stage4_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage4_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition3 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multi_scale_output=True)
        
        last_inp_channels = np.int(np.sum(pre_stage_channels))

        self.last_layer = nn.Sequential(
            nn.Conv2d(
                in_channels=last_inp_channels,
                out_channels=last_inp_channels,
                kernel_size=1,
                stride=1,
                padding=0),
            BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM),
            nn.ReLU(inplace=relu_inplace),
            nn.Conv2d(
                in_channels=last_inp_channels,
                # out_channels=config.DATASET.NUM_CLASSES,
                # kernel_size=extra.FINAL_CONV_KERNEL,
                out_channels = num_classes,
                kernel_size = final_conv_kernel,
                stride=1,
                padding=1 if final_conv_kernel == 3 else 0)
        )

    def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(nn.Sequential(
                        nn.Conv2d(num_channels_pre_layer[i],
                                  num_channels_cur_layer[i],
                                  3,
                                  1,
                                  1,
                                  bias=False),
                        BatchNorm2d(
                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=relu_inplace)))
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i+1-num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i-num_branches_pre else inchannels
                    conv3x3s.append(nn.Sequential(
                        nn.Conv2d(
                            inchannels, outchannels, 3, 2, 1, bias=False),
                        BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=relu_inplace)))
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)

    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(inplanes, planes))

        return nn.Sequential(*layers)

    def _make_stage(self, layer_config, num_inchannels,
                    multi_scale_output=True):
        num_modules = layer_config['NUM_MODULES']
        num_branches = layer_config['NUM_BRANCHES']
        num_blocks = layer_config['NUM_BLOCKS']
        num_channels = layer_config['NUM_CHANNELS']
        block = blocks_dict[layer_config['BLOCK']]
        fuse_method = layer_config['FUSE_METHOD']

        modules = []
        for i in range(num_modules):
            # multi_scale_output is only used last module
            if not multi_scale_output and i == num_modules - 1:
                reset_multi_scale_output = False
            else:
                reset_multi_scale_output = True
            modules.append(
                HighResolutionModule(num_branches,
                                      block,
                                      num_blocks,
                                      num_inchannels,
                                      num_channels,
                                      fuse_method,
                                      reset_multi_scale_output)
            )
            num_inchannels = modules[-1].get_num_inchannels()

        return nn.Sequential(*modules), num_inchannels

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.layer1(x)

        x_list = []
        for i in range(self.stage2_cfg['NUM_BRANCHES']):
            if self.transition1[i] is not None:
                x_list.append(self.transition1[i](x))
            else:
                x_list.append(x)
        y_list = self.stage2(x_list)

        x_list = []
        for i in range(self.stage3_cfg['NUM_BRANCHES']):
            if self.transition2[i] is not None:
                if i < self.stage2_cfg['NUM_BRANCHES']:
                    x_list.append(self.transition2[i](y_list[i]))
                else:
                    x_list.append(self.transition2[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        y_list = self.stage3(x_list)

        x_list = []
        for i in range(self.stage4_cfg['NUM_BRANCHES']):
            if self.transition3[i] is not None:
                if i < self.stage3_cfg['NUM_BRANCHES']:
                    x_list.append(self.transition3[i](y_list[i]))
                else:
                    x_list.append(self.transition3[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        x = self.stage4(x_list)

        # Upsampling
        x0_h, x0_w = x[0].size(2), x[0].size(3)
        x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
        x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
        x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)

        x = torch.cat([x[0], x1, x2, x3], 1)
        
        x = self.last_layer(x)
        x = torch.sigmoid(x)
        
        return x

    def init_weights(self, pretrained='',):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.001)
            elif isinstance(m, BatchNorm2d_class):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if os.path.isfile(pretrained):
            pretrained_dict = torch.load(pretrained)
            model_dict = self.state_dict()              
            pretrained_dict = {k: v for k, v in pretrained_dict.items()
                               if k in model_dict.keys()}
            model_dict.update(pretrained_dict)
            self.load_state_dict(model_dict)

def get_seg_model(num_classes, final_conv_kernel, **kwargs):
    model = HighResolutionNet(num_classes, final_conv_kernel, **kwargs)
    # model.init_weights(cfg.MODEL.PRETRAINED)

    return model

In [None]:
import socket
import timeit
from datetime import datetime
import scipy.misc as sm
from collections import OrderedDict
import glob
import sys
import imageio

# PyTorch includes
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.functional import interpolate

# Tensorboard include
from torch.utils.tensorboard import SummaryWriter

sys.path.append("/content/drive/My Drive/DEXTR-PyTorch")

from dataloaders.combine_dbs import CombineDBs as combine_dbs
import dataloaders.pascal as pascal
import dataloaders.sbd as sbd
from dataloaders import custom_transforms as tr
# from networks/models.seg_hrnet import get_seg_model
from layers.loss import class_balanced_cross_entropy_loss
from dataloaders.helpers import *

import networks.deeplab_resnet as resnet

!pip install torchmetrics
from torchmetrics import JaccardIndex

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.9.1-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 7.9 MB/s 
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.9.1


In [None]:
# Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
gpu_id = 0
device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print('Using GPU: {} '.format(gpu_id))

# Setting parameters
use_sbd = False
nEpochs = 10  # Number of epochs for training
resume_epoch = 0  # Default is 0, change if want to resume

p = OrderedDict()  # Parameters to include in report
classifier = 'psp'  # Head classifier to use
p['trainBatch'] = 16  # Training batch size
testBatch = 5  # Testing batch size
useTest = 1  # See evolution of the test set when training?
nTestInterval = 10  # Run on test set every nTestInterval epochs
snapshot = 10  # Store a model every snapshot epochs
relax_crop = 50  # Enlarge the bounding box by relax_crop pixels
nInputChannels = 4  # Number of input channels (RGB + heatmap of extreme points)
zero_pad_crop = True  # Insert zero padding when cropping the image
p['nAveGrad'] = 1  # Average the gradient of several iterations
p['lr'] = 1e-4  # Learning rate
p['wd'] = 0.0005  # Weight decay
p['momentum'] = 0.9  # Momentum
IoU = JaccardIndex(num_classes=2, threshold=0.8).to(device)

# Results and model directories (a new directory is generated for every run)
# save_dir_root = f'/content/drive/MyDrive/DEXTR-PyTorch/experiment_HRNet32/'
save_dir_root = f'/content/drive/MyDrive/DEXTR-PyTorch/experiment_HRNet24_Sobel/'
exp_name = 'Experiment_HRNet_Sobel_10Epochs'
# exp_name = 'Resnet'
if resume_epoch == 0:
    runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*')))
    run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
else:
    run_id = 0
save_dir = os.path.join(save_dir_root, 'run_' + str(run_id))
if not os.path.exists(os.path.join(save_dir, 'models')):
    os.makedirs(os.path.join(save_dir, 'models'))

# Network definition
modelName = 'dextr_pascal'
# net = resnet.resnet50(1, pretrained=False, nInputChannels=nInputChannels, classifier=classifier)
# net = resnet.resnet101(1, pretrained=False, nInputChannels=nInputChannels, classifier=classifier)
net = get_seg_model(num_classes=1,final_conv_kernel=1)
if resume_epoch == 0:
    print("Initializing from pretrained Deeplab-v2 model")
else:
    print("Initializing weights from: {}".format(
        os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
    net.load_state_dict(
        torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'),
                map_location=lambda storage, loc: storage))
train_params = [{'params': net.parameters(), 'lr': p['lr']}]

net.to(device)

Using GPU: 0 
Initializing from pretrained Deeplab-v2 model


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


HighResolutionNet(
  (conv1): Conv2d(4, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inpla

In [None]:
if __name__ == '__main__':
    # Training the network
    if resume_epoch != nEpochs:
        # Logging into Tensorboard
        log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
        writer = SummaryWriter(log_dir=log_dir)

        # Use the following optimizer
        optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
        p['optimizer'] = str(optimizer)

        # Preparation of the data loaders
        composed_transforms_tr = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
            tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
            tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
            tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'),
            tr.ToImage(norm_elem='extreme_points'),
            tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
            tr.ToTensor()])
        composed_transforms_ts = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
            tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
            tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
            tr.ToImage(norm_elem='extreme_points'),
            tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
            tr.ToTensor()])

        voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr, download=True)
        voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, download=True)

        if use_sbd:
            sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True)
            db_train = combine_dbs([voc_train, sbd], excluded=[voc_val])
        else:
            db_train = voc_train

        p['dataset_train'] = str(db_train)
        p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms]
        p['dataset_test'] = str(db_train)
        p['transformations_test'] = [str(tran) for tran in composed_transforms_ts.transforms]

        trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2)
        testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=2)

        generate_param_report(os.path.join(save_dir, exp_name + '.txt'), p)

        # Train variables
        num_img_tr = len(trainloader)
        num_img_ts = len(testloader)
        running_loss_tr = 0.0
        running_loss_ts = 0.0
        running_iou = []
        aveGrad = 0
        print("Training Network")
        # Main Training and Testing Loop
        for epoch in range(resume_epoch, nEpochs):
            start_time = timeit.default_timer()

            net.train()
            for ii, sample_batched in enumerate(trainloader):

                inputs, gts = sample_batched['concat'], sample_batched['crop_gt']

                # Forward-Backward of the mini-batch
                inputs.requires_grad_()
                inputs, gts = inputs.to(device), gts.to(device)

                output = net.forward(inputs)
                output = interpolate(output, size=(512, 512), mode='bilinear', align_corners=True)

                # Compute the losses, side outputs and fuse
                loss = class_balanced_cross_entropy_loss(output, gts, size_average=False, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_tr == num_img_tr - 1:
                    running_loss_tr = running_loss_tr / num_img_tr
                    writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
                    print('[Epoch: %d, numImages: %5d]' % (epoch, ii*p['trainBatch']+inputs.data.shape[0]))
                    print('Loss: %f' % running_loss_tr)
                    # print('IoU: %f' % (sum(running_iou) / len(running_iou)))
                    running_loss_tr = 0
                    # running_iou = []
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time)+"\n")

                # Backward the averaged gradient
                loss /= p['nAveGrad']
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p['nAveGrad'] == 0:
                    writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)
                    optimizer.step()
                    optimizer.zero_grad()
                    aveGrad = 0

            # Save the model
            if (epoch % snapshot) == snapshot - 1 and epoch != 0:
                torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))

            # One testing epoch
            if useTest and epoch % nTestInterval == (nTestInterval - 1):
                net.eval()
                with torch.no_grad():
                    for ii, sample_batched in enumerate(testloader):
                        inputs, gts = sample_batched['concat'], sample_batched['crop_gt']

                        # Forward pass of the mini-batch
                        inputs, gts = inputs.to(device), gts.to(device)

                        output = net.forward(inputs)
                        output = interpolate(output, size=(512, 512), mode='bilinear', align_corners=True)

                        # Compute the losses, side outputs and fuse
                        loss = class_balanced_cross_entropy_loss(output, gts, size_average=False)
                        running_loss_ts += loss.item()

                        # Print stuff
                        if ii % num_img_ts == num_img_ts - 1:
                            running_loss_ts = running_loss_ts / num_img_ts
                            print('[Epoch: %d, numImages: %5d]' % (epoch, ii*testBatch+inputs.data.shape[0]))
                            writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch)
                            print('Loss: %f' % running_loss_ts)
                            running_loss_ts = 0

        writer.close()

    # Generate result of the validation images
    net.eval()
    composed_transforms_ts = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512)}),
        tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()])
    db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True)
    testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)

    save_dir_res = os.path.join(save_dir, 'Results')
    if not os.path.exists(save_dir_res):
        os.makedirs(save_dir_res)

/path/to/PASCAL/VOC2012/VOCtrainval_11-May-2012.tar does not exist
Downloading http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar to /path/to/PASCAL/VOC2012/VOCtrainval_11-May-2012.tar
>> /path/to/PASCAL/VOC2012/VOCtrainval_11-May-2012.tar 100.0%Extracting tar file
Done!
Preprocessing of PASCAL VOC dataset, this will take long, but it will be done only once.
Preprocessing finished
Number of images: 1464
Number of objects: 3507
Files already downloaded and verified
Preprocessing of PASCAL VOC dataset, this will take long, but it will be done only once.
Preprocessing finished
Number of images: 1449
Number of objects: 3427
Training Network
[Epoch: 0, numImages:  3507]
Loss: 44596.759988
Execution time: 440.28115791999994

[Epoch: 1, numImages:  3507]
Loss: 39816.437251
Execution time: 299.62312605700004

[Epoch: 2, numImages:  3507]
Loss: 39258.728942
Execution time: 253.19538387600005

[Epoch: 3, numImages:  3507]
Loss: 38973.633887
Execution time: 255.13941835299

In [None]:
print('Testing Network')
with torch.no_grad():
    # Main Testing Loop
    for ii, sample_batched in enumerate(testloader):

        inputs, gts, metas = sample_batched['concat'], sample_batched['gt'], sample_batched['meta']

        # Forward of the mini-batch
        inputs = inputs.to(device)

        outputs = net.forward(inputs)
        outputs = interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=True)
        outputs = outputs.to(torch.device('cpu'))

        for jj in range(int(inputs.size()[0])):
            pred = np.transpose(outputs.data.numpy()[jj, :, :, :], (1, 2, 0))
            pred = 1 / (1 + np.exp(-pred))
            pred = np.squeeze(pred)
            gt = tens2image(gts[jj, :, :, :])
            bbox = get_bbox(gt, pad=relax_crop, zero_pad=zero_pad_crop)
            result = crop2fullmask(pred, bbox, gt, zero_pad=zero_pad_crop, relax=relax_crop)

            # Save the result, attention to the index jj
            imageio.imwrite(os.path.join(save_dir_res, metas['image'][jj] + '-' + metas['object'][jj] + '.png'), result)

Testing Network




In [None]:
from evaluation.eval import eval_one_result

dataset = pascal.VOCSegmentation(transform=None, retname=True, download=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

#second part
print ("starting second part")
# net 
exp_root_dir = save_dir_root
method = f'run_{run_id}'
results_folder = os.path.join(exp_root_dir, method, 'Results')
mask_threshold = [0.1*i for i in range(1,10)]
results = []

filename = os.path.join(exp_root_dir, 'eval_results', method.replace('/', '-') + '.txt')
if not os.path.exists(os.path.join(exp_root_dir, 'eval_results')):
    os.makedirs(os.path.join(exp_root_dir, 'eval_results'))

# if os.path.isfile(filename):
#     with open(filename, 'r') as f:
#         val = float(f.read())
# else:
print("Evaluating method: {}".format(method))
for m in mask_threshold:
    print(f"Threshold: {m}")
    jaccards = eval_one_result(dataloader, results_folder, mask_thres=m)
    val = jaccards["all_jaccards"].mean()
    results.append(val)

    print("Result for {:<80}: {}\n".format(method, str.format("{0:.1f}", 100*val)))


# Store result

with open(filename, 'w') as f:
    for m, val in zip(mask_threshold, results):
        f.write("Threshold {} : {}\n".format(str.format("{0:.2f}", m), str(val)))

Files already downloaded and verified
Number of images: 1449
Number of objects: 3427
starting second part
Evaluating method: run_0
Threshold: 0.1
Evaluating: 0 of 3427 objects
Evaluating: 500 of 3427 objects
Evaluating: 1000 of 3427 objects
Evaluating: 1500 of 3427 objects
Evaluating: 2000 of 3427 objects
Evaluating: 2500 of 3427 objects
Evaluating: 3000 of 3427 objects
Result for run_0                                                                           : 63.0

Threshold: 0.2
Evaluating: 0 of 3427 objects
Evaluating: 500 of 3427 objects
Evaluating: 1000 of 3427 objects
Evaluating: 1500 of 3427 objects
Evaluating: 2000 of 3427 objects
Evaluating: 2500 of 3427 objects
Evaluating: 3000 of 3427 objects
Result for run_0                                                                           : 63.0

Threshold: 0.30000000000000004
Evaluating: 0 of 3427 objects
Evaluating: 500 of 3427 objects
Evaluating: 1000 of 3427 objects
Evaluating: 1500 of 3427 objects
Evaluating: 2000 of 3427 obj