# Deploy SpeciesNet to SageMaker Serverless

This notebook deploys the SpeciesNet model to a SageMaker serverless endpoint.

## Setup

### Import Dependencies

In [None]:
%matplotlib inline

import boto3
import time
import json
import base64
from pathlib import Path
from PIL import Image
import sagemaker
from io import BytesIO

### Initialize AWS Session

In [None]:
sess = boto3.Session()
sm = sess.client('sagemaker')
region = sess.region_name
account = boto3.client('sts').get_caller_identity().get('Account')

### Get IAM Role

**Note**: Ensure the IAM role has:
- `AmazonS3FullAccess`
- `AmazonSageMakerFullAccess`

In [None]:
role = sagemaker.get_execution_role()
print(f"Using role: {role}")

## Build and Push Container

Build the custom container with our inference code and push it to Amazon ECR.

In [None]:
# Create ECR repository if it doesn't exist
registry_name = "speciesnet-sagemaker"
ecr = boto3.client('ecr')

try:
    ecr.create_repository(repositoryName=registry_name)
except ecr.exceptions.RepositoryAlreadyExistsException:
    pass

# Get auth token and login to ECR
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account}.dkr.ecr.{region}.amazonaws.com

# Build container
!docker build -t {registry_name} -f Dockerfile.sagemaker .

# Tag and push to ECR
image_uri = f"{account}.dkr.ecr.{region}.amazonaws.com/{registry_name}:latest"
!docker tag {registry_name} {image_uri}
!docker push {image_uri}

print(f"Container pushed to: {image_uri}")

## Create SageMaker Model

In [None]:
model_prefix = "speciesnet-v4.0.0a"

# Check if model already exists
model_already_created = False
for model_def in sm.list_models()['Models']:
    if model_prefix == model_def['ModelName']:
        create_model_response = model_def
        model_already_created = True

# Create model if it doesn't exist
if not model_already_created:
    create_model_response = sm.create_model(
        ModelName=model_prefix,
        ExecutionRoleArn=role,
        PrimaryContainer={
            "Image": image_uri,
            "Environment": {
                "SAGEMAKER_PROGRAM": "serve.py"
            }
        }
    )

print(f"Model ARN: {create_model_response['ModelArn']}")

## Create Endpoint Configuration

In [None]:
endpoint_config_name = f"{model_prefix}-config"
create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "ModelName": model_prefix,
            "VariantName": "AllTraffic",
            "ServerlessConfig": {
                "MemorySizeInMB": 6144,  # 6GB memory
                "MaxConcurrency": 8       # Maximum concurrent invocations
            }
        }
    ]
)

print(f"Endpoint Config ARN: {create_endpoint_config_response['EndpointConfigArn']}")

## Create and Deploy Endpoint

In [None]:
endpoint_name = model_prefix
create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name
)

print(f"Endpoint ARN: {create_endpoint_response['EndpointArn']}")

# Wait for endpoint creation
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print(f"Status: {status}")

while status == 'Creating':
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp['EndpointStatus']
    print(f"Status: {status}")

print(f"Arn: {resp['EndpointArn']}")
print(f"Final Status: {status}")

## Test the Endpoint

In [None]:
# Load a test image from the test_data directory
test_image = Image.open("tests/test_data/african_elephants.jpg")
display(test_image)

# Convert image to base64
buffered = BytesIO()
test_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()

# Prepare payload
payload = {
    "image_data": img_str,
    "country": "Kenya"  # Optional: provide location context
}

# Invoke endpoint
client = boto3.client('runtime.sagemaker')
response = client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType='application/json',
    Body=json.dumps(payload)
)

# Parse results
result = json.loads(response['Body'].read().decode())
print("\nPrediction Results:")
print(json.dumps(result, indent=2))

## Cleanup Resources

**Note**: Only run this cell when you want to delete the endpoint and associated resources.

In [None]:
# Uncomment to cleanup
# client = boto3.client('sagemaker')
# client.delete_endpoint(EndpointName=endpoint_name)
# client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
# client.delete_model(ModelName=model_prefix)