# Making a PyTorch model PyTorch Mobile compatible

In [1]:
import torch
import torchvision

print(torch.__version__) # 1.12+
print(torchvision.__version__) # 0.13+

import matplotlib.pyplot as plt

from torch import nn
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from torchinfo import summary

import time
import os
from PIL import Image
import pandas as pd

2.1.1
0.16.1


### Custom class for the model
* uses Sequential for transforms this time
* everything needs to be on cpu now

In [11]:
class mobile_model(torch.nn.Module):
    def __init__(self, model_path=None):
        super().__init__()

        weights = torchvision.models.ResNet50_Weights.DEFAULT
        self.model = torchvision.models.resnet50(weights=weights)
        for param in self.model.parameters():
            param.requires_grad = False

        num_in_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(in_features=num_in_features, out_features=6)
        )

        weights = torch.load(model_path, map_location=torch.device('cpu'))
        self.model.load_state_dict(weights)
        self.model.to('cpu')
        self.model.eval()

        # Sequential transforms
        self.transforms = torch.nn.Sequential(
            transforms.Resize([256], antialias=True),
            transforms.CenterCrop(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        )

    def forward(self, tensor):
        img = self.transforms(tensor)
        img = img.unsqueeze(0)
        img = img.to('cpu')
        return self.model(img)

### Optimizing for mobile

In [12]:
from pathlib import Path

MODEL_PATH = Path("models/resnet50_model.pth")
MODEL_PATH

WindowsPath('models/resnet50_model.pth')

In [13]:
resnet50_model = mobile_model(MODEL_PATH)

In [14]:
from torch.utils.mobile_optimizer import optimize_for_mobile

scripted_module = torch.jit.script(resnet50_model)
optimized_scripted_module = optimize_for_mobile(scripted_module)
optimized_scripted_module._save_for_lite_interpreter("resnet50_model_lite.ptl")

In [15]:
resnet50_lite = torch.jit.load("resnet50_model_lite.ptl")

In [16]:
img_path = Path("data/aluminum-cans-1.jpg")

In [17]:
pil_img = Image.open(img_path)
tensor = transforms.ToTensor()(pil_img)

### Check that results are the same

In [18]:
resnet50_model(tensor)

tensor([[-5.0872, -2.5763, -0.9848, -5.5661, -3.4356, -4.3329]],
       grad_fn=<AddmmBackward0>)

In [19]:
resnet50_lite(tensor)

tensor([[-5.0872, -2.5763, -0.9848, -5.5661, -3.4356, -4.3329]])