In [1]:
import torchvision.models as models
import torch.nn as nn
from torch import Tensor
import torch
from torchinfo import summary


In [2]:
from torchvision.models.resnet import BasicBlock, Bottleneck

In [12]:
class Unet_ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_input = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False),
                      nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                      nn.ReLU(inplace=True))
        layers = []
        downsample = nn.Sequential(
          nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
          nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        layers.append(Bottleneck(64,64, downsample=downsample))
        for _ in range(0, 4):
            layers.append(Bottleneck(256, 64))
        self.blocks = nn.Sequential(*layers)
        self.conv_end = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=3, kernel_size=1, stride=1, padding=0),
                                  nn.ReLU(inplace=True))
        #Reference source code for initialization of Batch Norm and Conv2d https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv_input(x)
        x = self.blocks(x)
        x = self.conv_end(x)
        return x

In [13]:
unet = Unet_ResNet()

In [14]:
list(unet.children())

[Sequential(
   (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
   (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU(inplace=True)
 ),
 Sequential(
   (0): Bottleneck(
     (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (downsample): Sequential(
       (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
   )
 

In [15]:
batch_size = 5
summary(unet, input_size=(batch_size, 3, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
Unet_ResNet                              --                        --
├─Sequential: 1-1                        [5, 64, 28, 28]           --
│    └─Conv2d: 2-1                       [5, 64, 28, 28]           192
│    └─BatchNorm2d: 2-2                  [5, 64, 28, 28]           128
│    └─ReLU: 2-3                         [5, 64, 28, 28]           --
├─Sequential: 1-2                        [5, 256, 28, 28]          --
│    └─Bottleneck: 2-4                   [5, 256, 28, 28]          --
│    │    └─Conv2d: 3-1                  [5, 64, 28, 28]           4,096
│    │    └─BatchNorm2d: 3-2             [5, 64, 28, 28]           128
│    │    └─ReLU: 3-3                    [5, 64, 28, 28]           --
│    │    └─Conv2d: 3-4                  [5, 64, 28, 28]           36,864
│    │    └─BatchNorm2d: 3-5             [5, 64, 28, 28]           128
│    │    └─ReLU: 3-6                    [5, 64, 28, 28]           --
│   

In [27]:
class Lens_Network(nn.Module):
    def __init__(self):
        super().__init__()
        #Lens component
        self.lens = Unet_ResNet()
        
        #Feature extraction
        res = models.resnet50()
        res.fc = torch.nn.Linear(in_features=2048, out_features=4, bias=True)
        self.res = res
        
    def forward(self, x: Tensor) -> Tensor:
        orig = x
        lens_output = self.lens(x)
        x = lens_output + orig
        x = self.res(x)
        return lens_output, x