In [14]:
import pathlib
import sys

import lightning
import torch
from dotenv import dotenv_values
import onnxruntime
import numpy as np
from wild_boar_detection.utils import Hyperparameters, dataclass_from_dict

sys.path.append(str(pathlib.Path.cwd()))
torch.set_float32_matmul_precision("medium")
cfg: Hyperparameters | dict[str, int | float | str | bool] = dataclass_from_dict(Hyperparameters, dotenv_values())
lightning.seed_everything(cfg.SEED, workers=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EXPERIMENT_PATH = pathlib.Path("../data/logs/mlruns/669180362677009476").joinpath("0f5b21808f684dd6b5595b01f9b197e3").resolve() 
ARTIFACTS_PATH = EXPERIMENT_PATH / "artifacts/model"

MODEL_PATH = ARTIFACTS_PATH / "checkpoints/model_checkpoint/model_checkpoint.ckpt"
ONNX_MODEL_PATH = ARTIFACTS_PATH / "model.onnx"
ONNX_CHECKPOINT_MODEL_PATH = ARTIFACTS_PATH / "checkpoint_model.onnx"

Seed set to 666


In [15]:
def predict_model(model_path: str | pathlib.Path, input_sample: np.ndarray) -> np.ndarray:
    ort_session = onnxruntime.InferenceSession(model_path)
    input_name = ort_session.get_inputs()[0].name
    ort_inputs = {input_name: input_sample}
    return ort_session.run(None, ort_inputs)

In [16]:
input_sample = np.random.rand(1, cfg.BASE_CHANNEL_SIZE, cfg.INPUT_SIZE, cfg.INPUT_SIZE).astype(np.float32)
input_sample.dtype

dtype('float32')

In [17]:
checkpoint_output = predict_model(model_path=ONNX_CHECKPOINT_MODEL_PATH, input_sample=input_sample)

In [18]:
model_output = predict_model(model_path=ONNX_MODEL_PATH, input_sample=input_sample)

In [19]:
checkpoint_output

[array([[-1.2903147]], dtype=float32)]

In [20]:
model_output

[array([[-3.565104]], dtype=float32)]