# How to use

--- 

SG4D100M is a moleculer embedding model.

This notebook demonstrate how to deploy [SG4D100M](https://aws.amazon.com/marketplace/pp/prodview-soy64f34ucl4g?sr=0-1&ref_=beagle&applicationId=AWSMPContessa) using Amazon SageMaker.

--- 

In [1]:
!pip install sagemaker polars > /dev/null

In [2]:
import io
import polars as pl
from sagemaker import ModelPackage
import sagemaker as sage
from sagemaker import get_execution_role
from IPython.display import clear_output

## Setup SageMaker Endpiont

In [4]:
model_package_arn = "arn:aws:sagemaker:us-east-1:865070037744:model-package/sg4d100m-9-a9336203cfd13fc38a8dd188ac0feaf4"
model_name = "sg4d100m-1"
content_type = "application/jsonlines"
instance_type = "ml.m5.xlarge"

sagemaker_session = sage.Session()
role = get_execution_role()
model = ModelPackage(
    role=role,
    model_package_arn=model_package_arn,
    sagemaker_session=sagemaker_session,
)
predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=model_name,
)

clear_output()

## Run inference


### Model Input

In [5]:
request_df = pl.DataFrame(
    [
        {"smiles": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"},
        {"smiles": "CCNC(=O)CC[C@H](N)C(=O)O"},
    ]
)
request_df

smiles
str
"""CN1C=NC2=C1C(=O)N(C(=O)N2C)C"""
"""CCNC(=O)CC[C@H](N)C(=O)O"""


### Inference on the endpoint

In [6]:
# dataframe -> ndjson
request_buf = io.BytesIO()
request_df.write_ndjson(request_buf)
request_body = request_buf.getvalue()

# inference
sagemaker_runtime_client = sagemaker_session.sagemaker_runtime_client
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=model_name,
    ContentType="application/jsonlines",
    Accept="application/jsonlines",
    Body=request_body,
)

# ndjson -> dataframe
response_body = response["Body"].read()
response_df = pl.read_ndjson(response_body)

### Model output

In [7]:
response_df

smiles,sg4d100m
str,list[f64]
"""CN1C=NC2=C1C(=O)N(C(=O)N2C)C""","[-8.437966, 4.945457, … 1.488266]"
"""CCNC(=O)CC[C@H](N)C(=O)O""","[-7.948503, 5.331437, … 0.845603]"


## Clean up SageMaker Endpoint

In [8]:
model.sagemaker_session.delete_endpoint(model_name)
model.sagemaker_session.delete_endpoint_config(model_name)
model.delete_model()

clear_output()