In [1]:
import torch
import torch.nn as nn
from torchvision import models
from torchinfo import summary

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


## Resnet18

In [2]:
class ResNet(nn.Module):
    def __init__(self, model_name: str = "resnet18", weights: str = "DEFAULT"):
        super().__init__()
        self.network = models.get_model(name="resnet18", weights=weights)
        self.network.fc = nn.Linear(self.network.fc.in_features, 136)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.network(x)
        x = x.reshape(x.size(0), 68, 2)
        return x

In [3]:
resnet18 = ResNet()
summary(
    model=resnet18,
    input_size=[1, 3, 256, 256],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
ResNet (ResNet)                               [1, 3, 256, 256]     [1, 68, 2]           --                   True
├─ResNet (network)                            [1, 3, 256, 256]     [1, 136]             --                   True
│    └─Conv2d (conv1)                         [1, 3, 256, 256]     [1, 64, 128, 128]    9,408                True
│    └─BatchNorm2d (bn1)                      [1, 64, 128, 128]    [1, 64, 128, 128]    128                  True
│    └─ReLU (relu)                            [1, 64, 128, 128]    [1, 64, 128, 128]    --                   --
│    └─MaxPool2d (maxpool)                    [1, 64, 128, 128]    [1, 64, 64, 64]      --                   --
│    └─Sequential (layer1)                    [1, 64, 64, 64]      [1, 64, 64, 64]      --                   True
│    │    └─BasicBlock (0)                    [1, 64, 64, 64]      [1, 64, 64, 64]     

In [4]:
# test input & output shape
random_input = torch.randn([16, 3, 256, 256])
output = resnet18(random_input)
print(f"\nINPUT SHAPE: {random_input.shape}")
print(f"OUTPUT SHAPE: {output.shape}")


INPUT SHAPE: torch.Size([16, 3, 256, 256])
OUTPUT SHAPE: torch.Size([16, 68, 2])


### Load lightning weights

In [5]:
weights = torch.load("../ckpts/resnet18.ckpt", map_location=torch.device("cpu"))
new_state_dict = {}
for key, value in weights["state_dict"].items():
    new_key = key.replace("net.", "")
    new_state_dict[new_key] = value
resnet18.load_state_dict(new_state_dict)

  weights = torch.load("../ckpts/resnet18.ckpt", map_location=torch.device("cpu"))
  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


<All keys matched successfully>

### Save torch weights

In [6]:
torch.save(resnet18.state_dict(), "../ckpts/resnet18.pth")

## MobilenetV3

In [7]:
class MobileNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = models.get_model(name="mobilenet_v3_large")
        self.network.classifier = nn.Sequential(
            nn.Linear(self.network.classifier[0].in_features, 512, bias=True),
            nn.Hardswish(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(512, 136, bias=True)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.network(x)
        x = x.reshape(x.size(0), 68, 2)
        return x

In [8]:
mobilenetv3 = MobileNet()
summary(
    model=mobilenetv3,
    input_size=[1, 3, 256, 256],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

Layer (type (var_name))                                           Input Shape          Output Shape         Param #              Trainable
MobileNet (MobileNet)                                             [1, 3, 256, 256]     [1, 68, 2]           --                   True
├─MobileNetV3 (network)                                           [1, 3, 256, 256]     [1, 136]             --                   True
│    └─Sequential (features)                                      [1, 3, 256, 256]     [1, 960, 8, 8]       --                   True
│    │    └─Conv2dNormActivation (0)                              [1, 3, 256, 256]     [1, 16, 128, 128]    464                  True
│    │    └─InvertedResidual (1)                                  [1, 16, 128, 128]    [1, 16, 128, 128]    464                  True
│    │    └─InvertedResidual (2)                                  [1, 16, 128, 128]    [1, 24, 64, 64]      3,440                True
│    │    └─InvertedResidual (3)                         

In [9]:
# test input & output shape
random_input = torch.randn([16, 3, 256, 256])
output = mobilenetv3(random_input)
print(f"\nINPUT SHAPE: {random_input.shape}")
print(f"OUTPUT SHAPE: {output.shape}")


INPUT SHAPE: torch.Size([16, 3, 256, 256])
OUTPUT SHAPE: torch.Size([16, 68, 2])


### Load lightning weights

In [10]:
weights = torch.load("../ckpts/mobilenetv3.ckpt", map_location=torch.device('cpu'))
new_state_dict = {}
for key, value in weights["state_dict"].items():
    new_key = key.replace("net.", "")
    new_state_dict[new_key] = value
mobilenetv3.load_state_dict(new_state_dict)

  weights = torch.load("../ckpts/mobilenetv3.ckpt", map_location=torch.device('cpu'))


<All keys matched successfully>

### Save torch weights

In [11]:
torch.save(mobilenetv3.state_dict(), "../ckpts/mobilenetv3.pth")