In [1]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

aws_role = "arn:aws:iam::654654375075:role/service-role/AmazonSageMaker-ExecutionRole-20240502T152077"
aws_region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.config INFO - Not applying SDK defaults from location: /app/etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sprince0031/.var/app/com.visualstudio.code/config/sagemaker/config.yaml


In [2]:
(
    model_id,
    model_version,
) = (
    "pytorch-ic-resnet18",
    "*",
)


In [4]:
print(bucket)

sagemaker-us-east-1-654654375075


In [5]:
from sagemaker.jumpstart.model import JumpStartModel

my_model = JumpStartModel(model_id=model_id, role=aws_role, sagemaker_session=sess, region=aws_region)
base_model_predictor = my_model.deploy()

Using model 'pytorch-ic-resnet18' with wildcard version identifier '*'. You can pin to version '3.0.0' for more stable results. Note that models may have different input/output signatures after a major version upgrade.


-------!

In [6]:
base_model_predictor.accept = "application/json;verbose"

In [7]:
s3_bucket = f"jumpstart-cache-prod-{aws_region}"
key_prefix = "inference-notebook-assets"


def download_from_s3(images):
    for filename, image_key in images.items():
        boto3.client("s3").download_file(s3_bucket, f"{key_prefix}/{image_key}", filename)


images = {"img1.jpg": "cat.jpg", "img2.jpg": "dog.jpg"}
download_from_s3(images)

In [8]:
from IPython.core.display import HTML


def predict_top_k_labels(probabilities, labels, k):
    topk_prediction_ids = sorted(
        range(len(probabilities)), key=lambda index: probabilities[index], reverse=True
    )[:k]
    topk_class_labels = ", ".join([labels[id] for id in topk_prediction_ids])
    return topk_class_labels


for image_filename in images.keys():
    with open(image_filename, "rb") as file:
        img = file.read()
    query_response = base_model_predictor.predict(img)
    labels, probabilities = query_response["labels"], query_response["probabilities"]
    top5_class_labels = predict_top_k_labels(probabilities, labels, 5)
    display(
        HTML(
            f'<img src={image_filename} alt={image_filename} align="left" style="width: 250px;"/>'
            f"<figcaption>Top-5 predictions: {top5_class_labels} </figcaption>"
        )
    )

In [9]:
# Delete the SageMaker endpoint and the attached resources
base_model_predictor.delete_model()
base_model_predictor.delete_endpoint()