# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Function to rename layers

For simplification

In [None]:
def rename_layer(name):
    name = name.replace('classifier.1', 'classifier')
    return name

In [None]:
def test_rename_layers(layers):
    new_layers = []
    for layer in layers:
        layer2 = rename_layer(layer)
        if layer2 not in layers:
            new_layers.append((layer, layer2))
        if layer2 != layer:
            print(f'{layer:<42} {layer2 if layer2 != layer else "SAME"}')
    return new_layers

# Publish model

In [None]:
from collections import OrderedDict

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
def publish_model(run_name, task, override=False):
    run_id = RunId(run_name, debug=False, task=task)

    name = f'{run_id.task}-{run_id.full_name}.pt'
    published_fpath = os.path.join(WORKSPACE_DIR, 'public_checkpoints', name)
    
    compiled_model = load_compiled_model(run_id)
    
    if not override and os.path.isfile(published_fpath):
        print('Already published')
        return published_fpath, compiled_model.metadata
    
    # Rename old layers
    state_dict = compiled_model.model.state_dict()
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        new_key = rename_layer(key)
        new_state_dict[new_key] = value
    
    # Save new checkpoint
    torch.save(new_state_dict, published_fpath)
    
    print('Published to', published_fpath)
    return published_fpath, compiled_model.metadata

In [None]:
published_fpath, metadata1 = publish_model('0402_062551', 'cls-seg')

In [None]:
published_fpath, metadata2 = publish_model('0422_163242', 'cls-seg')

# Test loading

## Simplfied definition

In [None]:
import torch
import torch.nn as nn
from torchvision import models

_ASSERT_IN_OUT_IMAGE_SIZE = False

N_CL_DISEASES = 14
N_SEG_LABELS = 4

def get_adaptive_pooling_layer(drop=0):
    """Returns a torch layer with AdaptivePooling2d, plus dropout if needed."""
    layers = [nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()]

    if drop > 0:
        layers.append(nn.Dropout(drop))

    return nn.Sequential(*layers)


class ImageNetClsSegModel(nn.Module):
    def __init__(self, freeze=False, dropout_features=0):
        super().__init__()
        densenet = models.densenet121(
          drop_rate=0.3,
          pretrained=False, # Not needed if using load_state_dict() later
        )
        densenet_features_size = 1024
        
        # Copy densenet features
        self.features = densenet.features

        if freeze:
            for param in self.features.parameters():
                param.requires_grad = False

        # NOTE: this setup works for image input sizes 256, 512, 1024, to output the exact
        # same size in the segmentator.
        # Other input sizes (as 200) may not work
        self.segmentator = nn.Sequential(
            # in: features_size, f-height, f-width
            nn.ConvTranspose2d(densenet_features_size, 4, 4, 2, padding=1),
            # out: 4, 2x fheight, 2x fwidth
            nn.ConvTranspose2d(4, N_SEG_LABELS, 32, 16, padding=8),
            # out: n_seg_labels, in_size, in_size
        )

        self.cl_reduction = get_adaptive_pooling_layer(drop=dropout_features)
        
        self.classifier = nn.Linear(densenet_features_size, N_CL_DISEASES)


    def forward(self, x):
        in_size = x.size()[-2:]

        x = self.features(x)
        # shape: batch_size, n_features, features-height, features-width

        classification = self.classifier(self.cl_reduction(x))
        # shape: batch_size, n_cl_diseases

        segmentation = self.segmentator(x)
        # shape: batch_size, n_seg_labels, height, width

        if _ASSERT_IN_OUT_IMAGE_SIZE:
            out_size = segmentation.size()[-2:]
            assert in_size == out_size, f'Image sizes do not match: in={in_size} vs out={out_size}'

        return classification, segmentation

## Load

In [None]:
model = ImageNetClsSegModel(dropout_features=0)

In [None]:
fpath = '/mnt/workspace/medical-ai/public_checkpoints/cls-seg-0422_163242_cxr14_densenet-121-cls-seg_drop0.3_dropf0.5_normS_lr3e-05_wd0.01_sch-roc-auc-p2-f0.5-c2_aug1-double__wd.pt'

In [None]:
model.load_state_dict(torch.load(fpath, map_location='cpu'))
_ = model.eval()