Deploy endpoint

Step 1: Setup

In [None]:
#Setup
import boto3
import sagemaker
import json
import os
import matplotlib
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorchModel
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor
# Define bucket and model location
bucket = "thibaut-test-inference-cbramod"
key = "cbramod/model.tar.gz"
model_data = f"s3://{bucket}/{key}"

# Get current role and region
role = get_execution_role()
region = sagemaker.Session().boto_region_name

print("✅ Role:", role)
print("✅ Region:", region)
print("✅ Model S3 path:", model_data)

Step 2: Define the PyTorchModel

In [None]:
pytorch_model = PyTorchModel(
    entry_point="inference.py",
    source_dir=None,
    model_data=model_data,
    role=role,
    framework_version="2.0.0",
    py_version="py310",
    env={"RAW_RECORDINGS_BUCKET": "idn-dev-raw-recordings-bucket"}
)

Step 3: Define Serverless Config

In [None]:
serverless_config = ServerlessInferenceConfig(
    memory_size_in_mb=6144,
    max_concurrency=2
)

Step 4: Deploy to SageMaker Serverless Endpoint

In [None]:
endpoint_name = "eeg-serverless-endpoint"

predictor = pytorch_model.deploy(
    endpoint_name="eeg-serverless-endpoint",
    serverless_inference_config=serverless_config,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

print(f"✅ Endpoint deployed: {endpoint_name}")

Step 5: Test the Endpoint with Sample Input

In [None]:
# Setup
predictor = Predictor(
    endpoint_name="eeg-serverless-endpoint",
    serializer=JSONSerializer()
)
os.environ["RAW_RECORDINGS_BUCKET"] = "idn-dev-raw-recordings-bucket"

# Input
sample_input = {
    "bucket_name": "idn-dev-raw-recordings-bucket",
    "userId": "036d5eb6-e177-475a-bb64-5c09e51062a7",
    "deviceId": "F9-79-78-54-CA-15",
    "recordingId": "1707409115815",
    "orig_sfreq": 250
}

# Predict
response = predictor.predict(sample_input)

# ✅ Decode + parse
if isinstance(response, bytes):
    response = response.decode()

parsed = json.loads(response)
print("✅ Clean Prediction:", parsed["prediction"])
print(len(parsed["prediction"]))