### MobileNetV3

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import torch
from network.mobilenet_v3 import mobilenet_v3_large
from network.Resnet import resnet50
from network.utils import IntermediateLayerGetter


In [None]:
model_b = mobilenet_v3_large(fs_layer=[1,1,1,0,0])

In [None]:
o = model_b(torch.zeros(1,3,224,224))

In [None]:
o.shape

In [None]:
model_b.features

In [None]:
# state_dict = model_b.state_dict()
state_dict = torch.load('weights/mobilenet_v3_large-5c1a4163.pth')
#state_dict_v2 = copy.deepcopy(state_dict)


for i in reversed(range(1,17)):
    state_dict_filt = {k: v for k, v in state_dict.items() if f'features.{i}.' in k}
    print(state_dict_filt.keys())
    for key in state_dict_filt:
        _, post = key.split(f'features.{i}.')
        state_dict[f'features.{i+1}.{post}'] = state_dict.pop(key)
        print(f'{key} \t -> features.{i+1}.{post}')

In [None]:
backbone = model_b.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
high_pos = stage_indices[-1]  # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})

In [None]:
backbone = model_b.features

backbone[high_pos].out_channels

### LR-ASPP

In [None]:
import torch
from network.lraspp import lraspp_mobilenet_v3_large

from train import args
import loss

In [None]:
criterion, criterion_val = loss.get_loss(args)
criterion_aux = loss.get_loss_aux(args)

In [None]:
model = lraspp_mobilenet_v3_large(args=args,
                                  criterion=criterion, 
                                  criterion_aux=criterion_aux, 
                                  cont_proj_head=args.cont_proj_head, 
                                  wild_cont_dict_size=args.wild_cont_dict_size)

In [None]:
with torch.no_grad():
    model.eval()
    o, f = model(torch.zeros(1,3,224,224))

In [None]:
f['low'].shape, f['high'].shape

## Files

In [None]:
import os
from pathlib import Path
import re
 
subd = Path("/media/data/Datasets/AgriSeg_Dataset/vineyard_real/")

for subdir in subd.iterdir():
    if subdir.is_file() or subdir.name.startswith('.'): continue
    print(subdir)
    # for ss in ['images', 'masks']:
    #     for f in sorted(list(subdir.joinpath(ss).iterdir())):
            # i = re.split('_|\.', f.name)[-2]
            # # print(f.name, i)
            # os.mkdir(f.parent.joinpath(f'Image{i}'))
            # os.rename(f, f.parent.joinpath(f'Image{i}').joinpath(f.name))
    image_file_names = [list(f.glob('**/*')) for f in subdir.joinpath('images').iterdir() 
                        if not f.name.startswith('.')]
    mask_file_names = [list(f.glob('**/*')) for f in subdir.joinpath('masks').iterdir()
                        if not f.name.startswith('.')]

In [None]:
image_file_names[10], mask_file_names[10]

In [None]:
from PIL import Image

Image.open(image_file_names[1][0])

In [None]:
Image.open(mask_file_names[1][0])