In [1]:
#https://docs.ray.io/en/latest/serve/tutorials/pytorch.html

In [2]:
import os 
import time
import requests
from io import BytesIO
from PIL import Image

import ray
from ray import serve

import torch
import torchvision
from torchvision.transforms import ToTensor, Resize,Compose, ToPILImage
from ray_cluster_control import start_ray_cluster, stop_ray_cluster

In [None]:
# this function is idempotent
start_ray_cluster("kubecon-2022")

In [None]:
@serve.deployment(route_prefix="/image_predict",
                  name="pet_image",
                  ray_actor_options={"num_gpus": 1})
class ImageModel:

    def __init__(self):
        self.model = torch.jit.load("best_model_scripted.pt", 
                                    map_location=torch.device('cuda:0'))
        self.preprocessor = Compose([Resize((64,64)),
                                     ToTensor()])

    async def __call__(self, starlette_request):
        image_payload_bytes = await starlette_request.body()
        pil_image = Image.open(BytesIO(image_payload_bytes))

        pil_images = [pil_image]
        input_tensor = torch.cat([self.preprocessor(i) for i in pil_images])
        input_tensor = torch.reshape(input_tensor, (1,3,64,64))
        with torch.no_grad():
            output_tensor = self.model(input_tensor.to("cuda"))

        return int(torch.argmax(output_tensor[0]))

In [None]:
ray.init('ray://{ray_head}-ray-head:10001'.format(ray_head=os.environ['RAY_CLUSTER_NAME']),
         runtime_env={"working_dir": "models/"},
         _metrics_export_port=8080)

In [None]:
nodes = ray.nodes()
for node in nodes:
    if 'head' in node["NodeManagerHostname"]:
         host = node["NodeManagerAddress"]
print(host)

In [None]:
serve.start(http_options={"host":host})

In [None]:
ImageModel.deploy()

In [None]:
with open("data/oxford-iiit-pet/images/Abyssinian_22.jpg", "rb") as image:
    f = image.read()
    b = bytearray(f)

Image.open(BytesIO(b))

In [None]:
resp = requests.post(f"http://{host}:8000/image_predict", data=b)
print(f"Predicted Class: {resp.json()}")

In [None]:
external_route = f"http://{os.environ['SERVING_ENDPOINT']}/image_predict"
resp = requests.post(external_route, data=b)
print(f"Predicted Class: {resp.json()}")

In [None]:
#stop_ray_cluster()