# TorchServe Local Development and Testing

## Overview

This notebook demonstrates how to run **TorchServe locally** for development and testing of PyTorch model deployments. This is an essential workflow for validating model serving behavior before deploying to production environments.

### What This Notebook Does

- Downloads a pre-built Model Archive (.mar) file from Google Cloud Storage
- Installs TorchServe and its dependencies locally
- Starts a local TorchServe instance with the model
- Tests predictions via the REST API
- Demonstrates management and metrics endpoints
- Shows proper shutdown procedures

### Prerequisites

Before running this notebook, you should have:
- Completed the `../pytorch-autoencoder.ipynb` notebook to create and export the model archive
- A `.mar` file stored in Google Cloud Storage
- Python 3.8+ environment
- Sufficient local compute resources (CPU/GPU)

### What is TorchServe?

[TorchServe](https://pytorch.org/serve/) is PyTorch's official model serving framework that provides:
- Production-ready REST and gRPC APIs for model inference
- Multi-model serving with dynamic model loading
- Built-in metrics and logging
- Custom preprocessing and postprocessing via handlers
- Support for batching and versioning

### Benefits of Local Testing

Running TorchServe locally allows you to:
- **Validate** model behavior before cloud deployment
- **Debug** custom handlers and preprocessing logic
- **Test** prediction endpoints with sample data
- **Iterate** quickly without cloud deployment overhead
- **Understand** TorchServe configuration and APIs

### Architecture

```
┌─────────────────────────────────────────┐
│         Local Development               │
│                                         │
│  ┌───────────────────────────────────┐ │
│  │      TorchServe Instance          │ │
│  │                                   │ │
│  │  ┌─────────────────────────────┐ │ │
│  │  │   Inference API :8080       │ │ │
│  │  │   - POST /predictions       │ │ │
│  │  └─────────────────────────────┘ │ │
│  │                                   │ │
│  │  ┌─────────────────────────────┐ │ │
│  │  │   Management API :8081      │ │ │
│  │  │   - GET /models             │ │ │
│  │  └─────────────────────────────┘ │ │
│  │                                   │ │
│  │  ┌─────────────────────────────┐ │ │
│  │  │   Metrics API :8082         │ │ │
│  │  │   - GET /metrics            │ │ │
│  │  └─────────────────────────────┘ │ │
│  │                                   │ │
│  │  ┌─────────────────────────────┐ │ │
│  │  │   Model Store               │ │ │
│  │  │   - model.mar               │ │ │
│  │  └─────────────────────────────┘ │ │
│  └───────────────────────────────────┘ │
│                                         │
│  ┌───────────────────────────────────┐ │
│  │   Jupyter Notebook Client         │ │
│  │   - HTTP requests to APIs         │ │
│  └───────────────────────────────────┘ │
└─────────────────────────────────────────┘
```

### What You Will Learn

1. How to install and configure TorchServe locally
2. How to load model archives into TorchServe
3. How to make predictions via the REST API
4. How to use management APIs to inspect models
5. How to access metrics and monitoring data
6. Best practices for local model serving development

---
## Environment Setup

In [None]:
PROJECT_ID = !gcloud config get-value project
PROJECT_ID = PROJECT_ID[0]
print(f"Project ID: {PROJECT_ID}")

In [None]:
SERIES = "frameworks"
EXPERIMENT = "pytorch-autoencoder"
REGION = "us-central1"

# Model archive settings
MODEL_NAME = "pytorch_autoencoder"
MAR_FILE = "model.mar"
GCS_PATH = f"gs://{PROJECT_ID}/{SERIES}/{EXPERIMENT}/{MAR_FILE}"

# TorchServe settings
INFERENCE_PORT = 8080
MANAGEMENT_PORT = 8081
METRICS_PORT = 8082

print(f"Series: {SERIES}")
print(f"Experiment: {EXPERIMENT}")
print(f"Model Archive: {GCS_PATH}")

In [None]:
# Create working directory for model store
import os

MODEL_STORE_DIR = "./model_store"
os.makedirs(MODEL_STORE_DIR, exist_ok=True)
print(f"Model store directory: {MODEL_STORE_DIR}")

---
## Python Setup

In [None]:
import json
import subprocess
import time
import requests
import numpy as np
from google.cloud import storage

In [None]:
# API endpoints
INFERENCE_URL = f"http://localhost:{INFERENCE_PORT}/predictions/{MODEL_NAME}"
MANAGEMENT_URL = f"http://localhost:{MANAGEMENT_PORT}"
METRICS_URL = f"http://localhost:{METRICS_PORT}/metrics"

print(f"Inference URL: {INFERENCE_URL}")
print(f"Management URL: {MANAGEMENT_URL}")
print(f"Metrics URL: {METRICS_URL}")

---
## Download Model Archive

Download the pre-built `.mar` file from Google Cloud Storage to the local model store.

In [None]:
# Download model archive from GCS
storage_client = storage.Client(project=PROJECT_ID)

# Parse GCS path
gcs_path_parts = GCS_PATH.replace("gs://", "").split("/")
bucket_name = gcs_path_parts[0]
blob_path = "/".join(gcs_path_parts[1:])

# Download file
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_path)
local_mar_path = os.path.join(MODEL_STORE_DIR, MAR_FILE)

print(f"Downloading {GCS_PATH} to {local_mar_path}...")
blob.download_to_filename(local_mar_path)
print(f"Downloaded successfully!")
print(f"File size: {os.path.getsize(local_mar_path) / 1024:.2f} KB")

### Install TorchServe

Install TorchServe and the model archiver tool.

In [None]:
# Install TorchServe and dependencies
!pip install -q torchserve torch-model-archiver torch-workflow-archiver

# Verify installation
!torchserve --version

---
## Start TorchServe

Start a local TorchServe instance with the downloaded model archive.

**Command options:**
- `--start`: Start the server
- `--model-store`: Directory containing model archives
- `--models`: Model to load at startup (format: `model_name=archive.mar`)
- `--ncs`: No config snapshot (simplifies local development)

The server will expose three APIs:
- **Inference API** (port 8080): For predictions
- **Management API** (port 8081): For model management
- **Metrics API** (port 8082): For monitoring metrics

In [None]:
# Start TorchServe
start_command = [
    "torchserve",
    "--start",
    "--model-store", MODEL_STORE_DIR,
    "--models", f"{MODEL_NAME}={MAR_FILE}",
    "--ncs"
]

print(f"Starting TorchServe with command: {' '.join(start_command)}")
result = subprocess.run(start_command, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("Stderr:", result.stderr)

# Wait for server to start
print("\nWaiting for TorchServe to start...")
time.sleep(10)

# Check if server is running
try:
    response = requests.get(f"{MANAGEMENT_URL}/models")
    if response.status_code == 200:
        print("TorchServe started successfully!")
        print(f"Loaded models: {response.json()}")
    else:
        print(f"Server responded with status {response.status_code}")
except Exception as e:
    print(f"Error checking server status: {e}")
    print("Server may still be starting up...")

### Test Predictions

Make test predictions using the inference API. We'll send sample transaction data with 30 features.

In [None]:
# Create sample input data (30 features)
# Generate a sample transaction with realistic values
sample_transaction = np.random.randn(30).astype(np.float32)

print("Sample transaction data:")
print(sample_transaction)
print(f"\nShape: {sample_transaction.shape}")
print(f"Data type: {sample_transaction.dtype}")

# Make prediction request
print(f"\nSending prediction request to {INFERENCE_URL}...")

# Prepare request data
data = json.dumps({"instances": sample_transaction.tolist()})

headers = {"Content-Type": "application/json"}

try:
    response = requests.post(INFERENCE_URL, data=data, headers=headers)
    
    print(f"\nResponse status: {response.status_code}")
    
    if response.status_code == 200:
        prediction = response.json()
        print("\nPrediction successful!")
        print(f"Response: {json.dumps(prediction, indent=2)}")
        
        # Parse and display reconstruction
        if isinstance(prediction, dict) and "predictions" in prediction:
            reconstructed = np.array(prediction["predictions"])
            print(f"\nReconstructed shape: {reconstructed.shape}")
            
            # Calculate reconstruction error
            mse = np.mean((sample_transaction - reconstructed.flatten())**2)
            print(f"Reconstruction MSE: {mse:.6f}")
    else:
        print(f"Error: {response.text}")
        
except Exception as e:
    print(f"Error making prediction: {e}")

### Management API

Explore the management API to inspect loaded models and their details.

In [None]:
# List all models
print("Listing all models...")
response = requests.get(f"{MANAGEMENT_URL}/models")
print(f"Status: {response.status_code}")
print(f"Models: {json.dumps(response.json(), indent=2)}")

# Get detailed model information
print(f"\n\nGetting details for model '{MODEL_NAME}'...")
response = requests.get(f"{MANAGEMENT_URL}/models/{MODEL_NAME}")
print(f"Status: {response.status_code}")
model_info = response.json()
print(f"Model info: {json.dumps(model_info, indent=2)}")

# Get model workers information
print(f"\n\nGetting worker information for model '{MODEL_NAME}'...")
response = requests.get(f"{MANAGEMENT_URL}/models/{MODEL_NAME}/all")
print(f"Status: {response.status_code}")
worker_info = response.json()
print(f"Worker info: {json.dumps(worker_info, indent=2)}")

### Metrics

Access the metrics endpoint to view performance and usage statistics.

In [None]:
# Get metrics
print("Fetching metrics...")
response = requests.get(METRICS_URL)
print(f"Status: {response.status_code}")

if response.status_code == 200:
    metrics = response.text
    print("\nMetrics (Prometheus format):")
    print("="*80)
    
    # Display relevant metrics
    for line in metrics.split('\n'):
        # Filter for model-specific metrics
        if MODEL_NAME in line or 'ts_' in line:
            print(line)
    
    print("="*80)
    print(f"\nTotal metrics lines: {len(metrics.split(chr(10)))}")
else:
    print(f"Error fetching metrics: {response.text}")

### Stop TorchServe

Gracefully shutdown the TorchServe instance.

In [None]:
# Stop TorchServe
print("Stopping TorchServe...")
result = subprocess.run(["torchserve", "--stop"], capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("Stderr:", result.stderr)

# Wait for shutdown
time.sleep(5)
print("TorchServe stopped.")

# Verify server is stopped
try:
    response = requests.get(f"{MANAGEMENT_URL}/models", timeout=2)
    print("Warning: Server may still be running")
except requests.exceptions.RequestException:
    print("Confirmed: Server is not responding (stopped successfully)")

---
## Summary and Next Steps

### What You Accomplished

In this notebook, you successfully:

1. **Downloaded** a model archive (.mar) from Google Cloud Storage
2. **Installed** TorchServe and its dependencies locally
3. **Started** a local TorchServe instance with your model
4. **Made predictions** via the REST inference API
5. **Explored** the management API to inspect model details
6. **Accessed** metrics for monitoring model performance
7. **Stopped** TorchServe gracefully

### Key Takeaways

- **TorchServe provides three APIs**: inference (8080), management (8081), and metrics (8082)
- **Model archives (.mar)** package the model, handler, and dependencies together
- **Local testing** enables rapid iteration before cloud deployment
- **REST APIs** make it easy to integrate with any client application
- **Metrics** help you understand model performance and usage patterns

### Next Steps

Now that you've tested TorchServe locally, you can:

1. **Deploy to Vertex AI**: Use the model archive in a custom container deployment
   - See the `torchserve-vertex.ipynb` notebook for Vertex AI deployment
   
2. **Customize the handler**: Modify preprocessing/postprocessing logic
   - Edit the handler in your model training notebook
   - Rebuild the .mar file and test locally
   
3. **Performance testing**: Load test your model locally
   - Use tools like `apache-bench` or `locust` for load testing
   - Adjust worker counts and batch sizes
   
4. **Production deployment**: Deploy to production environments
   - Use Vertex AI Prediction for managed serving
   - Or deploy to GKE for full control

### Additional Resources

- [TorchServe Documentation](https://pytorch.org/serve/)
- [TorchServe REST API](https://pytorch.org/serve/rest_api.html)
- [Custom Handlers Guide](https://pytorch.org/serve/custom_service.html)
- [Vertex AI Custom Container Deployment](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements)