In [35]:
import torch.nn as nn
import torch

In [36]:
class BasicBlock(nn.Module) :
    expansion = 1
    def __init__(self, in_channels, mid_channels, out_channels, stride = 1) :
        super().__init__()
        
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size = 3, stride = stride, padding=1, bias = False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
            nn.Conv2d(mid_channels, out_channels * BasicBlock.expansion, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm2d(out_channels*BasicBlock.expansion),
        )
        self.shortcut = nn.Sequential()
        
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x) :
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


In [None]:
block1_1 = BasicBlock(130, 256, 256)
block1_2 = BasicBlock(256, 256, 256)
conv1 = nn.Conv2d(in_channels = 256, out_channels = 256,  kernel_size = 3, stride = 1, padding = 1)
maxpool = nn.MaxPool2d(kernel_size = (2,2), stride=(1,2), padding= (1,0))
block2_1 = BasicBlock(256, 512, 256)
# block2_2 = BasicBlock(512, 256, 256)

In [None]:
x = block1_1(samp)
print(x.shape)
x = block1_2(x)
print(x.shape)
x = conv1(x)
print(x.shape)
x = maxpool(x)
print(x.shape)
x = block2_1(x)
print(x.shape)

In [None]:
samp = torch.rand((130,16,64)).unsqueeze(0)

resblock2(resblock1(samp)).shape

In [None]:
x.shape

In [37]:
class SimpleResNet(nn.Module) :
    def __init__(self, block, num_blocks, init_weights = True) :
        super().__init__()
        self.in_channels = 130
        self.block1 = self._make_layer(block, 256, 256, num_blocks[0])
        self.conv1 = nn.Conv2d(self.in_channels, out_channels = self.in_channels, kernel_size= 3, padding = 1)
        self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = (2,1), padding = (0,1))
        self.block2 = self._make_layer(block, 256, 512, num_blocks[1])
        self.conv2 = nn.Conv2d(self.in_channels, out_channels = self.in_channels, kernel_size = 3, padding = 1)
        self.block3 = self._make_layer(block, 512, 512, num_blocks[2])
        self.conv3 = nn.Conv2d(self.in_channels, out_channels = self.in_channels, kernel_size = 2, stride = (2,1), padding = (0,1))
        self.conv4 = nn.Conv2d(self.in_channels, out_channels = self.in_channels, kernel_size = 2, stride = (1,1), padding = (0,0))
        self.avgpool = nn.AvgPool2d(kernel_size = (3,1), stride = (2,1), padding = (0,0))
        
        if init_weights :
            self._initialize_weights()
    
    def _make_layer(self, block, mid_channels, out_channels, num_blocks, stride = 1) :
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides :
            layers.append(block(self.in_channels, mid_channels, out_channels, stride))
            self.in_channels = out_channels
            
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    
    def forward(self, x) :
        output = self.block1(x)
        output = self.conv1(output)
        output = self.maxpool(output)
        output = self.block2(output)
        output = self.conv2(output)
        output = self.block3(output)
        output = self.conv3(output)
        output = self.conv4(output)
        output = self.avgpool(output)
        return output
    


In [38]:
samp_net = SimpleResNet(BasicBlock, [2,5,3])

In [40]:
samp = torch.randn(([53, 130, 16, 64]))

samp_net(samp).shape

torch.Size([53, 512, 1, 65])

In [None]:
x = samp_net.block1(samp)
x = samp_net.conv1(x)

x = samp_net.maxpool(x)
x = samp_net.block2(x)

In [None]:
resnet_list = list(resnet.children())

In [None]:
i = 0 
for bottleneck in resnet_list :
    if isinstance(bottleneck, nn.Sequential) :
        i+= 1
        print(bottleneck)
print(i)

## STD

In [4]:
import torch
import torch.nn as nn
import torchvision
resnet = torchvision.models.resnet.resnet50(pretrained=True)


class ConvBlock(nn.Module):
    """
    Helper module that consists of a Conv -> BN -> ReLU
    """

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x


class Bridge(nn.Module):
    """
    This is the middle layer of the UNet which just consists of some
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)


class UpBlockForUNetWithResNet50(nn.Module):
    """
    Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
    """

    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """
        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x


class UNetWithResnet50Encoder(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=4):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True)
        down_blocks = []
        up_blocks = []
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
        self.down_blocks = nn.ModuleList(down_blocks)
        self.bridge = Bridge(2048, 2048)
        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

    def forward(self, x, with_output_feature_map=True):
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x

        x = self.bridge(x)

        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
            if i == UNetWithResnet50Encoder.DEPTH - 1 :
                output_feature_map = x
            x = block(x, pre_pools[key])
        
        x = self.out(x)
        del pre_pools
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

model = UNetWithResnet50Encoder().cuda()
inp = torch.rand((2, 3, 512, 512)).cuda()
out = model(inp)

In [None]:
down_blocks = []

for bottleneck in resnet_list :
    if isinstance(bottleneck, nn.Sequential) :
        down_blocks.append(bottleneck)
        
for i, block in enumerate(down_blocks, 2) :
    if i == (UNetWithResnet50Encoder.DEPTH - 1):
        continue
    print(block)

In [None]:
list(resnet.children())

In [5]:
class Residual_Block(nn.Module): 
    def __init__(self, in_dim, mid_dim, out_dim): 
        super(Residual_Block,self).__init__() # Residual Block 
        self.residual_block = nn.Sequential( nn.Conv2d(in_dim, mid_dim, kernel_size=3, padding=1), 
                                            nn.ReLU(), 
                                            nn.Conv2d(mid_dim, out_dim, kernel_size=3, padding=1), ) 
        self.relu = nn.ReLU() 
    
    def forward(self, x): 
        out = self.residual_block(x) # F(x) out = out + x # F(x) + x 
        out = self.relu(out) 
        return out


In [None]:
rblock = Residual_Block(256, 256, 256)

In [6]:
class UpBlock(nn.Module):
    """
    Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
    """

    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """
        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x


In [7]:
class BottleNeck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()

        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        self.relu = nn.ReLU()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels*BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels*BottleNeck.expansion)
            )
            
    def forward(self, x):
        x = self.residual_function(x) + self.shortcut(x)
        x = self.relu(x)
        return x
    

class Bridge(nn.Module):
    """
    This is the middle layer of the UNet which just consists of some
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)

In [8]:
class UpBlock(nn.Module) :
    def __init__(self, input_channel, output_channel, upsampling_method = 'conv_transpose') :
        super().__init__()
        if upsampling_method == "conv_transpose":
            self.UpSample = nn.ConvTranspose2d(input_channel, input_channel, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.UpSample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(input_channel, input_channel, kernel_size=1, stride=1)
            )

        self.UpConv_Block = nn.Sequential(
                                            nn.Conv2d(input_channel*2, output_channel*2, kernel_size = 1, stride = 1),
                                            nn.BatchNorm2d(output_channel*2),
                                            nn.ReLU(),
                                            nn.Conv2d(output_channel*2, output_channel, kernel_size = 3, stride = 1, padding = 1),
                                            nn.BatchNorm2d(output_channel),
                                            nn.ReLU()            
                                            )
    
    def forward(self, up_x, down_x, return_output = False) :
        x = self.UpSample(up_x)
        x = torch.cat([x, down_x], 1)
        if return_output :
            output_feature = x
        x = self.UpConv_Block(x)
        if return_output :
            return x, output_feature
        else :
            return x
        
        

In [None]:
bottleneck_1 = BottleNeck(3, 16)

bottleneck_1(torch.rand(1,3,128,128)).shape

In [32]:
class ResNetUNet(nn.Module) :
    def __init__(self, input_channel, n_classes) :
        super().__init__()
        self.input_channel = input_channel
        self.down_block1 = BottleNeck(self.input_channel, 16)
        self.down_block2 = BottleNeck(64, 64)
        self.down_block3 = BottleNeck(256, 128)
        self.down_block4 = BottleNeck(512, 256)
        self.down_block5 = BottleNeck(1024, 512)
        self.bridge = nn.ConvTranspose2d(2048, 1024, kernel_size=1, stride=1)
        
        self.up_block1 = UpBlock(1024, 512)
        self.up_block2 = UpBlock(512, 256)
        self.up_block3 = UpBlock(256, 64)
        self.up_block4 = UpBlock(64, 32)
        self.last_layer = nn.Sequential(
                                       nn.Conv2d(32,32, kernel_size= 3, stride = 1, padding = 1),
                                       nn.Conv2d(32,32, kernel_size= 3, stride = 1, padding = 1),
                                       nn.Conv2d(32,16, kernel_size= 3, stride = 1, padding = 1),
                                       nn.Conv2d(16,16, kernel_size= 1),
                                       nn.Conv2d(16,n_classes, kernel_size = 1, stride = 1)
                                      )
        
    def forward(self, x) :
        pre_pools = dict()
        x = self.down_block1(x)
        print(x.shape)
        pre_pools[f"layer_1"] = x
        x = self.down_block2(x)
        print(x.shape)
        pre_pools[f"layer_2"] = x
        x = self.down_block3(x)
        print(x.shape)
        pre_pools[f"layer_3"] = x
        x = self.down_block4(x) 
        print(x.shape)
        pre_pools[f"layer_4"] = x
        x = self.down_block5(x) 
        print(x.shape)
        x = self.bridge(x)
        print(x.shape)
        x = self.up_block1(x, pre_pools['layer_4'])
        print(x.shape)
        x = self.up_block2(x, pre_pools['layer_3'])
        print(x.shape)
        x = self.up_block3(x, pre_pools['layer_2'])
        print(x.shape)
        x, output_feature = self.up_block4(x, pre_pools['layer_1'], return_output = True)
        print(x.shape)
        x = self.last_layer(x)
             
        return x, output_feature
    
   

In [33]:

sampnet = ResNetUNet(3, 4)

x, output_feature = sampnet(torch.rand(1,3,512,512))

torch.Size([1, 64, 256, 256])
torch.Size([1, 256, 128, 128])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 32, 32])
torch.Size([1, 2048, 16, 16])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 512, 32, 32])
torch.Size([1, 256, 64, 64])
torch.Size([1, 64, 128, 128])
torch.Size([1, 32, 256, 256])


In [34]:
x.shape

torch.Size([1, 4, 256, 256])

In [None]:

torch.cat([x[:,:2,:,:], output_feature], 1).shape

In [None]:
sampnet

In [24]:
from ResUnet import CRAFT
import torch
net = CRAFT(input_channel = 3, n_classes = 4)

In [26]:
768/32

24.0

In [29]:
samp = torch.randn((1,3,768,768))

x, output_feature = net(samp)

In [30]:
x.shape

torch.Size([1, 4, 384, 384])

In [31]:
output_feature.shape

torch.Size([1, 128, 384, 384])