In [5]:
import torchvision.models as models
import torch.nn as nn
from torch import Tensor
import torch
from torchinfo import summary


In [6]:
from torchvision.models.resnet import BasicBlock

In [23]:
class Unet_ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_input = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, padding=0),
                                  nn.ReLU(inplace=True))
        layers = []
        for _ in range(0, 5):
            layers.append(BasicBlock(64, 64))
        self.blocks = nn.Sequential(*layers)
        
        self.conv_end = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0),
                                  nn.ReLU(inplace=True))
        #Reference source code for initialization of Batch Norm and Conv2d https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv_input(x)
        x = self.blocks(x)
        x = self.conv_end(x)
        return x

In [24]:
unet = Unet_ResNet()

In [25]:
list(unet.children())

[Sequential(
   (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
   (1): ReLU(inplace=True)
 ),
 Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
   (1): BasicBlock(
     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
   (2): BasicBlock(
     (conv1): Co

In [26]:
batch_size = 1
summary(unet, input_size=(batch_size, 3, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
Unet_ResNet                              --                        --
├─Sequential: 1-1                        [1, 64, 28, 28]           --
│    └─Conv2d: 2-1                       [1, 64, 28, 28]           256
│    └─ReLU: 2-2                         [1, 64, 28, 28]           --
├─Sequential: 1-2                        [1, 64, 28, 28]           --
│    └─BasicBlock: 2-3                   [1, 64, 28, 28]           --
│    │    └─Conv2d: 3-1                  [1, 64, 28, 28]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 28, 28]           128
│    │    └─ReLU: 3-3                    [1, 64, 28, 28]           --
│    │    └─Conv2d: 3-4                  [1, 64, 28, 28]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 28, 28]           128
│    │    └─ReLU: 3-6                    [1, 64, 28, 28]           --
│    └─BasicBlock: 2-4                   [1, 64, 28, 28]           --
│   

In [27]:
class Lens_Network(nn.Module):
    def __init__(self):
        super().__init__()
        #Lens component
        self.lens = Unet_ResNet()
        
        #Feature extraction
        res = models.resnet50()
        res.fc = torch.nn.Linear(in_features=2048, out_features=4, bias=True)
        self.res = res
        
    def forward(self, x: Tensor) -> Tensor:
        orig = x
        lens_output = self.lens(x)
        x = lens_output + orig
        x = self.res(x)
        return lens_output, x

In [28]:
lens = Lens_Network()