In [1]:
import boto3
import io
import numpy as np
import sagemaker
from sklearn.datasets import load_iris
from sagemaker import ModelPackage
from sagemaker.predictor import Predictor
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import CSVDeserializer
from sagemaker import get_execution_role

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


### Settings

In [9]:
model_package_group_name = 'Proj1' # model package group name. The latest model from this group will be used
output_path = "s3://sagemaker-bucket-ds/OUTPUT/IRIS/" # Path to output data
input_path = "s3://sagemaker-bucket-ds/INPUT/IRIS/" # Path to the file with inpt data
bucket_name = "sagemaker-bucket-ds"
input_file_name = "INPUT/IRIS/iris.csv"
input_path = f"s3://{bucket_name}/{input_file_name}"

### Create clients

In [10]:
sagemaker_client = boto3.client('sagemaker')
s3_client = boto3.client('s3')
role = get_execution_role()
sagemaker_session = sagemaker.Session()

### Create input data
In normal case this data would be already available

In [11]:
iris = load_iris()
X = iris.data

# Convert the NumPy array to a CSV string
csv_buffer = io.StringIO()
np.savetxt(csv_buffer, X, delimiter=',', fmt='%.6f')
s3_client.put_object(Bucket=bucket_name, Key=input_file_name, Body=csv_buffer.getvalue())

{'ResponseMetadata': {'RequestId': 'P85RW1ZV7PP79W1B',
  'HostId': 'Q2ORiaHitBofv2kRRiBjoJGoDRd1dH41M4va99t6aEv0Nwo7eTvaPYi/HN9JDVrZyHtiL+sVKY8=',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amz-id-2': 'Q2ORiaHitBofv2kRRiBjoJGoDRd1dH41M4va99t6aEv0Nwo7eTvaPYi/HN9JDVrZyHtiL+sVKY8=',
   'x-amz-request-id': 'P85RW1ZV7PP79W1B',
   'date': 'Wed, 19 Jun 2024 15:41:38 GMT',
   'x-amz-server-side-encryption': 'AES256',
   'etag': '"0cb2e05024525139c7065e36ffcefcf1"',
   'server': 'AmazonS3',
   'content-length': '0'},
  'RetryAttempts': 0},
 'ETag': '"0cb2e05024525139c7065e36ffcefcf1"',
 'ServerSideEncryption': 'AES256'}

### Get ARN of model package

In [12]:
# List the model packages in the specified model package group
list_model_packages_response = sagemaker_client.list_model_packages(
    ModelPackageGroupName=model_package_group_name
)

# Print the list of model packages
model_packages = list_model_packages_response['ModelPackageSummaryList']
for model_package in model_packages:
    print("Model Package ARN: ", model_package['ModelPackageArn'])

# Get the ARN of the latest model package (if there are multiple)
model_package_arn = model_packages[0]['ModelPackageArn']  # Adjust as necessary
print("Latest Model Package ARN: ", model_package_arn)

Model Package ARN:  arn:aws:sagemaker:eu-west-1:211125740051:model-package/Proj1/1
Latest Model Package ARN:  arn:aws:sagemaker:eu-west-1:211125740051:model-package/Proj1/1


### Retrieve the model

In [13]:
# Create the model from the model package
model = ModelPackage(
    role=role,
    model_package_arn=model_package_arn,
    sagemaker_session=sagemaker_session
)

### Create transformer

In [16]:
transformer = model.transformer(
    instance_count=1, 
    instance_type='ml.m5.large',
    assemble_with='Line',
    output_path=output_path)

INFO:sagemaker:Creating model with name: Proj1-2024-06-19-15-42-37-860


In [17]:
transformer.transform(data=input_path,
                      split_type='Line',
                       content_type='text/csv',
                       wait=True)

INFO:sagemaker:Creating transform job with name: Proj1-2024-06-19-15-42-45-073


................................[34m2024-06-19 15:48:04,783 INFO - sagemaker-containers - No GPUs detected (normal if no gpus installed)[0m
[34m2024-06-19 15:48:04,786 INFO - sagemaker-containers - No GPUs detected (normal if no gpus installed)[0m
[34m2024-06-19 15:48:04,787 INFO - sagemaker-containers - nginx config: [0m
[34mworker_processes auto;[0m
[34mdaemon off;[0m
[34mpid /tmp/nginx.pid;[0m
[34merror_log  /dev/stderr;[0m
[34mworker_rlimit_nofile 4096;[0m
[34mevents {
  worker_connections 2048;[0m
[34m}[0m
[34mhttp {
  include /etc/nginx/mime.types;
  default_type application/octet-stream;
  access_log /dev/stdout combined;
  upstream gunicorn {
    server unix:/tmp/gunicorn.sock;
  }
  server {
    listen 8080 deferred;
    client_max_body_size 0;
    keepalive_timeout 3;
    location ~ ^/(ping|invocations|execution-parameters) {
      proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
      proxy_set_header Host $http_host;
      proxy_redirect off

KeyboardInterrupt: 

In [29]:
# The ModelName should be derived from the model package ARN or set manually
model_name = model_package_arn.split('/')[-1]
model.name = model_name

### Create an endpoint configuration

In [30]:
endpoint_config_name = 'Iris-endpoint-config'
endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': 'AllTraffic',
            'ModelName': model.name,
            'InstanceType': 'ml.m5.large',
            'InitialInstanceCount': 1
        }
    ]
)
print("Endpoint Config Arn: ", endpoint_config_response['EndpointConfigArn'])

ClientError: An error occurred (ValidationException) when calling the CreateEndpointConfig operation: Could not find model "1".

In [None]:
# Create an endpoint
endpoint_name = 'model-endpoint'
endpoint_response = sagemaker_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name
)
print("Endpoint Arn: ", endpoint_response['EndpointArn'])

# Wait for the endpoint to be in service
sagemaker_client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)
print(f"Endpoint {endpoint_name} is in service!")

# Perform inference
predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=CSVSerializer(),
    deserializer=CSVDeserializer()
)

# Example data for inference (in CSV format)
input_data = "1.0, 2.0, 3.0, 4.0"  # Replace with actual data

# Perform the prediction
prediction = predictor.predict(input_data)
print("Prediction: ", prediction)