# Sigray ML Platform on SageMaker

This notebook demonstrates how to use the Sigray Machine Learning Platform on AWS SageMaker for 3D image enhancement.

## Overview

We'll cover:
1. Setting up the environment
2. Preparing training data
3. Running a training job
4. Deploying for inference
5. Batch processing

## 1. Environment Setup

In [None]:
# Install the Sigray ML Platform
!git clone https://github.com/tianzhuqin-argonne/Image_quality_enhancement.git
%cd Image_quality_enhancement
!pip install -e .

In [None]:
# Import required libraries
import sagemaker
import boto3
import numpy as np
import json
from pathlib import Path
from sagemaker.pytorch import PyTorch, PyTorchModel
from sagemaker import get_execution_role

# Import Sigray ML Platform
from src.inference.api import ImageEnhancementAPI
from src.core.config import InferenceConfig
from src.testing.test_fixtures import TestDataFixtures

print("✅ All imports successful!")

In [None]:
# Initialize SageMaker session
sagemaker_session = sagemaker.Session()
role = get_execution_role()
bucket = sagemaker_session.default_bucket()
region = sagemaker_session.boto_region_name

print(f"SageMaker role: {role}")
print(f"S3 bucket: {bucket}")
print(f"Region: {region}")

## 2. Create and Upload Training Data

In [None]:
# Create synthetic training data
import tempfile
import shutil

temp_dir = tempfile.mkdtemp()
print(f"Creating training data in: {temp_dir}")

with TestDataFixtures(temp_dir) as fixtures:
    # Create training dataset
    dataset = fixtures.create_test_dataset(
        num_pairs=10,
        volume_size="small",
        degradation_level="moderate"
    )
    
    print(f"Created {len(dataset)} training pairs")
    
    # Organize data for SageMaker
    training_data_dir = Path(temp_dir) / "training_data"
    input_dir = training_data_dir / "input"
    target_dir = training_data_dir / "target"
    
    input_dir.mkdir(parents=True, exist_ok=True)
    target_dir.mkdir(parents=True, exist_ok=True)
    
    # Save training pairs
    for i, (input_vol, target_vol) in enumerate(dataset):
        input_path = fixtures.save_test_tiff(input_vol, f"input_{i:03d}")
        target_path = fixtures.save_test_tiff(target_vol, f"target_{i:03d}")
        
        # Move to organized structure
        shutil.move(input_path, input_dir / f"input_{i:03d}.tif")
        shutil.move(target_path, target_dir / f"target_{i:03d}.tif")
    
    print(f"Training data organized in: {training_data_dir}")

In [None]:
# Upload training data to S3
training_s3_path = f"s3://{bucket}/sigray-ml-training-data"

training_input = sagemaker_session.upload_data(
    path=str(training_data_dir),
    bucket=bucket,
    key_prefix="sigray-ml-training-data"
)

print(f"Training data uploaded to: {training_input}")

## 3. Run SageMaker Training Job

In [None]:
# Define hyperparameters
hyperparameters = {
    'epochs': 10,  # Reduced for demo
    'batch-size': 4,
    'learning-rate': 1e-4,
    'patch-size': '128,128',  # Smaller for demo
    'overlap': 16,
    'optimizer': 'adam',
    'loss-function': 'mse',
    'device': 'auto',
    'early-stopping-patience': 5,
    'use-augmentation': True
}

print("Hyperparameters:")
for key, value in hyperparameters.items():
    print(f"  {key}: {value}")

In [None]:
# Create PyTorch estimator
estimator = PyTorch(
    entry_point='sagemaker_train.py',
    source_dir='sagemaker',
    role=role,
    instance_type='ml.p3.2xlarge',  # GPU instance
    instance_count=1,
    volume_size=30,  # GB
    max_run=3600,    # 1 hour max
    framework_version='1.12.0',
    py_version='py38',
    hyperparameters=hyperparameters,
    environment={
        'PYTHONPATH': '/opt/ml/code'
    },
    # Enable spot instances for cost savings (optional)
    use_spot_instances=True,
    max_wait=7200,
    checkpoint_s3_uri=f's3://{bucket}/sigray-ml-checkpoints/'
)

print("✅ Estimator created")

In [None]:
# Start training job
print("🚀 Starting training job...")

estimator.fit({
    'training': training_input
})

print("✅ Training job completed!")
print(f"Model artifacts: {estimator.model_data}")

## 4. Deploy Model for Real-time Inference

In [None]:
# Create PyTorch model from training job
model = PyTorchModel(
    model_data=estimator.model_data,
    role=role,
    entry_point='sagemaker_inference.py',
    source_dir='sagemaker',
    framework_version='1.12.0',
    py_version='py38'
)

print("✅ Model created")

In [None]:
# Deploy model as endpoint
print("🚀 Deploying model endpoint...")

predictor = model.deploy(
    initial_instance_count=1,
    instance_type='ml.p3.2xlarge',
    endpoint_name='sigray-ml-endpoint'
)

print("✅ Endpoint deployed!")
print(f"Endpoint name: {predictor.endpoint_name}")

## 5. Test Real-time Inference

In [None]:
# Create test data
test_image = np.random.rand(3, 128, 128).astype(np.float32)
print(f"Test image shape: {test_image.shape}")

# Prepare input for endpoint
import base64

# Serialize numpy array
image_bytes = test_image.tobytes()
image_b64 = base64.b64encode(image_bytes).decode('utf-8')

# Create request payload
payload = {
    'image_data': image_b64,
    'shape': test_image.shape,
    'metadata': {
        'source': 'notebook_test',
        'timestamp': '2025-01-14'
    }
}

print("✅ Test payload prepared")

In [None]:
# Make prediction
print("🔮 Making prediction...")

# Set content type and accept headers
predictor.content_type = 'application/json'
predictor.accept = 'application/json'

# Make prediction
result = predictor.predict(payload)

print("✅ Prediction completed!")
print(f"Success: {result['success']}")
print(f"Processing time: {result.get('processing_time', 'N/A')} seconds")

if result['success']:
    enhanced_shape = np.array(result['enhanced_array']).shape
    print(f"Enhanced image shape: {enhanced_shape}")
    
    if 'quality_metrics' in result:
        metrics = result['quality_metrics']
        print("Quality metrics:")
        for key, value in metrics.items():
            if key != 'enhanced_array':  # Skip the large array
                print(f"  {key}: {value}")
else:
    print(f"Error: {result.get('error_message', 'Unknown error')}")

## 6. Batch Transform Job

In [None]:
# Create batch input data
batch_input_dir = Path(temp_dir) / "batch_input"
batch_input_dir.mkdir(exist_ok=True)

# Create several test images
for i in range(3):
    test_vol = fixtures.get_small_test_volume()
    test_path = fixtures.save_test_tiff(test_vol, f"batch_test_{i}")
    shutil.move(test_path, batch_input_dir / f"test_{i}.tif")

print(f"Created batch input data in: {batch_input_dir}")

In [None]:
# Upload batch input to S3
batch_input_s3 = sagemaker_session.upload_data(
    path=str(batch_input_dir),
    bucket=bucket,
    key_prefix="sigray-ml-batch-input"
)

print(f"Batch input uploaded to: {batch_input_s3}")

In [None]:
# Create batch transformer
transformer = model.transformer(
    instance_count=1,
    instance_type='ml.p3.2xlarge',
    output_path=f's3://{bucket}/sigray-ml-batch-output/',
    accept='application/json'
)

print("✅ Transformer created")

In [None]:
# Start batch transform job
print("🚀 Starting batch transform job...")

transformer.transform(
    data=batch_input_s3,
    content_type='image/tiff',
    split_type='None'
)

print("✅ Batch transform completed!")
print(f"Output location: {transformer.output_path}")

## 7. Monitor and Analyze Results

In [None]:
# Get training job metrics
import boto3
from datetime import datetime, timedelta

cloudwatch = boto3.client('cloudwatch')

# Get training job name
training_job_name = estimator.latest_training_job.name
print(f"Training job name: {training_job_name}")

# Get CPU utilization metrics
try:
    response = cloudwatch.get_metric_statistics(
        Namespace='AWS/SageMaker',
        MetricName='CPUUtilization',
        Dimensions=[
            {
                'Name': 'TrainingJobName',
                'Value': training_job_name
            }
        ],
        StartTime=datetime.utcnow() - timedelta(hours=2),
        EndTime=datetime.utcnow(),
        Period=300,
        Statistics=['Average']
    )
    
    if response['Datapoints']:
        avg_cpu = sum(dp['Average'] for dp in response['Datapoints']) / len(response['Datapoints'])
        print(f"Average CPU utilization: {avg_cpu:.2f}%")
    else:
        print("No CPU metrics available yet")
        
except Exception as e:
    print(f"Could not retrieve metrics: {e}")

In [None]:
# Download and examine training summary
import boto3
import json

s3 = boto3.client('s3')

# Parse model data S3 path
model_s3_path = estimator.model_data
bucket_name = model_s3_path.split('/')[2]
key_prefix = '/'.join(model_s3_path.split('/')[3:-1])

print(f"Model artifacts location: s3://{bucket_name}/{key_prefix}/")

# List available artifacts
try:
    response = s3.list_objects_v2(
        Bucket=bucket_name,
        Prefix=key_prefix
    )
    
    print("Available artifacts:")
    for obj in response.get('Contents', []):
        print(f"  {obj['Key']}")
        
except Exception as e:
    print(f"Could not list artifacts: {e}")

## 8. Cleanup Resources

In [None]:
# Delete endpoint to avoid charges
print("🧹 Cleaning up resources...")

try:
    predictor.delete_endpoint()
    print("✅ Endpoint deleted")
except Exception as e:
    print(f"Error deleting endpoint: {e}")

# Clean up local temporary files
try:
    shutil.rmtree(temp_dir)
    print("✅ Temporary files cleaned up")
except Exception as e:
    print(f"Error cleaning up temp files: {e}")

print("🎉 Cleanup completed!")

## Summary

In this notebook, we demonstrated:

1. ✅ **Environment Setup** - Installed Sigray ML Platform on SageMaker
2. ✅ **Data Preparation** - Created and uploaded training data to S3
3. ✅ **Training Job** - Ran distributed training on GPU instances
4. ✅ **Real-time Inference** - Deployed model as SageMaker endpoint
5. ✅ **Batch Processing** - Processed multiple images with batch transform
6. ✅ **Monitoring** - Analyzed training metrics and performance
7. ✅ **Cleanup** - Removed resources to avoid charges

### Next Steps

- Scale up with larger datasets and longer training
- Experiment with different hyperparameters
- Set up automated pipelines with SageMaker Pipelines
- Implement A/B testing for model versions
- Add custom metrics and monitoring

### Cost Optimization Tips

- Use spot instances for training (up to 90% savings)
- Choose appropriate instance types for your workload
- Delete endpoints when not in use
- Use batch transform for large-scale inference
- Monitor usage with AWS Cost Explorer