## Loading MIRAv2 weights into PyTorch and compiling to Torchscript

Things to note: 
- the model was trained on a GPU so we need to load weights and re-compile to CPU
- it's important to check what version of torchvision (if used here) and torch you're running in this notebook environment & be sure they match the versions pinned in the deployment container's Dockerfile
- it's important to know which version of efficientnet was used in training (in our case it was "efficientnet-b3") and the size of the inputs (in our case 300x300). See [model options and parameters](https://github.com/microsoft/CameraTraps/blob/ccb5e98095cf81a625bf19129cb3dc97354f6284/classification/efficientnet/utils.py#L452) if you don't know what size images where used in training. 

In [None]:
import torch
from efficientnet_pytorch import EfficientNet

torch.__version__

In [None]:
ckpt = torch.load("model-weights/ckpt_18.pt", map_location=torch.device("cpu"))
model = EfficientNet.from_pretrained("efficientnet-b3", num_classes=5)
model.load_state_dict(ckpt['model'])

# NOTE: I had originally tried the following (based on examples online) but it didn't work: 

# state_dict = torch.load("model-weights/ckpt_18.pt", map_location=torch.device("cpu"))
# model = EfficientNet.from_pretrained("efficientnet-b3", num_classes=5)
# model.load_state_dict(state_dict)

# because in Microsoft's code when they're saving checkpoints they nest the state_dict in an object w/ a key 
# called “model”: https://github.com/microsoft/CameraTraps/blob/ccb5e98095cf81a625bf19129cb3dc97354f6284/classification/train_classifier.py#L419
# And then load it from that key: https://github.com/microsoft/CameraTraps/blob/ccb5e98095cf81a625bf19129cb3dc97354f6284/classification/train_classifier.py#L254

# It's worth mentioning because if you're trying to load weights from a 
# different PyTorch model that was saved differently, 
# you may need to update this code slightly

In [None]:
model.eval()

In [None]:
# from https://github.com/microsoft/CameraTraps/blob/ccb5e98095cf81a625bf19129cb3dc97354f6284/classification/evaluate_model.py#L71
img_size = 300
compiled_path = './model-weights/mira_compiled_cpu.pt'

model.set_swish(memory_efficient=False)
ex_img = torch.rand(1, 3, img_size, img_size)
scripted_model = torch.jit.trace(model, (ex_img,))

scripted_model.save(compiled_path)
print('Saved TorchScript compiled model to', compiled_path)