!pip install timm --quiet
!pip install fastai --quiet

In [13]:
import torch
import timm
import torch.nn as nn
import torch.nn.functional as F

In [14]:
for o in [torch, timm]:
    print(f"{o.__name__} \t {o.__version__}")

torch 	 1.10.1+cu113
timm 	 0.5.4


Create 2 models from different architecture families. 
- Resnet
- EfficientNetV2

In [15]:
resnet = timm.create_model(model_name='resnet50',pretrained=True)
effnet = timm.create_model(model_name='tf_efficientnetv2_b0',pretrained=True)

Sometimes we want to change the head of the architectures instead of a simple Linear layer. Before PyTorch Lazy modules, we need to inspect each of the models, to understand the input features and then create a head. 

In [16]:
resnet.fc, effnet.classifier

(Linear(in_features=2048, out_features=1000, bias=True),
 Linear(in_features=1280, out_features=1000, bias=True))

In [17]:
resnet.fc.in_features, effnet.classifier.in_features

(2048, 1280)

In [18]:
def create_head_old(in_features,out_features):
    head = nn.Sequential(
        nn.Linear(in_features=in_features, out_features=512,bias=False),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(512),
        nn.Dropout(),
        nn.Linear(in_features=512,out_features=out_features,bias=False))
    return head

Lets replace the head of the resnet and efficientnet models.

In [19]:
resnet.fc = create_head_old(2048,1)
effnet.classifier = create_head_old(1280,1)

In [20]:
resnet.fc, effnet.classifier

(Sequential(
   (0): Linear(in_features=2048, out_features=512, bias=False)
   (1): ReLU(inplace=True)
   (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (3): Dropout(p=0.5, inplace=False)
   (4): Linear(in_features=512, out_features=1, bias=False)
 ),
 Sequential(
   (0): Linear(in_features=1280, out_features=512, bias=False)
   (1): ReLU(inplace=True)
   (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (3): Dropout(p=0.5, inplace=False)
   (4): Linear(in_features=512, out_features=1, bias=False)
 ))

In [21]:
resnet = timm.create_model(model_name='resnet50',pretrained=True)
effnet = timm.create_model(model_name='tf_efficientnetv2_b0',pretrained=True)

def create_head_new(out_features):
    head = nn.Sequential(
                nn.LazyLinear(512,bias=False),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(512),
                nn.Dropout(),
                nn.Linear(in_features=512,
                          out_features=out_features,
                          bias=False))
    return head

In [22]:
resnet.fc = create_head_new(1)
effnet.classifier = create_head_new(1)



In [23]:
resnet.fc

Sequential(
  (0): LazyLinear(in_features=0, out_features=512, bias=False)
  (1): ReLU(inplace=True)
  (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=512, out_features=1, bias=False)
)

In [26]:
dummy_tensor = torch.randn((2,3,224,224))
_ = resnet(dummy_tensor)
resnet.fc

Sequential(
  (0): Linear(in_features=2048, out_features=512, bias=False)
  (1): ReLU(inplace=True)
  (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=512, out_features=1, bias=False)
)

That may be actually simple, where we may not appreciate the potential of Lazy Modules. But I believe it can simplify how we can write a complex UNet architecture. 

Lets, grab a Unet Model from one of my earlier project and checkhow we can use LazyModule. 

## Unet

In [27]:
class Encoder(nn.Module):
    def __init__(self, model_name='resnext50_32x4d'):
        super().__init__()
        self.encoder = timm.create_model(model_name, features_only=True, pretrained=False)
        
    def forward(self, x):
        return self.encoder(x)

In [28]:
enc = Encoder('resnet50')

In [29]:
dummy_batch = torch.rand((2,3,224,224))

In [30]:
def conv_block(in_feat,out_feat):
    conv_block = nn.Sequential(nn.Conv2d(in_feat,out_feat,3,1,1,bias=False),
                               nn.BatchNorm2d(out_feat),
                               nn.ReLU())
    return conv_block

conv_block(3,64)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

In [31]:
class UnetBlock(nn.Module):
    def __init__(self,in_channels,chanels,out_channels):
        super().__init__()
        self.conv1 = conv_block(in_channels,chanels)
        self.conv2 = conv_block(chanels,out_channels)
        
    def forward(self,x):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [32]:
class UnetDecoder(nn.Module):
    def __init__(self, fs=32, expansion=4,n_out=1):
        super().__init__()
        center_ch = 512*expansion
        decoder5_ch = center_ch + (256*expansion)
        channels = 512
        self.center = nn.Sequential(conv_block(center_ch,center_ch),conv_block(center_ch,center_ch//2))
        self.decoder5 = UnetBlock(decoder5_ch,channels,fs) 
        self.decoder4 = UnetBlock(256*expansion+fs,256,fs) 
        self.decoder3 = UnetBlock(128*expansion+fs,128,fs)
        self.decoder2 = UnetBlock(64*expansion+fs,64,fs)
        self.decoder1 = UnetBlock(fs,fs,fs)
        self.logit = nn.Sequential(conv_block(fs,fs//2),conv_block(fs//2,fs//2),nn.Conv2d(fs//2,n_out,kernel_size=1))
        
    def forward(self, feats):
        e1,e2,e3,e4,e5 = feats #'64 256 512 1024 2048'
        f = self.center(e5)
        d5 = self.decoder5(torch.cat([f, e5], 1))
        d4 = self.decoder4(torch.cat([d5, e4], 1))
        d3 = self.decoder3(torch.cat([d4, e3], 1))
        d2 = self.decoder2(torch.cat([d3, e2], 1))
        d1 = self.decoder1(d2)
        return self.logit(d1)

In [33]:
class Unet(nn.Module):
    def __init__(self, fs=32, expansion=4, model_name='resnext50_32x4d',n_out=1):
        super().__init__()
        self.encoder = Encoder(model_name)
        self.decoder = UnetDecoder(fs=fs, expansion=expansion,n_out=n_out)
        
    def forward(self, x):
        feats = self.encoder(x) 
        out = self.decoder(feats)
        return out

In [34]:
dummy_batch = torch.rand((2,3,224,224))
unet = Unet()
unet(dummy_batch).shape

torch.Size([2, 1, 224, 224])

## With LazyModules

In [35]:
def lazy_conv_block(out_feat):
    conv_block = nn.Sequential(nn.LazyConv2d(out_feat,3,1,1,bias=False),
                               nn.BatchNorm2d(out_feat),
                               nn.ReLU())
    return conv_block

class LazyUnetBlock(nn.Module):
    def __init__(self,chanels,out_channels):
        super().__init__()
        self.conv1 = lazy_conv_block(chanels)
        self.conv2 = lazy_conv_block(out_channels)
        
    def forward(self,x):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class LazyUnetDecoder(nn.Module):
    def __init__(self, fs=32, expansion=4,n_out=1):
        super().__init__()
        channels = 512
        center_ch = channels*expansion
        self.center = nn.Sequential(lazy_conv_block(center_ch),lazy_conv_block(center_ch//2))
        self.decoder5 = LazyUnetBlock(channels,fs) 
        self.decoder4 = LazyUnetBlock(channels//2,fs) 
        self.decoder3 = LazyUnetBlock(channels//4,fs)
        self.decoder2 = LazyUnetBlock(channels//8,fs)
        self.decoder1 = LazyUnetBlock(fs,fs)
        self.logit = nn.Sequential(lazy_conv_block(fs//2),lazy_conv_block(fs//2),nn.Conv2d(fs//2,n_out,kernel_size=1))
        
    def forward(self, feats):
        e1,e2,e3,e4,e5 = feats #'64 256 512 1024 2048'
        f = self.center(e5)
        d5 = self.decoder5(torch.cat([f, e5], 1))
        d4 = self.decoder4(torch.cat([d5, e4], 1))
        d3 = self.decoder3(torch.cat([d4, e3], 1))
        d2 = self.decoder2(torch.cat([d3, e2], 1))
        d1 = self.decoder1(d2)
        return self.logit(d1)

class LazyUnet(nn.Module):
    def __init__(self, fs=32, expansion=4, model_name='resnext50_32x4d',n_out=1):
        super().__init__()
        self.encoder = Encoder(model_name)
        self.decoder = LazyUnetDecoder(fs=fs, expansion=expansion,n_out=n_out)
        
    def forward(self, x):
        feats = self.encoder(x) 
        out = self.decoder(feats)
        return out

In [36]:

unet2 = LazyUnet()

In [37]:
unet2.decoder

LazyUnetDecoder(
  (center): Sequential(
    (0): Sequential(
      (0): LazyConv2d(0, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): LazyConv2d(0, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (decoder5): LazyUnetBlock(
    (conv1): Sequential(
      (0): LazyConv2d(0, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Sequential(
      (0): LazyConv2d(0, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
  

In [38]:
unet2(dummy_batch)

tensor([[[[-0.4978, -0.7660, -0.3222,  ..., -0.0819, -0.1124, -0.2113],
          [-0.1036, -0.4169, -0.1280,  ...,  0.0572,  0.1489, -0.5093],
          [-0.1304, -0.3900, -0.2854,  ..., -0.2363, -0.1311, -0.5610],
          ...,
          [-0.3159, -0.5664, -0.5392,  ..., -0.3108, -0.0590, -0.6070],
          [-0.6864, -0.4726, -0.3104,  ..., -0.4545, -0.0874, -0.3718],
          [-0.2720, -0.1937, -0.3484,  ..., -0.2758, -0.2565, -0.5056]]],


        [[[-0.1785, -0.3436, -0.4608,  ..., -0.1940, -0.1168, -0.1306],
          [-0.2678, -0.2984, -0.0957,  ..., -0.4747, -0.0239, -0.1774],
          [-0.4147, -0.7018, -0.6035,  ..., -0.3336, -0.0444, -0.5637],
          ...,
          [-0.1437, -0.3104, -0.6257,  ..., -0.0048, -0.0482, -0.0139],
          [-0.2603, -0.2087, -0.1792,  ...,  0.0129,  0.0717, -0.2716],
          [-0.2750, -0.2776, -0.2603,  ..., -0.0762, -0.0597, -0.6284]]]],
       grad_fn=<MkldnnConvolutionBackward0>)

In [39]:
unet2.decoder

LazyUnetDecoder(
  (center): Sequential(
    (0): Sequential(
      (0): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (decoder5): LazyUnetBlock(
    (conv1): Sequential(
      (0): Conv2d(3072, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Sequential(
      (0): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
 

In [40]:
Unet(model_name='resnet18')(dummy_batch).shape

RuntimeError: Given groups=1, weight of size [2048, 2048, 3, 3], expected input[2, 512, 7, 7] to have 2048 channels, but got 512 channels instead

In [41]:
LazyUnet(model_name='resnet18')(dummy_batch).shape

torch.Size([2, 1, 224, 224])

In [42]:
Unet(model_name='resnet18')(dummy_batch).shape

RuntimeError: Given groups=1, weight of size [2048, 2048, 3, 3], expected input[2, 512, 7, 7] to have 2048 channels, but got 512 channels instead

In [43]:
# Unet(model_name='efficientnet_em')(dummy_batch).shape
LazyUnet(model_name='efficientnet_em')(dummy_batch).shape

torch.Size([2, 1, 224, 224])

In [44]:
timm.list_models(filter='*eff*',pretrained=True)

['efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b1_pruned',
 'efficientnet_b2',
 'efficientnet_b2_pruned',
 'efficientnet_b3',
 'efficientnet_b3_pruned',
 'efficientnet_b4',
 'efficientnet_el',
 'efficientnet_el_pruned',
 'efficientnet_em',
 'efficientnet_es',
 'efficientnet_es_pruned',
 'efficientnet_lite0',
 'efficientnetv2_rw_m',
 'efficientnetv2_rw_s',
 'efficientnetv2_rw_t',
 'gc_efficientnetv2_rw_t',
 'tf_efficientnet_b0',
 'tf_efficientnet_b0_ap',
 'tf_efficientnet_b0_ns',
 'tf_efficientnet_b1',
 'tf_efficientnet_b1_ap',
 'tf_efficientnet_b1_ns',
 'tf_efficientnet_b2',
 'tf_efficientnet_b2_ap',
 'tf_efficientnet_b2_ns',
 'tf_efficientnet_b3',
 'tf_efficientnet_b3_ap',
 'tf_efficientnet_b3_ns',
 'tf_efficientnet_b4',
 'tf_efficientnet_b4_ap',
 'tf_efficientnet_b4_ns',
 'tf_efficientnet_b5',
 'tf_efficientnet_b5_ap',
 'tf_efficientnet_b5_ns',
 'tf_efficientnet_b6',
 'tf_efficientnet_b6_ap',
 'tf_efficientnet_b6_ns',
 'tf_efficientnet_b7',
 'tf_efficientnet_b7_ap',
 'tf_effi