<a href="https://colab.research.google.com/github/sheikmohdimran/Experiments_2019/blob/master/Vision/Pytorch_Architechture_Debug.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install torchsnooper -q

In [0]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torchsnooper

In [0]:
class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2=None):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # for padding issues, see 
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        
        if x2 is not None:
            x = torch.cat([x2, x1], dim=1)
        else:
            x = x1
        x = self.conv(x)
        return x

def get_mesh(batch_size, shape_x, shape_y):
    mg_x, mg_y = np.meshgrid(np.linspace(0, 1, shape_y), np.linspace(0, 1, shape_x))
    mg_x = np.tile(mg_x[None, None, :, :], [batch_size, 1, 1, 1]).astype('float32')
    mg_y = np.tile(mg_y[None, None, :, :], [batch_size, 1, 1, 1]).astype('float32')
    mesh = torch.cat([torch.tensor(mg_x).to(device), torch.tensor(mg_y).to(device)], 1)
    return mesh

In [0]:
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.GroupNorm(16, planes)

        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.GroupNorm(16, planes)

        if stride != 1 or inplanes != planes:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride), nn.GroupNorm(16, planes))
        else:
            self.downsample = None


    def forward(self, x):
        identity = x

        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.bn2(self.conv2(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = F.relu(out, inplace=True)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.GroupNorm(16, planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.GroupNorm(16, planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.GroupNorm(16, planes * self.expansion)

        if stride != 1 or inplanes != planes * self.expansion:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes * self.expansion, stride), 
                nn.GroupNorm(16, planes * self.expansion))
        else:
            self.downsample = None

    def forward(self, x):
        identity = x

        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = F.relu(self.bn2(self.conv2(out)), inplace=True)
        out = self.bn3(self.conv3(out))
 
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = F.relu(out)

        return out


class ResNetFeatures(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
        super(ResNetFeatures, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.GroupNorm(16, 64)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):

        layers = []
        layers.append(block(self.inplanes, planes, stride))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)


    def forward(self, x):
        conv1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
        conv1 = F.max_pool2d(conv1, 3, stride=2, padding=1)

        x = self.layer1(conv1)
        feats8 = self.layer2(x)
        feats16 = self.layer3(feats8)
        feats32 = self.layer4(feats16)

        return feats8,feats16,feats32



def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetFeatures(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        _load_pretrained(model, model_zoo.load_url(model_urls['resnet18']))
    return model
	
def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetFeatures(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        _load_pretrained(model, model_zoo.load_url(model_urls['resnet18']))
    return model

def _load_pretrained(model, pretrained):
    model_dict = model.state_dict()
    pretrained = {k : v for k, v in pretrained.items() if k in model_dict}
    model_dict.update(pretrained)
    model.load_state_dict(model_dict)

In [0]:
base_model=resnet18(pretrained=False)

In [0]:
@torchsnooper.snoop()
class CentResnet(nn.Module):
    '''Mixture of previous classes'''
    def __init__(self, n_classes):
        super(CentResnet, self).__init__()
        self.base_model = base_model
        
        # Lateral layers convert resnet outputs to a common feature size
        self.lat8 = nn.Conv2d(128, 256, 1)
        self.lat16 = nn.Conv2d(256, 256, 1)
        self.lat32 = nn.Conv2d(512, 256, 1)
        self.bn8 = nn.GroupNorm(16, 256)
        self.bn16 = nn.GroupNorm(16, 256)
        self.bn32 = nn.GroupNorm(16, 256)

        self.conv0 = double_conv(5, 64)
        self.conv1 = double_conv(64, 128)
        self.conv2 = double_conv(128, 512)
        self.conv3 = double_conv(512, 1024)
        
        self.mp = nn.MaxPool2d(2)
        
        self.up1 = up(1282 , 512) #+ 1024
        self.up2 = up(512 + 512, 256)
        self.outc = nn.Conv2d(256, n_classes, 1)
        
    
    def forward(self, x):
        batch_size = x.shape[0]
        mesh1 = get_mesh(batch_size, x.shape[2], x.shape[3])
        x0 = torch.cat([x, mesh1], 1)
        x1 = self.mp(self.conv0(x0))
        x2 = self.mp(self.conv1(x1))
        x3 = self.mp(self.conv2(x2))
        x4 = self.mp(self.conv3(x3))
        
        # Run frontend network
        feats8, feats16,feats32 = self.base_model(x)
        lat8 = F.relu(self.bn8(self.lat8(feats8)))
        lat16 = F.relu(self.bn16(self.lat16(feats16)))
        lat32 = F.relu(self.bn32(self.lat32(feats32)))
        
        # Add positional info
        mesh2 = get_mesh(batch_size, lat32.shape[2], lat32.shape[3])
        feats = torch.cat([lat32, mesh2], 1)
        x = self.up1(feats, x4)
        x = self.up2(x, x3)
        x = self.outc(x)
        return x


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CentResnet(8).to(device)

Source path:... <ipython-input-7-6c73de720edf>
Starting var:.. self = REPR FAILED
Starting var:.. n_classes = 8
Starting var:.. __class__ = <class '__main__.CentResnet'>
06:21:56.948217 call         4     def __init__(self, n_classes):
06:21:56.952803 line         5         super(CentResnet, self).__init__()
Modified var:.. self = CentResnet()
06:21:56.953210 line         6         self.base_model = base_model
Modified var:.. self = CentResnet(  (base_model): ResNetFeatures(    (c...(16, 512, eps=1e-05, affine=True)      )    )  ))
06:21:56.953739 line         9         self.lat8 = nn.Conv2d(128, 256, 1)
Modified var:.. self = CentResnet(  (base_model): ResNetFeatures(    (c...v2d(128, 256, kernel_size=(1, 1), stride=(1, 1)))
06:21:56.957147 line        10         self.lat16 = nn.Conv2d(256, 256, 1)
Modified var:.. self = CentResnet(  (base_model): ResNetFeatures(    (c...v2d(256, 256, kernel_size=(1, 1), stride=(1, 1)))
06:21:56.959033 line        11         self.lat32 = nn.Conv2d(512

In [9]:
img_batch = torch.randn((1,3,512,2048))
test = model(img_batch.to(device))

Starting var:.. self = CentResnet(  (base_model): ResNetFeatures(    (c...onv2d(256, 8, kernel_size=(1, 1), stride=(1, 1)))
Starting var:.. x = tensor<(1, 3, 512, 2048), float32, cuda:0>
06:22:02.373339 call        28     def forward(self, x):
06:22:02.409489 line        29         batch_size = x.shape[0]
New var:....... batch_size = 1
06:22:02.412173 line        30         mesh1 = get_mesh(batch_size, x.shape[2], x.shape[3])
New var:....... mesh1 = tensor<(1, 2, 512, 2048), float32, cuda:0>
06:22:02.522144 line        31         x0 = torch.cat([x, mesh1], 1)
New var:....... x0 = tensor<(1, 5, 512, 2048), float32, cuda:0>
06:22:02.526005 line        32         x1 = self.mp(self.conv0(x0))
New var:....... x1 = tensor<(1, 64, 256, 1024), float32, cuda:0, grad>
06:22:02.542043 line        33         x2 = self.mp(self.conv1(x1))
New var:....... x2 = tensor<(1, 128, 128, 512), float32, cuda:0, grad>
06:22:02.767264 line        34         x3 = self.mp(self.conv2(x2))
New var:....... x3 = ten

In [10]:
del model
base_model=resnet50(pretrained=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CentResnet(8).to(device)
img_batch = torch.randn((1,3,512,2048))
test = model(img_batch.to(device))

Starting var:.. self = REPR FAILED
Starting var:.. n_classes = 8
Starting var:.. __class__ = <class '__main__.CentResnet'>
06:22:04.265948 call         4     def __init__(self, n_classes):
06:22:04.266648 line         5         super(CentResnet, self).__init__()
Modified var:.. self = CentResnet()
06:22:04.266928 line         6         self.base_model = base_model
Modified var:.. self = CentResnet(  (base_model): ResNetFeatures(    (c...16, 2048, eps=1e-05, affine=True)      )    )  ))
06:22:04.267190 line         9         self.lat8 = nn.Conv2d(128, 256, 1)
Modified var:.. self = CentResnet(  (base_model): ResNetFeatures(    (c...v2d(128, 256, kernel_size=(1, 1), stride=(1, 1)))
06:22:04.269310 line        10         self.lat16 = nn.Conv2d(256, 256, 1)
Modified var:.. self = CentResnet(  (base_model): ResNetFeatures(    (c...v2d(256, 256, kernel_size=(1, 1), stride=(1, 1)))
06:22:04.271357 line        11         self.lat32 = nn.Conv2d(512, 256, 1)
Modified var:.. self = CentResnet(  (

RuntimeError: ignored

In [0]:
@torchsnooper.snoop()
class CentResnet1(nn.Module):
    '''Mixture of previous classes'''
    def __init__(self, n_classes):
        super(CentResnet1, self).__init__()
        self.base_model = base_model
        
        # Lateral layers convert resnet outputs to a common feature size
        self.lat8 = nn.Conv2d(512, 256, 1)
        self.lat16 = nn.Conv2d(1024, 256, 1)
        self.lat32 = nn.Conv2d(2048, 256, 1)
        self.bn8 = nn.GroupNorm(16, 256)
        self.bn16 = nn.GroupNorm(16, 256)
        self.bn32 = nn.GroupNorm(16, 256)

        self.conv0 = double_conv(5, 64)
        self.conv1 = double_conv(64, 128)
        self.conv2 = double_conv(128, 512)
        self.conv3 = double_conv(512, 1024)
        
        self.mp = nn.MaxPool2d(2)
        
        self.up1 = up(1282 , 512) #+ 1024
        self.up2 = up(512 + 512, 256)
        self.outc = nn.Conv2d(256, n_classes, 1)
        
    
    def forward(self, x):
        batch_size = x.shape[0]
        mesh1 = get_mesh(batch_size, x.shape[2], x.shape[3])
        x0 = torch.cat([x, mesh1], 1)
        x1 = self.mp(self.conv0(x0))
        x2 = self.mp(self.conv1(x1))
        x3 = self.mp(self.conv2(x2))
        x4 = self.mp(self.conv3(x3))
        
        # Run frontend network
        feats8, feats16, feats32 = self.base_model(x)
        lat8 = F.relu(self.bn8(self.lat8(feats8)))
        lat16 = F.relu(self.bn16(self.lat16(feats16)))
        lat32 = F.relu(self.bn32(self.lat32(feats32)))
        
        # Add positional info
        mesh2 = get_mesh(batch_size, lat32.shape[2], lat32.shape[3])
        feats = torch.cat([lat32, mesh2], 1)
        x = self.up1(feats, x4)
        x = self.up2(x, x3)
        x = self.outc(x)
        return x


In [5]:
base_model=resnet50(pretrained=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CentResnet1(8).to(device)
img_batch = torch.randn((1,3,512,2048))
test = model(img_batch.to(device))


Source path:... <ipython-input-4-9ac2feb4a147>
Starting var:.. self = REPR FAILED
Starting var:.. n_classes = 8
Starting var:.. __class__ = <class '__main__.CentResnet1'>
06:22:58.848489 call         4     def __init__(self, n_classes):
06:22:58.849016 line         5         super(CentResnet1, self).__init__()
Modified var:.. self = CentResnet1()
06:22:58.849206 line         6         self.base_model = base_model
Modified var:.. self = CentResnet1(  (base_model): ResNetFeatures(    (...16, 2048, eps=1e-05, affine=True)      )    )  ))
06:22:58.849362 line         9         self.lat8 = nn.Conv2d(512, 256, 1)
Modified var:.. self = CentResnet1(  (base_model): ResNetFeatures(    (...v2d(512, 256, kernel_size=(1, 1), stride=(1, 1)))
06:22:58.852840 line        10         self.lat16 = nn.Conv2d(1024, 256, 1)
Modified var:.. self = CentResnet1(  (base_model): ResNetFeatures(    (...2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)))
06:22:58.856627 line        11         self.lat32 = nn.Conv2d