In [None]:
import torchvision
import torch

model = torchvision.models.resnet50(weights='IMAGENET1K_V2')
model.eval()
traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
torch.jit.save(traced_model, "model.pt")


In [61]:
from torchvision import transforms
from PIL import Image

def rn50_preprocess(img_path="img1.jpg"):
    img = Image.open(img_path)
    preprocess = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return preprocess(img).numpy()

transformed_img = rn50_preprocess()
transformed_img.shape

(3, 224, 224)

In [62]:
import tritonclient.http as httpclient
import numpy as np

client = httpclient.InferenceServerClient(url="localhost:8000")


In [63]:
batch_imgs = np.stack([transformed_img, transformed_img, transformed_img, transformed_img, transformed_img, transformed_img, transformed_img]) 
batch_imgs.shape

(7, 3, 224, 224)

In [64]:
inputs = httpclient.InferInput("input__0", batch_imgs.shape, datatype="FP32")
inputs.set_data_from_numpy(batch_imgs, binary_data=True)

outputs = httpclient.InferRequestedOutput(
    "output__0", binary_data=True, class_count=1000
)
results = client.infer(model_name="resnet50", inputs=[inputs], outputs=[outputs])
inference_output = results.as_numpy("output__0")
inference_output.shape

(7, 1000)

In [65]:
client.close()