In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from PIL.Image import Image as PilImage
import textwrap, os
import sagemaker
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.serializers import JSONSerializer, IdentitySerializer
from sagemaker.deserializers import JSONDeserializer

%matplotlib inline

In [None]:
s3_bucket_name = sagemaker.session.Session().default_bucket()
s3_bucket_name

In [None]:
%%writefile build_model_tar.sh
#!/bin/bash
BUCKET_NAME="$1"
MODEL_NAME=RN50.pt
MODEL_NAME_URL=https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt

BUILD_ROOT=/tmp/model_path
S3_PATH=s3://${BUCKET_NAME}/models/clip/model.tar.gz


rm -rf $BUILD_ROOT
mkdir $BUILD_ROOT
cd $BUILD_ROOT && curl -o $BUILD_ROOT/$MODEL_NAME $MODEL_NAME_URL
cd $BUILD_ROOT && tar -czvf model.tar.gz .
aws s3 cp $BUILD_ROOT/model.tar.gz  $S3_PATH

In [None]:
!bash build_model_tar.sh {s3_bucket_name}

In [None]:
!aws s3 ls s3://{s3_bucket_name}/models/clip/model.tar.gz

In [None]:
!mkdir -p code

In [None]:
%%writefile code/clip_inference.py

import io
import torch
import clip
from PIL import Image
import json
import logging
import sys
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

MODEL_NAME = os.environ.get("MODEL_NAME", "RN50.pt")
# ENCODE_TYPE could be IMAGE or TEXT
ENCODE_TYPE = os.environ.get("ENCODE_TYPE", "TEXT")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# defining model and loading weights to it.
def model_fn(model_dir):
    model, preprocess = clip.load(os.path.join(model_dir, MODEL_NAME), device=device)
    return {"model_obj": model, "preprocess_fn": preprocess}


def load_from_bytearray(request_body):
    
    return image

# data loading
def input_fn(request_body, request_content_type):
    assert request_content_type in (
        "application/json",
        "application/x-image",
    ), f"{request_content_type} is an unknown type."
    if request_content_type == "application/json":
        data = json.loads(request_body)["inputs"]
    elif request_content_type == "application/x-image":
        image_as_bytes = io.BytesIO(request_body)
        data = Image.open(image_as_bytes)
    return data


# inference
def predict_fn(input_object, model):
    model_obj = model["model_obj"]
    # for image preprocessing
    preprocess_fn = model["preprocess_fn"]
    assert ENCODE_TYPE in ("TEXT", "IMAGE"), f"{ENCODE_TYPE} is an unknown encode type."

    # preprocessing
    if ENCODE_TYPE == "TEXT":
        input_ = clip.tokenize(input_object).to(device)
    elif ENCODE_TYPE == "IMAGE":
        input_ = preprocess_fn(input_object).unsqueeze(0).to(device)

    # inference
    with torch.no_grad():
        if ENCODE_TYPE == "TEXT":
            prediction = model_obj.encode_text(input_)
        elif ENCODE_TYPE == "IMAGE":
            prediction = model_obj.encode_image(input_)
    return prediction


# Serialize the prediction result into the desired response content type
def output_fn(predictions, content_type):
    assert content_type == "application/json"
    res = predictions.cpu().numpy().tolist()
    return json.dumps(res)

In [None]:
%%writefile code/requirements.txt
ftfy
regex
tqdm
git+https://github.com/openai/CLIP.git

In [None]:
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role, Session


role = get_execution_role()
shared_params = dict(
    entry_point="clip_inference.py",
    source_dir="code",
    role=role,
    model_data=f"s3://{s3_bucket_name}/models/clip/model.tar.gz",
    framework_version="1.9.0",
    py_version="py38",
)

clip_image_model = PyTorchModel(
    env={'MODEL_NAME': 'RN50.pt', "ENCODE_TYPE": "IMAGE"},
    name="clip-image-model",
    **shared_params
)

clip_text_model = PyTorchModel(
    env={'MODEL_NAME': 'RN50.pt', "ENCODE_TYPE": "TEXT"},
    name="clip-text-model",
    **shared_params
)

In [None]:
text_predictor = clip_text_model.deploy(
    instance_type='ml.c5.xlarge',
    initial_instance_count=1,
    serverless_inference_config=ServerlessInferenceConfig(memory_size_in_mb=6144),
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
    wait=True
)

image_predictor = clip_image_model.deploy(
    instance_type='ml.c5.xlarge',
    initial_instance_count=1,
    serverless_inference_config=ServerlessInferenceConfig(memory_size_in_mb=6144),
    serializer=IdentitySerializer(content_type="application/x-image"),
    deserializer=JSONDeserializer(),
    wait=True
)

def encode_image(file_name):    
    with open(file_name, "rb") as f:
        payload = f.read()
        payload = bytearray(payload)
    res = image_predictor.predict(payload)
    return res[0]

def encode_name(item_name):
    res = text_predictor.predict({"inputs": [f"this is a {item_name}"]})
    return res[0]

In [None]:
text = 'fishes'
text_emb = encode_name(text)
print(len(text_emb))

In [None]:
image = 'test-image-clip.jpeg'
image_emb = encode_image(image)
print(len(image_emb))

In [None]:
import numpy as np
text_emb = np.array(text_emb)
image_emb = np.array(image_emb)
cos_sim = text_emb.dot(image_emb) / (np.linalg.norm(text_emb) * np.linalg.norm(image_emb))
print(cos_sim)