In [None]:
from torch.utils.data import DataLoader
from src.data.data_loader import load_final, JerseyNumberDataset
from src.data.data_handling import balancer

In [None]:
path = "data/final/images"
test_data = load_final(path)
test_data = balancer(test_data,max_0=0)
print("total test : ", len(test_data))

In [None]:
cut = "topbottom"
image_size = (224, 224)
batch_size = 64
workers = 1

test_dataset = JerseyNumberDataset(test_data, image_size=image_size,cut=cut)
test_loader = DataLoader(test_dataset,num_workers=workers, batch_size=batch_size, shuffle=True,pin_memory=True)

In [None]:
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
import numpy as np
import matplotlib.pyplot as plt

def grid_trt(engine_path, test_loader, num_images=100, input_shape=(3,224,224), device='cuda'):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    context = engine.create_execution_context()

    batch_size = 1
    input_nbytes = np.prod((batch_size, *input_shape)) * np.float32().nbytes
    output_nbytes = np.prod((batch_size, 100)) * np.float32().nbytes 

    d_input = cuda.mem_alloc(int(input_nbytes))
    d_output = cuda.mem_alloc(int(output_nbytes))
    output = np.empty((batch_size, 100), dtype=np.float32)

    images_to_show = []
    preds_to_show = []
    labels_to_show = []

    for batch in test_loader:
        images = batch["x"]
        labels = batch["y"]

        for i in range(images.size(0)):
            if len(images_to_show) >= num_images:
                break

            img = images[i].numpy().astype(np.float32)
            img = np.ascontiguousarray(img[np.newaxis, ...])

            cuda.memcpy_htod(d_input, img)

            bindings = [int(d_input), int(d_output)]
            context.execute_v2(bindings=bindings)

            cuda.memcpy_dtoh(output, d_output)

            pred = int(np.argmax(output[0]))

            images_to_show.append(images[i])
            preds_to_show.append(pred)
            labels_to_show.append(labels[i].item())

        if len(images_to_show) >= num_images:
            break

    cols = 10
    rows = (num_images + cols - 1) // cols
    fig = plt.figure(figsize=(cols * 3.2, rows * 3.2))
    num_display = len(images_to_show)

    for i in range(num_display):
        img = images_to_show[i].permute(1, 2, 0).numpy().clip(0, 1)
        pred = preds_to_show[i]
        label = labels_to_show[i]

        ax = fig.add_subplot(rows, cols, i + 1)
        ax.imshow(img)
        ax.axis("off")
        color = "green" if pred == label else "red"
        title = f"True: {label}\nPred: {pred}"

        ax.text(
            5, 20,
            title,
            fontsize=10,
            color=color,
            bbox=dict(facecolor="black", alpha=0.6, pad=3)
        )

    plt.tight_layout()
    plt.savefig("grida.png", dpi=150, bbox_inches="tight")
    plt.show()


In [None]:
grid_trt("results/trt_models/vit_model_fp16.trt",test_loader=test_loader)