In [None]:
import random
import time
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as rt
import pandas as pd
from tqdm.notebook import tqdm

In [None]:
images_path = Path("/home/simone/workspace/fogna/datasets/ompi/SJZ4_HangClip/bad/")
images_list = list(images_path.rglob("*.tiff"))

In [None]:
test_df = pd.read_csv("/path/to/model/determined_lamarr_d66238bc/test.csv")
images_list = test_df.filename.values

In [None]:
onnx_model_path = "/path/to/model/determined_lamarr_d66238bc/debug/onnx/determined_lamarr_d66238bc.onnx"
session = rt.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])

In [None]:
output_path = Path("./determined_lamarr_d66238bc")
output_path.mkdir(parents=True, exist_ok=True)

In [None]:
all_max_positions = []
elapsed_times = []
subs = False
input_shape = (224, 224)

if subs:
    subsample = random.sample(images_list, k=500)
    images_list = subsample

for ix, img_fn in tqdm(enumerate(images_list)):
    img = cv2.imread(str(img_fn), cv2.IMREAD_COLOR)
    img = cv2.resize(img, input_shape)

    inp = np.array(img, dtype=np.float32)
    inputs = {"inputs": inp[np.newaxis, ...]}
    s = time.time()
    outputs = session.run(None, inputs)
    e = time.time()
    elapsed_times.append(e - s)

    pred = "Good"
    if np.argmax(outputs[0][0]) == 0:
        pred = "Bad"

    heatmap = outputs[1][0, :]
    min_h = np.min(heatmap)
    max_h = np.max(heatmap)
    max_pos = np.where(heatmap == np.amax(heatmap))
    max_pos = (max_pos[0][0], max_pos[1][0])
    min_pos = np.where(heatmap == np.amin(heatmap))
    # print(max_pos)
    all_max_positions.append(max_pos)

    fig, ax = plt.subplots(1, 2, figsize=(16, 9))
    ax[0].imshow(img)
    ax[1].imshow(img)
    ax[1].imshow(heatmap, cmap="jet", alpha=0.5)
    # ax[1].plot(max_pos[1], max_pos[0], marker="o", markersize=20, markerfacecolor="green")
    fig.suptitle(
        f"Prediction: {pred}; score: {outputs[0][0]}, min_heatmap: {min_h}, max_heatmap: {max_h}"
    )
    fig.savefig(str(output_path / f"{ix}.png"))
    plt.close()

In [None]:
print(f"Mean inference time: {np.mean(elapsed_times)}")
print(f"Std inference time: {np.std(elapsed_times)}")

In [None]:
plt.plot(all_max_positions, "bo")

In [None]:
plt.plot(all_max_positions, "bo")

In [None]:
onnx_model_path2 = (
    "/path/to/model/fervent_brattain_74b0b6c8/onnx/fervent_brattain_74b0b6c8.onnx"
)
session2 = rt.InferenceSession(onnx_model_path2, providers=["CPUExecutionProvider"])

In [None]:
s = time.time()
outputs = session2.run(None, inputs)
e = time.time()
print(e - s)