### MobileNetV3

In [1]:
%load_ext autoreload
%autoreload 2

In [25]:
import copy
import torch
from network.mobilenet_v3 import mobilenet_v3_large
from network.Resnet import resnet50
from network.utils import IntermediateLayerGetter


In [3]:
model_b = mobilenet_v3_large(fs_layer=[1,1,1,0,0])

Loading ImageNet weights...


In [4]:
o = model_b(torch.zeros(1,3,224,224))

In [5]:
o.shape

torch.Size([1, 1000])

In [22]:
model_b.features

Sequential(
  (0): Conv2dNormActivation(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): AdaptiveInstanceNormalization()
    (2): Hardswish()
  )
  (1): InvertedResidual(
    (instance_norm_layer): AdaptiveInstanceNormalization()
    (block): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Conv2dNormActivation(
        (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      )
    )
  )
  (2): InvertedResidual(
    (instance_norm_layer): AdaptiveInstanceNormalization()
    (block): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias

In [56]:
# state_dict = model_b.state_dict()
state_dict = torch.load('weights/mobilenet_v3_large-5c1a4163.pth')
#state_dict_v2 = copy.deepcopy(state_dict)


for i in reversed(range(1,17)):
    state_dict_filt = {k: v for k, v in state_dict.items() if f'features.{i}.' in k}
    print(state_dict_filt.keys())
    for key in state_dict_filt:
        _, post = key.split(f'features.{i}.')
        state_dict[f'features.{i+1}.{post}'] = state_dict.pop(key)
        print(f'{key} \t -> features.{i+1}.{post}')

dict_keys(['features.16.0.weight', 'features.16.1.weight', 'features.16.1.bias', 'features.16.1.running_mean', 'features.16.1.running_var', 'features.16.1.num_batches_tracked'])
features.16.0.weight 	 -> features.17.0.weight
features.16.1.weight 	 -> features.17.1.weight
features.16.1.bias 	 -> features.17.1.bias
features.16.1.running_mean 	 -> features.17.1.running_mean
features.16.1.running_var 	 -> features.17.1.running_var
features.16.1.num_batches_tracked 	 -> features.17.1.num_batches_tracked
dict_keys(['features.15.block.0.0.weight', 'features.15.block.0.1.weight', 'features.15.block.0.1.bias', 'features.15.block.0.1.running_mean', 'features.15.block.0.1.running_var', 'features.15.block.0.1.num_batches_tracked', 'features.15.block.1.0.weight', 'features.15.block.1.1.weight', 'features.15.block.1.1.bias', 'features.15.block.1.1.running_mean', 'features.15.block.1.1.running_var', 'features.15.block.1.1.num_batches_tracked', 'features.15.block.2.fc1.weight', 'features.15.block.2.fc

In [26]:
backbone = model_b.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
high_pos = stage_indices[-1]  # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})

In [32]:
backbone = model_b.features

backbone[high_pos].out_channels

960

### LR-ASPP

In [33]:
import torch
from network.lraspp import lraspp_mobilenet_v3_large

from train import args
import loss

In [34]:
criterion, criterion_val = loss.get_loss(args)
criterion_aux = loss.get_loss_aux(args)

standard cross entropy
standard cross entropy


In [37]:
model = lraspp_mobilenet_v3_large(args=args,
                                  criterion=criterion, 
                                  criterion_aux=criterion_aux, 
                                  cont_proj_head=args.cont_proj_head, 
                                  wild_cont_dict_size=args.wild_cont_dict_size)

Loading ImageNet weights...


In [53]:
with torch.no_grad():
    model.eval()
    o, f = model(torch.zeros(1,3,224,224))

In [54]:
f['low'].shape, f['high'].shape

(torch.Size([1, 40, 28, 28]), torch.Size([1, 960, 14, 14]))