<a href="https://colab.research.google.com/github/zlqhem/mlapi/blob/main/aws-ecr-torchscript/aws_ecr_torchscript.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# torchscript

## load a model from S3

In [None]:
%%capture
!pip install boto3 python-dotenv

In [None]:
# aws access key setup
import dotenv
# contains 'AWS_ACCESS_KEY', 'AWS_SCRET_ACCESS_KEY'
env_file = "/content/drive/MyDrive/w2/mlapi/aws.env"
dotenv.load_dotenv(env_file)

True

In [None]:
try:
    import unzip_requirements
except ImportError:
    pass

import json
from io import BytesIO
import time
import os
import base64

import boto3
import numpy as np
from PIL import Image

import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F

s3 = boto3.client("s3")
bucket = "soltware.test"
key = "v1/best.torchscript"

def download_model(s3, bucket, key):
    file_name = os.path.basename(key)
    print ('file_name', file_name)
    s3.download_file(bucket, key, file_name)


def load_model(s3, bucket):
  response = s3.get_object(Bucket=bucket, Key=key)
  #state = torch.load(BytesIO(response["Body"].read()))
  #model.load_state_dict(state)
  #model.eval()

  bytes_array = BytesIO(response["Body"].read())
  model = torch.jit.load(bytes_array, map_location=torch.device('cpu')).eval()
  return model

In [None]:
download_model(s3, bucket, key)

file_name best.torchscript


In [None]:
!ls -al best.torchscript

-rw-r--r-- 1 root root 103561564 Feb  4 16:14 best.torchscript


In [None]:
!date

Sun Feb  4 04:14:51 PM UTC 2024


In [None]:
model = load_model(s3, bucket)

In [None]:
model

RecursiveScriptModule(original_name=DetectionModel)

## lambda handler

In [None]:
classes = np.array([
  'Tomato Healthy',
  'Strawberry Healthy',
  'Lettuce Healthy',
  'Strawberry Ashy Mold',
  'Strawberry White Powdery Mildew',
  'Lettuce Bacterial Head Rot',
  'Lettuce Bacterial Wilt',
  'Tomato Leaf Mold',
  'Tomato Yellow Leaf Curl Virus',
])

def lambda_handler(event, context):
    '''
    if event.get("source") in ["aws.events", "serverless-plugin-warmup"]:
        print('Lambda is warm!')
        return {}
    '''

    data = json.loads(event["body"])
    print("data keys:", data.keys())
    image = data["image"]
    response = predict(input_fn_stream(image), model)
    return {
        'statusCode': 200,
        'body': json.dumps(response)
    }

def input_fn_stream(image):
    image = image[image.find(",")+1:]
    dec = base64.b64decode(image + "===")
    byte_array = BytesIO(dec)

    im = Image.open(byte_array).resize((640,640))
    im = im.convert("RGB")

    #https://dev.to/andreygermanov/how-to-create-yolov8-based-object-detection-web-service-using-python-julia-nodejs-javascript-go-and-rust-4o8e#prepare_the_input
    # "We do not need Alpha channel in the image for YOLOv8 predictions. Let's remove it"
    input = np.array(im)
    input = input.transpose(2,0,1)
    input = input.reshape(1,3,640,640)
    input = input/255.0
    return torch.Tensor(input)

def predict(img_tensor, model):
  predict_values = model(img_tensor)
  print(predict_values[0].shape)
  print('predict_values[0]', predict_values[0])
  preds = F.softmax(predict_values, dim=1)
  conf_score, indx = torch.max(preds, dim=1)
  conf_score = conf_score.cpu().numpy()
  indx = indx.cpu().numpy()
  predict_class = classes[indx]
  response = {}
  response['class'] = str(predict_class)
  response['confidence'] = str(conf_score)
  return response


## Use the deployed API


In [None]:
path = "/content/drive/MyDrive/w2/flutter/strawberry-healthy.png"

with open(path, "rb") as image_file:
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

print (len(encoded_string))

url = "TBD"

data =  {
    "image": encoded_string
}

844988


In [None]:
body = json.dumps(data)
res = lambda_handler({"body": body}, {})
print(res)
body = res["body"]
json_data = json.loads(body)
print(json_data)

data keys: dict_keys(['image'])
torch.Size([13, 8400])
predict_values[0] tensor([[1.5857e+01, 3.6524e+01, 4.8894e+01,  ..., 5.0013e+02, 5.2872e+02,
         5.8240e+02],
        [9.9368e+00, 1.2355e+01, 1.6714e+01,  ..., 5.7213e+02, 5.7526e+02,
         5.7146e+02],
        [3.2870e+01, 7.3306e+01, 9.7783e+01,  ..., 2.7847e+02, 2.2293e+02,
         1.1958e+02],
        ...,
        [1.5544e-09, 3.8068e-10, 2.1305e-09,  ..., 3.0667e-10, 4.3925e-09,
         1.2591e-08],
        [1.7935e-09, 3.4876e-10, 1.7459e-09,  ..., 2.0698e-10, 3.2835e-09,
         1.1948e-08],
        [4.9995e-09, 2.0060e-09, 1.2173e-08,  ..., 3.2706e-10, 3.1460e-09,
         8.3934e-09]])
{'statusCode': 200, 'body': '{"class": "[[\'Lettuce Healthy\' \'Lettuce Healthy\' \'Lettuce Healthy\' ...\\n  \'Strawberry Healthy\' \'Strawberry Healthy\' \'Tomato Healthy\']]", "confidence": "[[0.9999974  1.         1.         ... 1.         1.         0.99998224]]"}'}
{'class': "[['Lettuce Healthy' 'Lettuce Healthy' 'Lettuce H

In [None]:
# TBD
import requests
response = requests.post(url, json=data)

## references

* https://aws.amazon.com/ko/blogs/machine-learning/using-container-images-to-run-pytorch-models-in-aws-lambda/
* https://github.com/ahmedbesbes/cartoonify/tree/main