## Loading deepfaune-ne weights into PyTorch and compiling to Torchscript

Things to note: 
- the model was trained on a GPU so we need to load weights and re-compile to CPU
- it's important to check what version of torchvision (if used here) and torch you're running in this notebook environment & be sure they match the versions pinned in the deployment container's Dockerfile
- it's important to know which architecture/model backbone was used in training (in deepfaune-ne's case it was [dinov2](https://github.com/facebookresearch/dinov2)) and the size of the inputs (182x182).

In [1]:
import torch
# from torchvision.models import efficientnet
import timm

torch.__version__

  from .autonotebook import tqdm as notebook_tqdm


'2.6.0'

In [10]:
# Modified from https://code.usgs.gov/vtcfwru/deepfaune-new-england/-/blob/main/scripts/dfne_model.py
# import os
# import pandas as pd
# # import onnxruntime
# from pathlib import Path
# from time import time
# import glob
# import torch.nn as nn

CLASSES = {
    0: "American Marten",
    1: "Bird sp.",
    2: "Black Bear",
    3: "Bobcat",
    4: "Coyote",
    5: "Domestic Cat",
    6: "Domestic Cow",
    7: "Domestic Dog",
    8: "Fisher",
    9: "Gray Fox",
    10: "Gray Squirrel",
    11: "Human",
    12: "Moose",
    13: "Mouse sp.",
    14: "Opossum",
    15: "Raccoon",
    16: "Red Fox",
    17: "Red Squirrel",
    18: "Skunk",
    19: "Snowshoe Hare",
    20: "White-tailed Deer",
    21: "Wild Boar",
    22: "Wild Turkey",
    23: "no-species"
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def _load_model(weights = None):
    """
    Load the DeepFaune NE model weights
    
    Args:
        - weights (str, optional): 
            Path to the model weights. (defaults to DFNE weights)

    Returns:
        - model: model object with loaded weights
    """


    if weights == None:
        weights = "dfne_weights_v1_0.pth"

    model_name = "vit_large_patch14_dinov2.lvd142m"

    classes = {
        0: "American Marten",
        1: "Bird sp.",
        2: "Black Bear",
        3: "Bobcat",
        4: "Coyote",
        5: "Domestic Cat",
        6: "Domestic Cow",
        7: "Domestic Dog",
        8: "Fisher",
        9: "Gray Fox",
        10: "Gray Squirrel",
        11: "Human",
        12: "Moose",
        13: "Mouse sp.",
        14: "Opossum",
        15: "Raccoon",
        16: "Red Fox",
        17: "Red Squirrel",
        18: "Skunk",
        19: "Snowshoe Hare",
        20: "White-tailed Deer",
        21: "Wild Boar",
        22: "Wild Turkey",
        23: "no-species"
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    predictor = timm.create_model(
        model_name, 
        pretrained = False, 
        num_classes = len(classes),
        dynamic_img_size = True
    )
        
    checkpoint = torch.load(
        f = weights,
        map_location = device,
        weights_only = True
    )
    
    predictor.load_state_dict(checkpoint['model_state_dict'])

    return predictor

In [11]:
classifier = _load_model(weights = 'model-weights/dfne_weights_v1_0.pth')

In [12]:
# Save out the whole model for future inference deployment
# https://pytorch.org/tutorials/beginner/saving_loading_models.html

compiled_path = './model-weights/deepfaune-ne_compiled_cpu.pt'

model_scripted = torch.jit.script(classifier) # Export to TorchScript
model_scripted.save(compiled_path) # Save