In [1]:
import tomllib
import torch
import numpy as np
from src.utils.utils import dataclass_from_dict
from src.model.hyperparameters import Hyperparameters
from __future__ import annotations
import pathlib
import sys
from src.model.autoencoder import Autoencoder
import onnxruntime

sys.path.append(pathlib.Path.cwd().parent)

def load_model_hyperparameters() -> Hyperparameters:
    with (pathlib.Path("/home/paolo/git/spotify-playlist-generator/config/model.toml")).open("rb") as f:
        return dataclass_from_dict(Hyperparameters, tomllib.load(f))
cfg = load_model_hyperparameters()


In [2]:
INPUT_SHAPE: tuple[int, ...] = (1, 2, 256, 1292)

## Save the models

In [3]:
model = Autoencoder.load_from_checkpoint("/home/paolo/git/spotify-playlist-generator/logs/mlruns/341767160074738545/1553b9dee11e43dd83503dd311804a81/artifacts/model/checkpoints/model_checkpoint/model_checkpoint.ckpt", hyperparam=cfg).eval().to("cpu").float()

In [4]:
base_path: pathlib.Path = pathlib.Path.cwd().parents[0]
torchscript_model_path: pathlib.Path = base_path/ "data/models/torchscript_model.pt"
onnx_model_path: pathlib.Path = base_path / "data/models/onnx_model.onnx"
onnx_encoder_model_path: pathlib.Path = base_path / "data/models/onnx_encoder_model.onnx"

In [5]:
if not (base_path / torchscript_model_path).exists():
    torchscript_model = model.to_torchscript()
    torch.jit.save(torchscript_model, base_path / torchscript_model_path)

if not (base_path / onnx_model_path).exists():
    input_sample = torch.randn((1, 2, 256, 1292))
    model.to_onnx(base_path / onnx_model_path, input_sample, export_params=True)

if not (base_path / onnx_encoder_model_path).exists():
    input_sample = torch.randn((1, 2, 256, 1292))
    torch.onnx.export(model.encoder.float(), input_sample, onnx_encoder_model_path)

## Pytorch

In [6]:
x = torch.randn(INPUT_SHAPE, dtype=torch.float32, device="cpu")

In [7]:
%%timeit -r 10 -n 1000

output = model.encoder(x)

19.4 ms ± 687 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


## Torchscript

In [8]:
scripted_module = torch.jit.load(torchscript_model_path)
x = torch.randn(INPUT_SHAPE, dtype=torch.float32, device="cpu")

In [9]:
%%timeit -r 10 -n 1000


output = scripted_module.encoder(x)

19 ms ± 393 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


## ONNX

In [10]:
ort_session = onnxruntime.InferenceSession(onnx_model_path)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(*INPUT_SHAPE).astype(np.float32)}

In [11]:
%%timeit -r 7 -n 100

ort_outs = ort_session.run(None, ort_inputs)

72.6 ms ± 939 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## ONNX Encoder

In [12]:
ort_session = onnxruntime.InferenceSession(onnx_encoder_model_path)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(*INPUT_SHAPE).astype(np.float32)}

In [13]:
%%timeit -r 10 -n 1000

ort_outs = ort_session.run(None, ort_inputs)

33.8 ms ± 2.91 ms per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
