In [None]:
import subprocess
from typing import Any

import json
import numpy as np
import requests
import torch
from ray import serve
from starlette.requests import Request

In [None]:
class MNISTClassifier:
    def __init__(self, remote_path: str, local_path: str, device: str):
        subprocess.run(f"aws s3 cp {remote_path} {local_path} --no-sign-request", shell=True, check=True)
        
        self.device = device
        self.model = torch.jit.load(local_path).to(device).eval()

    def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
        return self.predict(batch)
    
    def predict(self, batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
        images = torch.tensor(batch["image"]).float().to(self.device)

        with torch.no_grad():
            logits = self.model(images).cpu().numpy()

        batch["predicted_label"] = np.argmax(logits, axis=1)
        return batch

In [None]:
storage_folder = '/Desktop/Anyscale-ray/ray-anyscale/Notebooks'  
model_path = f"{storage_folder}/model.pt" 
classifier = MNISTClassifier(remote_path="s3://anyscale-public-materials/ray-ai-libraries/mnist/model/model.pt", local_path=model_path, device="cpu")

In [None]:
output = classifier({"image": np.random.rand(1, 1, 28, 28).astype(np.float32)})  
output["predicted_label"] 

In [None]:
@serve.deployment() # Decorator 
class OnlineMNISTClassifier:
    def __init__(self, remote_path: str, local_path: str, device: str):
        subprocess.run(f"aws s3 cp {remote_path} {local_path} --no-sign-request", shell=True, check=True)
        
        self.device = device
        self.model = torch.jit.load(local_path).to(device).eval()

    async def __call__(self, request: Request) -> dict[str, Any]:  # __call__ now takes a Request object
        batch = json.loads(await request.json()) 
        return await self.predict(batch)
    
    async def predict(self, batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
        images = torch.tensor(batch["image"]).float().to(self.device)

        with torch.no_grad():
            logits = self.model(images).cpu().numpy()

        batch["predicted_label"] = np.argmax(logits, axis=1)
        return batch

In [None]:
model_path = f"{storage_folder}/model.pt"
mnist_app = OnlineMNISTClassifier.bind(remote_path="s3://anyscale-public-materials/ray-ai-libraries/mnist/model/model.pt", local_path=model_path, device="cpu")
mnist_app

In [None]:
mnist_app_handle = serve.run(mnist_app, name='mnist_classifier', blocking=False)
mnist_app_handle

In [None]:
images = np.random.rand(2, 1, 28, 28).tolist()
json_request = json.dumps({"image": images})
response = requests.post("http://localhost:8000/", json=json_request)
response.json()["predicted_label"]


batch = {"image": np.random.rand(10, 1, 28, 28)}
response = await mnist_app_handle.predict.remote(batch)
response["predicted_label"]

In [None]:
!cd intro/ && serve run main:mnist_app --non-blocking --name app1

In [None]:
!cd intro/ && serve build -o config.yaml main:mnist_app 

In [None]:
!cd intro/ && serve run config.yaml --non-blocking

In [None]:
!cd intro/ && serve run app_builder:build_app --non-blocking --name app1 device=cpu

In [None]:
!rm {storage_folder}/model.pt