<a href="https://colab.research.google.com/github/rcbusinesstechlab/realtime-face-recognition/blob/main/ok_unet_se.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [2]:
class SqueezeExcitation(nn.Module):
    def __init__(self, nb_channels, reduction=16):
        super(SqueezeExcitation, self).__init__()
        self.nb_channels=nb_channels
        self.avg_pool=nn.AdaptiveAvgPool2d(1)
        self.fc=nn.Sequential(
                nn.Linear(nb_channels, nb_channels // reduction),
                nn.ReLU(inplace=True),
                nn.Linear(nb_channels // reduction, nb_channels),
                nn.Sigmoid())


    def forward(self, x):
        y = self.avg_pool(x).view(-1,self.nb_channels)
        y = self.fc(y).view(-1,self.nb_channels,1,1)
        return x * y


print(SqueezeExcitation(64)(torch.rand(64,568,568)).shape)
summary(SqueezeExcitation(64),input_size=(64,568,568))

torch.Size([1, 64, 568, 568])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
 AdaptiveAvgPool2d-1             [-1, 64, 1, 1]               0
            Linear-2                    [-1, 4]             260
              ReLU-3                    [-1, 4]               0
            Linear-4                   [-1, 64]             320
           Sigmoid-5                   [-1, 64]               0
Total params: 580
Trainable params: 580
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 78.77
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 78.77
----------------------------------------------------------------


In [3]:
class ConvBlock(nn.Module):

    def __init__(self,
            in_ch,
            in_size,
            depth=2,
            kernel_size=3,
            stride=1,
            padding=0,
            out_ch=None,
            bn=True,
            se=True,
            act='relu',
            act_kwargs={}):
        super(ConvBlock, self).__init__()
        self.out_ch=out_ch or 2*in_ch
        self._set_post_processes(self.out_ch,bn,se,act,act_kwargs)
        self._set_conv_layers(
            depth,
            in_ch,
            kernel_size,
            stride,
            padding)
        self.out_size=in_size-depth*2*((kernel_size-1)/2-padding)


    def forward(self, x):
        x=self.conv_layers(x)
        if self.bn:
            x=self.bn(x)
        if self.act:
            x=self._activation(x)
        if self.se:
            x=self.se(x)
        return x


    def _set_post_processes(self,out_channels,bn,se,act,act_kwargs):
        if bn:
            self.bn=nn.BatchNorm2d(out_channels)
        else:
            self.bn=False
        if se:
            self.se=SqueezeExcitation(out_channels)
        else:
            self.se=False
        self.act=act
        self.act_kwargs=act_kwargs


    def _set_conv_layers(
            self,
            depth,
            in_ch,
            kernel_size,
            stride,
            padding):
        layers=[]
        for index in range(depth):
            if index!=0:
                in_ch=self.out_ch
            layers.append(
                nn.Conv2d(
                    in_channels=in_ch,
                    out_channels=self.out_ch,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding))
        self.conv_layers=nn.Sequential(*layers)


    def _activation(self,x):
        return getattr(F,self.act,**self.act_kwargs)(x)


conv_block=ConvBlock(1,572,out_ch=64)
print(conv_block.out_size,conv_block.out_ch)
print(conv_block(torch.rand(1,1,572,572)).shape)
summary(ConvBlock(1,572,out_ch=64),input_size=(1,572,572))

568.0 64
torch.Size([1, 64, 568, 568])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 570, 570]             640
            Conv2d-2         [-1, 64, 568, 568]          36,928
       BatchNorm2d-3         [-1, 64, 568, 568]             128
 AdaptiveAvgPool2d-4             [-1, 64, 1, 1]               0
            Linear-5                    [-1, 4]             260
              ReLU-6                    [-1, 4]               0
            Linear-7                   [-1, 64]             320
           Sigmoid-8                   [-1, 64]               0
 SqueezeExcitation-9         [-1, 64, 568, 568]               0
Total params: 38,276
Trainable params: 38,276
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 631.24
Params size (MB): 0.15
Estimated Total Size (MB): 632.63
--

In [4]:
class DownBlock(nn.Module):

    def __init__(self,
            in_ch,
            in_size,
            out_ch=None,
            depth=2,
            padding=0,
            bn=True,
            se=True,
            act='relu',
            act_kwargs={}):
        super(DownBlock, self).__init__()
        self.out_size=(in_size//2)-depth*(1-padding)*2
        self.out_ch=out_ch or in_ch*2
        self.down=nn.MaxPool2d(kernel_size=2)
        self.conv_block=ConvBlock(
            in_ch=in_ch,
            out_ch=self.out_ch,
            in_size=in_size//2,
            depth=depth,
            padding=padding,
            bn=bn,
            se=se,
            act=act,
            act_kwargs=act_kwargs)


    def forward(self, x):
        x=self.down(x)
        return self.conv_block(x)


db_out=DownBlock(64,568,depth=4)
print(db_out.out_size,db_out.out_ch)
print(db_out(torch.rand(1,64,568,568)).shape)
summary(db_out,input_size=(64,568,568))

276 128
torch.Size([1, 128, 276, 276])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         MaxPool2d-1         [-1, 64, 284, 284]               0
            Conv2d-2        [-1, 128, 282, 282]          73,856
            Conv2d-3        [-1, 128, 280, 280]         147,584
            Conv2d-4        [-1, 128, 278, 278]         147,584
            Conv2d-5        [-1, 128, 276, 276]         147,584
       BatchNorm2d-6        [-1, 128, 276, 276]             256
 AdaptiveAvgPool2d-7            [-1, 128, 1, 1]               0
            Linear-8                    [-1, 8]           1,032
              ReLU-9                    [-1, 8]               0
           Linear-10                  [-1, 128]           1,152
          Sigmoid-11                  [-1, 128]               0
SqueezeExcitation-12        [-1, 128, 276, 276]               0
        ConvBlock-13        [-1, 128, 276, 276]               0


In [16]:
# file ipython-input-11-b11b25979420
class UpBlock(nn.Module):

    @staticmethod
    def cropping(skip_size,size):
        return (skip_size-size)//2


    def __init__(self,
            in_ch,
            in_size,
            out_ch=None,
            bilinear=False,
            crop=None,
            depth=2,
            padding=0,
            bn=True,
            se=True,
            act='relu',
            act_kwargs={}):
        super(UpBlock, self).__init__()
        self.crop=crop
        self.padding=padding
        self.out_size=(in_size*2)-depth*(1-padding)*2
        self.out_ch=out_ch or in_ch//2
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
        self.conv_block=ConvBlock(
            in_ch,
            self.out_size,
            out_ch=self.out_ch,
            depth=depth,
            padding=padding,
            bn=bn,
            se=se,
            act=act,
            act_kwargs=act_kwargs)


    def forward(self, x, skip):
        x = self.up(x)
        skip = self._crop(skip,x)
        x = torch.cat([skip, x], dim=1)
        x = self.conv_block(x)
        return x

    def _crop(self, skip, x):
        if skip.size(2) != x.size(2) or skip.size(3) != x.size(3):
          crop_h, crop_w = self.cropping(skip.size(2), x.size(2)), self.cropping(skip.size(3), x.size(3))
          skip = skip[:, :, crop_h:-crop_h, crop_w:-crop_w]
        return skip # Always return skip, even if not cropped

db_out=UpBlock(256,100)
print(db_out.out_size,db_out.out_ch)
print(db_out(torch.rand(1,256,100,100),torch.rand(1,128,280,280)).shape)

196 128
torch.Size([1, 128, 196, 196])


In [17]:
class UNet(nn.Module):

    def __init__(self,
            network_depth=4,
            conv_depth=2,
            in_size=572,
            in_ch=1,
            out_ch=2,
            init_ch=64,
            padding=0,
            bn=True,
            se=True,
            act='relu',
            act_kwargs={}):
        super(UNet, self).__init__()
        self.network_depth=network_depth
        self.conv_depth=conv_depth
        self.out_ch=out_ch
        self.padding=padding
        self.input_conv=ConvBlock(
            in_ch=in_ch,
            in_size=in_size,
            out_ch=init_ch,
            depth=self.conv_depth,
            padding=padding,
            bn=bn,
            se=se,
            act=act,
            act_kwargs=act_kwargs)
        down_layers=self._down_layers(
            self.input_conv.out_ch,
            self.input_conv.out_size,
            bn=bn,
            se=se,
            act=act,
            act_kwargs=act_kwargs)
        self.down_blocks=nn.ModuleList(down_layers)
        up_layers=self._up_layers(
            down_layers,
            bn=bn,
            se=se,
            act=act,
            act_kwargs=act_kwargs)
        self.up_blocks=nn.ModuleList(up_layers)
        self.out_size=self.up_blocks[-1].out_size
        self.output_conv=self._output_layer(out_ch)


    def forward(self, x):
        x=self.input_conv(x)
        skips=[x]
        for block in self.down_blocks:
            x=block(x)
            skips.append(x)
        skips.pop()
        skips=skips[::-1]
        for skip,block in zip(skips,self.up_blocks):
            x=block(x,skip)
        x=self.output_conv(x)
        return x


    def _down_layers(self,in_ch,in_size,bn,se,act,act_kwargs):
        layers=[]
        for index in range(1,self.network_depth+1):
            layer=DownBlock(
                in_ch,
                in_size,
                depth=self.conv_depth,
                padding=self.padding,
                bn=bn,
                se=se,
                act=act,
                act_kwargs=act_kwargs)
            in_ch=layer.out_ch
            in_size=layer.out_size
            layers.append(layer)
        return layers


    def _up_layers(self,down_layers,bn,se,act,act_kwargs):
        down_layers=down_layers[::-1]
        down_layers.append(self.input_conv)
        first=down_layers.pop(0)
        in_ch=first.out_ch
        in_size=first.out_size
        layers=[]
        for down_layer in down_layers:
            # Note: The cropping logic might need further review
            # if precise output sizes are critical and depend on network depth/padding.
            # This change primarily fixes the immediate TypeError.
            crop=UpBlock.cropping(down_layer.out_size,2*in_size)
            layer=UpBlock(
                in_ch,
                in_size,
                depth=self.conv_depth,
                crop=crop,
                padding=self.padding,
                bn=bn,
                se=se,
                act=act,
                act_kwargs=act_kwargs)
            in_ch=layer.out_ch
            in_size=layer.out_size
            layers.append(layer)
        return layers


    def _output_layer(self,out_ch):
        return nn.Conv2d(
           in_channels=64,
           out_channels=out_ch,
           kernel_size=1,
           stride=1,
           padding=0)


unet=UNet(in_size=572,network_depth=4,conv_depth=2)
print(unet.network_depth,unet.conv_depth)
print(unet.out_size,unet.out_ch)
print(unet(torch.rand(1,1,572,572)).shape)
summary(unet,input_size=(1,572,572))


4 2
388.0 2
torch.Size([1, 2, 388, 388])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 570, 570]             640
            Conv2d-2         [-1, 64, 568, 568]          36,928
       BatchNorm2d-3         [-1, 64, 568, 568]             128
 AdaptiveAvgPool2d-4             [-1, 64, 1, 1]               0
            Linear-5                    [-1, 4]             260
              ReLU-6                    [-1, 4]               0
            Linear-7                   [-1, 64]             320
           Sigmoid-8                   [-1, 64]               0
 SqueezeExcitation-9         [-1, 64, 568, 568]               0
        ConvBlock-10         [-1, 64, 568, 568]               0
        MaxPool2d-11         [-1, 64, 284, 284]               0
           Conv2d-12        [-1, 128, 282, 282]          73,856
           Conv2d-13        [-1, 128, 280, 280]         147,58

In [13]:
unet=UNet(in_size=572,network_depth=2,conv_depth=4)
print(unet.network_depth,unet.conv_depth)
print(unet.out_size,unet.out_ch)
print(unet(torch.rand(1,1,572,572)).shape)
summary(unet,input_size=(1,572,572))

2 4
492.0 2
torch.Size([1, 2, 492, 492])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 570, 570]             640
            Conv2d-2         [-1, 64, 568, 568]          36,928
            Conv2d-3         [-1, 64, 566, 566]          36,928
            Conv2d-4         [-1, 64, 564, 564]          36,928
       BatchNorm2d-5         [-1, 64, 564, 564]             128
 AdaptiveAvgPool2d-6             [-1, 64, 1, 1]               0
            Linear-7                    [-1, 4]             260
              ReLU-8                    [-1, 4]               0
            Linear-9                   [-1, 64]             320
          Sigmoid-10                   [-1, 64]               0
SqueezeExcitation-11         [-1, 64, 564, 564]               0
        ConvBlock-12         [-1, 64, 564, 564]               0
        MaxPool2d-13         [-1, 64, 282, 282]               

In [14]:
SIZE=256
unet=UNet(in_size=SIZE,network_depth=2)
print(unet.network_depth,unet.conv_depth)
print(unet.out_size,unet.out_ch)
print(unet(torch.rand(1,1,SIZE,SIZE)).shape)
summary(unet,input_size=(1,SIZE,SIZE))

2 2
216.0 2
torch.Size([1, 2, 216, 216])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 254, 254]             640
            Conv2d-2         [-1, 64, 252, 252]          36,928
       BatchNorm2d-3         [-1, 64, 252, 252]             128
 AdaptiveAvgPool2d-4             [-1, 64, 1, 1]               0
            Linear-5                    [-1, 4]             260
              ReLU-6                    [-1, 4]               0
            Linear-7                   [-1, 64]             320
           Sigmoid-8                   [-1, 64]               0
 SqueezeExcitation-9         [-1, 64, 252, 252]               0
        ConvBlock-10         [-1, 64, 252, 252]               0
        MaxPool2d-11         [-1, 64, 126, 126]               0
           Conv2d-12        [-1, 128, 124, 124]          73,856
           Conv2d-13        [-1, 128, 122, 122]         147,58

In [18]:
SIZE=256
unet=UNet(in_size=SIZE,network_depth=5,padding=1)
print(unet.network_depth,unet.conv_depth)
print(unet.out_size,unet.out_ch)
print(unet(torch.rand(1,1,SIZE,SIZE)).shape)
summary(unet,input_size=(1,SIZE,SIZE))

5 2
256.0 2
torch.Size([1, 2, 256, 256])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
            Conv2d-2         [-1, 64, 256, 256]          36,928
       BatchNorm2d-3         [-1, 64, 256, 256]             128
 AdaptiveAvgPool2d-4             [-1, 64, 1, 1]               0
            Linear-5                    [-1, 4]             260
              ReLU-6                    [-1, 4]               0
            Linear-7                   [-1, 64]             320
           Sigmoid-8                   [-1, 64]               0
 SqueezeExcitation-9         [-1, 64, 256, 256]               0
        ConvBlock-10         [-1, 64, 256, 256]               0
        MaxPool2d-11         [-1, 64, 128, 128]               0
           Conv2d-12        [-1, 128, 128, 128]          73,856
           Conv2d-13        [-1, 128, 128, 128]         147,58

In [19]:
SIZE=256
unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=False)
print(unet.network_depth,unet.conv_depth)
print(unet.out_size,unet.out_ch)
print(unet(torch.rand(1,1,SIZE,SIZE)).shape)
summary(unet,input_size=(1,SIZE,SIZE))

2 2
216.0 2
torch.Size([1, 2, 216, 216])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 254, 254]             640
            Conv2d-2         [-1, 64, 252, 252]          36,928
         ConvBlock-3         [-1, 64, 252, 252]               0
         MaxPool2d-4         [-1, 64, 126, 126]               0
            Conv2d-5        [-1, 128, 124, 124]          73,856
            Conv2d-6        [-1, 128, 122, 122]         147,584
         ConvBlock-7        [-1, 128, 122, 122]               0
         DownBlock-8        [-1, 128, 122, 122]               0
         MaxPool2d-9          [-1, 128, 61, 61]               0
           Conv2d-10          [-1, 256, 59, 59]         295,168
           Conv2d-11          [-1, 256, 57, 57]         590,080
        ConvBlock-12          [-1, 256, 57, 57]               0
        DownBlock-13          [-1, 256, 57, 57]               

In [20]:
SIZE=256
unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=True)
print(unet.network_depth,unet.conv_depth)
print(unet.out_size,unet.out_ch)
print(unet(torch.rand(1,1,SIZE,SIZE)).shape)
summary(unet,input_size=(1,SIZE,SIZE))

2 2
216.0 2
torch.Size([1, 2, 216, 216])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 254, 254]             640
            Conv2d-2         [-1, 64, 252, 252]          36,928
 AdaptiveAvgPool2d-3             [-1, 64, 1, 1]               0
            Linear-4                    [-1, 4]             260
              ReLU-5                    [-1, 4]               0
            Linear-6                   [-1, 64]             320
           Sigmoid-7                   [-1, 64]               0
 SqueezeExcitation-8         [-1, 64, 252, 252]               0
         ConvBlock-9         [-1, 64, 252, 252]               0
        MaxPool2d-10         [-1, 64, 126, 126]               0
           Conv2d-11        [-1, 128, 124, 124]          73,856
           Conv2d-12        [-1, 128, 122, 122]         147,584
AdaptiveAvgPool2d-13            [-1, 128, 1, 1]               

In [None]:
SIZE=256
unet=UNet(in_size=SIZE,network_depth=2,act='elu',se=False)
print(unet.network_depth,unet.conv_depth)
print(unet.out_size,unet.out_ch)
print(unet(torch.rand(1,1,SIZE,SIZE)).shape)
summary(unet,input_size=(1,SIZE,SIZE))

(2, 2)
(216, 2)
torch.Size([1, 2, 216, 216])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 254, 254]             640
            Conv2d-2         [-1, 64, 252, 252]          36,928
       BatchNorm2d-3         [-1, 64, 252, 252]             128
         ConvBlock-4         [-1, 64, 252, 252]               0
         MaxPool2d-5         [-1, 64, 126, 126]               0
            Conv2d-6        [-1, 128, 124, 124]          73,856
            Conv2d-7        [-1, 128, 122, 122]         147,584
       BatchNorm2d-8        [-1, 128, 122, 122]             256
         ConvBlock-9        [-1, 128, 122, 122]               0
        DownBlock-10        [-1, 128, 122, 122]               0
        MaxPool2d-11          [-1, 128, 61, 61]               0
           Conv2d-12          [-1, 256, 59, 59]         295,168
           Conv2d-13          [-1, 256, 57, 57]         59