In [1]:
from torch import nn
from collections.abc import Sequence

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import load_model
from models.get_model import get_model

from monai.networks.blocks.segresnet_block import ResBlock, get_conv_layer, get_upsample_layer
from monai.networks.layers.factories import Dropout
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import UpsampleMode
from typing import Union, Tuple, List, Dict, Optional
import json

In [2]:
class MobileWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        # RBG to grayscale
        x = torch.mean(x, dim=-3, keepdim=True)
        # add batch dim
        # x = x.unsqueeze(0)

        x = self.model(x)
        # do argmax
        x = torch.softmax(x, dim=1)
        x = x[:, 1, ...]
        x = x.unsqueeze(1)
        x = x > 0.9

        res : Dict[str, torch.Tensor] = {}
        res["out"] = x
        # res = x
        return res

In [3]:
run_name = "2023-11-18_16-23-20"

run_path = f"runs/{run_name}/"

train_summary = json.load(open(run_path + "train_summary.json"))

model_name = train_summary["config"]["MODEL"]
IMAGE_SIZE = train_summary["config"]["IMAGE_SIZE"]

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = get_model(model_name, IMAGE_SIZE)
model = load_model(model, run_path + "best_model.pth")

model_mobile = MobileWrapper(model)

In [5]:
model_mobile.eval()
example = torch.rand(1, 3, 256, 256)

out = model_mobile(example)
print(out["out"].shape)

torch.Size([1, 1, 256, 256])


In [6]:
model_mobile.eval()
example = torch.rand(1, 3, 256, 256)
# traced_module = torch.jit.trace(model, example, strict=False)
traced_script_module = torch.jit.script(model_mobile)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("model.ptl")


In [7]:
from torch.jit.mobile import (
    _backport_for_mobile,
    _get_model_bytecode_version,
)

print(_get_model_bytecode_version("model.ptl"))

_backport_for_mobile("model.ptl", "model_7.ptl", 7)

print(_get_model_bytecode_version("model_7.ptl"))

8
7


In [7]:
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torch.hub.load('pytorch/vision:v0.11.0', 'deeplabv3_resnet50', pretrained=True)
model.eval();

scripted_module = torch.jit.script(model)
optimized_scripted_module = optimize_for_mobile(scripted_module)

# Export full jit version model (not compatible with lite interpreter)
scripted_module.save("deeplabv3_scripted.pt")
# Export lite interpreter version model (compatible with lite interpreter)
scripted_module._save_for_lite_interpreter("deeplabv3_scripted.ptl")
# using optimized lite interpreter model makes inference about 60% faster than the non-optimized lite interpreter model, which is about 6% faster than the non-optimized full jit model
optimized_scripted_module._save_for_lite_interpreter("deeplabv3_scripted_optimized.ptl")

Using cache found in /home/rikhat.akizhanov/.cache/torch/hub/pytorch_vision_v0.11.0


In [5]:
img = torch.rand(1, 3, 256, 256)
out = model(img)

print(out["out"].shape)



torch.Size([1, 21, 256, 256])
