# Model Deployment with PyTorch

This notebook demonstrates various methods for deploying PyTorch models for production use.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import time
import os

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create output directory
output_dir = "deployment_outputs"
os.makedirs(output_dir, exist_ok=True)

## 1. Training a Simple Model

First, let's train a simple CNN on MNIST that we'll use for deployment examples.

In [None]:
class SimpleConvNet(nn.Module):
    """A simple CNN for MNIST classification."""
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(0.25)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [None]:
# Train the model
model = SimpleConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("Training model...")
model.train()
for epoch in range(2):  # Quick training for demo
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100:  # Limit for demo
            break
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/100:.4f}")

# Save the model
torch.save(model.state_dict(), os.path.join(output_dir, "simple_model.pth"))
print("Model trained and saved!")

In [None]:
# Test the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

## 2. TorchScript - Tracing and Scripting

TorchScript provides two ways to convert PyTorch models for deployment: tracing and scripting.

### 2.1 Tracing

In [None]:
# Tracing: Records operations as they are executed
model.eval()
example_input = torch.randn(1, 1, 28, 28).to(device)

# Trace the model
traced_model = torch.jit.trace(model, example_input)

# Save traced model
traced_path = os.path.join(output_dir, "traced_model.pt")
traced_model.save(traced_path)
print(f"Traced model saved to {traced_path}")

# Test traced model
output_original = model(example_input)
output_traced = traced_model(example_input)
print(f"Output difference: {torch.max(torch.abs(output_original - output_traced)).item()}")

### 2.2 Scripting

In [None]:
# Scripting: Analyzes Python code directly
scripted_model = torch.jit.script(model)

# Save scripted model
scripted_path = os.path.join(output_dir, "scripted_model.pt")
scripted_model.save(scripted_path)
print(f"Scripted model saved to {scripted_path}")

# Print the TorchScript code
print("\nTorchScript code:")
print(scripted_model.code)

### 2.3 Performance Comparison

In [None]:
# Compare inference speeds
test_input = torch.randn(100, 1, 28, 28).to(device)

# Warm up
for _ in range(10):
    _ = model(test_input)
    _ = traced_model(test_input)

# Time original model
start = time.time()
with torch.no_grad():
    for _ in range(100):
        _ = model(test_input)
original_time = time.time() - start

# Time traced model
start = time.time()
with torch.no_grad():
    for _ in range(100):
        _ = traced_model(test_input)
traced_time = time.time() - start

print(f"Original model: {original_time:.4f}s")
print(f"Traced model: {traced_time:.4f}s")
print(f"Speedup: {original_time/traced_time:.2f}x")

## 3. ONNX Export

ONNX (Open Neural Network Exchange) allows models to be deployed across different frameworks.

In [None]:
# Export to ONNX
model.eval()
dummy_input = torch.randn(1, 1, 28, 28).to(device)
onnx_path = os.path.join(output_dir, "model.onnx")

torch.onnx.export(
    model,                       # model
    dummy_input,                 # model input
    onnx_path,                   # output path
    export_params=True,          # store trained params
    opset_version=11,            # ONNX version
    do_constant_folding=True,    # optimize
    input_names=['input'],       # input names
    output_names=['output'],     # output names
    dynamic_axes={               # variable batch size
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print(f"Model exported to ONNX: {onnx_path}")
print(f"ONNX file size: {os.path.getsize(onnx_path) / 1024:.2f} KB")

In [None]:
# Verify ONNX model (if onnx is installed)
try:
    import onnx
    import onnxruntime as ort
    
    # Check model
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("ONNX model is valid!")
    
    # Run inference with ONNX Runtime
    ort_session = ort.InferenceSession(onnx_path)
    
    # Prepare input
    input_name = ort_session.get_inputs()[0].name
    test_data = np.random.randn(1, 1, 28, 28).astype(np.float32)
    
    # Run inference
    ort_outputs = ort_session.run(None, {input_name: test_data})
    print(f"ONNX output shape: {ort_outputs[0].shape}")
    
except ImportError:
    print("ONNX/ONNXRuntime not installed. Install with:")
    print("pip install onnx onnxruntime")

## 4. Model Quantization

Quantization reduces model size and improves inference speed by using lower precision representations.

In [None]:
# Create a quantization-friendly model
class QuantizableConvNet(nn.Module):
    def __init__(self):
        super(QuantizableConvNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.quant(x)
        
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        x = self.dequant(x)
        return x

### 4.1 Dynamic Quantization

In [None]:
# Dynamic quantization (weights only)
model_fp32 = QuantizableConvNet()
model_fp32.load_state_dict(torch.load(os.path.join(output_dir, "simple_model.pth"), 
                                      map_location=device))
model_fp32.eval()

# Apply dynamic quantization
model_int8_dynamic = torch.quantization.quantize_dynamic(
    model_fp32,
    {nn.Linear},  # Quantize Linear layers
    dtype=torch.qint8
)

# Compare model sizes
def print_model_size(model, name):
    torch.save(model.state_dict(), "temp.p")
    size_mb = os.path.getsize("temp.p") / 1e6
    os.remove("temp.p")
    print(f"{name} size: {size_mb:.2f} MB")
    return size_mb

size_fp32 = print_model_size(model_fp32, "FP32 model")
size_int8 = print_model_size(model_int8_dynamic, "INT8 model")
print(f"Size reduction: {(1 - size_int8/size_fp32)*100:.1f}%")

### 4.2 Static Quantization

In [None]:
# Static quantization requires calibration
model_fp32_static = QuantizableConvNet()
model_fp32_static.load_state_dict(torch.load(os.path.join(output_dir, "simple_model.pth"), 
                                             map_location='cpu'))  # Must be on CPU
model_fp32_static.eval()

# Set quantization config
model_fp32_static.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Prepare model for quantization
model_fp32_prepared = torch.quantization.prepare(model_fp32_static)

# Calibrate with representative data
print("Calibrating model...")
with torch.no_grad():
    for batch_idx, (data, _) in enumerate(test_loader):
        if batch_idx > 10:  # Use a few batches for calibration
            break
        model_fp32_prepared(data)

# Convert to quantized model
model_int8_static = torch.quantization.convert(model_fp32_prepared)
print("Static quantization complete!")

# Compare sizes
size_static = print_model_size(model_int8_static, "Static INT8 model")
print(f"Size reduction vs FP32: {(1 - size_static/size_fp32)*100:.1f}%")

### 4.3 Quantization Performance Comparison

In [None]:
# Compare inference speeds
test_input_cpu = torch.randn(100, 1, 28, 28)  # On CPU for quantized models

# Time FP32 model
model_fp32.eval()
start = time.time()
with torch.no_grad():
    for _ in range(50):
        _ = model_fp32(test_input_cpu)
fp32_time = time.time() - start

# Time static quantized model
start = time.time()
with torch.no_grad():
    for _ in range(50):
        _ = model_int8_static(test_input_cpu)
int8_time = time.time() - start

print(f"FP32 model: {fp32_time:.4f}s")
print(f"INT8 model: {int8_time:.4f}s")
print(f"Speedup: {fp32_time/int8_time:.2f}x")

# Visualize results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Model sizes
models = ['FP32', 'Dynamic INT8', 'Static INT8']
sizes = [size_fp32, size_int8, size_static]
ax1.bar(models, sizes)
ax1.set_ylabel('Size (MB)')
ax1.set_title('Model Size Comparison')

# Inference times
models_time = ['FP32', 'Static INT8']
times = [fp32_time, int8_time]
ax2.bar(models_time, times)
ax2.set_ylabel('Time (seconds)')
ax2.set_title('Inference Time Comparison (50 iterations)')

plt.tight_layout()
plt.show()

## 5. Model Serving Example

Here's an example of how to create a simple REST API for model serving.

In [None]:
# Create a simple inference function
def create_inference_function(model_path):
    """
    Create an inference function for the model.
    """
    # Load model
    model = torch.jit.load(model_path)
    model.eval()
    
    # Define preprocessing
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    def predict(image):
        """
        Make prediction on a PIL image.
        """
        # Preprocess
        tensor = transform(image).unsqueeze(0)
        
        # Predict
        with torch.no_grad():
            output = model(tensor)
            probabilities = F.softmax(output, dim=1)
            prediction = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][prediction].item()
        
        return {
            'prediction': prediction,
            'confidence': confidence,
            'probabilities': probabilities[0].tolist()
        }
    
    return predict

# Test the inference function
predict_fn = create_inference_function(traced_path)

# Create a test image
from PIL import Image, ImageDraw

# Create a simple digit image
img = Image.new('L', (28, 28), 0)
draw = ImageDraw.Draw(img)
draw.text((10, 5), '7', fill=255)

# Make prediction
result = predict_fn(img)
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']:.2%}")

In [None]:
# Flask API example code
flask_api_code = '''
# save as app.py
from flask import Flask, request, jsonify
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import io
import base64

app = Flask(__name__)

# Load model at startup
model = torch.jit.load('traced_model.pt')
model.eval()

# Define preprocessing
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

@app.route('/predict', methods=['POST'])
def predict():
    try:
        # Get image from request
        data = request.get_json()
        image_base64 = data['image']
        
        # Decode and preprocess
        image_bytes = base64.b64decode(image_base64)
        image = Image.open(io.BytesIO(image_bytes)).convert('L')
        tensor = transform(image).unsqueeze(0)
        
        # Make prediction
        with torch.no_grad():
            output = model(tensor)
            probabilities = F.softmax(output, dim=1)
            prediction = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][prediction].item()
        
        return jsonify({
            'prediction': prediction,
            'confidence': confidence,
            'probabilities': probabilities[0].tolist()
        })
    
    except Exception as e:
        return jsonify({'error': str(e)}), 400

@app.route('/health', methods=['GET'])
def health():
    return jsonify({'status': 'healthy'})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)
'''

# Save the Flask app code
with open(os.path.join(output_dir, 'flask_app.py'), 'w') as f:
    f.write(flask_api_code)

print("Flask API code saved to flask_app.py")
print("\nTo run the server:")
print("1. pip install flask pillow")
print("2. python flask_app.py")
print("3. Send POST requests to http://localhost:5000/predict")

## 6. Mobile Deployment

PyTorch Mobile allows models to run on iOS and Android devices.

In [None]:
# Optimize model for mobile
from torch.utils.mobile_optimizer import optimize_for_mobile

# Load the traced model
mobile_model = torch.jit.load(traced_path)
mobile_model.eval()

# Optimize for mobile
optimized_model = optimize_for_mobile(mobile_model)

# Save for mobile
mobile_path = os.path.join(output_dir, "model_mobile.ptl")
optimized_model._save_for_lite_interpreter(mobile_path)
print(f"Mobile model saved to {mobile_path}")
print(f"Mobile model size: {os.path.getsize(mobile_path) / 1024:.2f} KB")

In [None]:
# Example Android code
android_code = '''
// Android (Java) example
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

// Load model
Module module = Module.load(assetFilePath(this, "model_mobile.ptl"));

// Prepare input
Bitmap bitmap = ... // Your input image
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
    bitmap,
    new float[]{0.1307f}, // mean
    new float[]{0.3081f}  // std
);

// Run inference
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();

// Find prediction
int maxIdx = 0;
float maxScore = scores[0];
for (int i = 1; i < scores.length; i++) {
    if (scores[i] > maxScore) {
        maxScore = scores[i];
        maxIdx = i;
    }
}
'''

print("Android integration example:")
print(android_code)

## 7. Deployment Best Practices

Summary of key considerations for production deployment.

In [None]:
# Create a deployment checklist
deployment_checklist = """
# PyTorch Model Deployment Checklist

## 1. Model Optimization
- [ ] Convert to TorchScript (tracing or scripting)
- [ ] Apply quantization if needed (dynamic or static)
- [ ] Optimize for target hardware (mobile, edge, server)
- [ ] Profile and benchmark performance

## 2. Input/Output Handling
- [ ] Document input format and preprocessing steps
- [ ] Implement input validation
- [ ] Define output format and post-processing
- [ ] Handle edge cases and errors gracefully

## 3. Serving Infrastructure
- [ ] Choose deployment method (REST API, gRPC, batch)
- [ ] Implement health checks and monitoring
- [ ] Set up logging and metrics collection
- [ ] Configure auto-scaling if needed

## 4. Testing
- [ ] Test on representative data
- [ ] Verify numerical accuracy after optimization
- [ ] Load test for expected traffic
- [ ] Test on target deployment hardware

## 5. Security
- [ ] Validate and sanitize all inputs
- [ ] Use HTTPS for API endpoints
- [ ] Implement authentication if needed
- [ ] Set up rate limiting

## 6. Maintenance
- [ ] Version your models
- [ ] Plan for model updates
- [ ] Monitor model performance over time
- [ ] Set up A/B testing framework
"""

# Save checklist
with open(os.path.join(output_dir, 'deployment_checklist.md'), 'w') as f:
    f.write(deployment_checklist)

print("Deployment checklist saved!")
print("\nKey deployment formats created:")
for file in os.listdir(output_dir):
    size = os.path.getsize(os.path.join(output_dir, file)) / 1024
    print(f"- {file}: {size:.2f} KB")

## Summary

In this notebook, we've covered:

1. **TorchScript**: Converting models using tracing and scripting for production deployment
2. **ONNX Export**: Exporting models for cross-framework compatibility
3. **Quantization**: Reducing model size and improving speed with INT8 quantization
4. **Model Serving**: Creating REST APIs for model inference
5. **Mobile Deployment**: Optimizing models for iOS and Android
6. **Best Practices**: Key considerations for production deployment

Each method has its trade-offs:
- **TorchScript**: Best for PyTorch-native deployment with good performance
- **ONNX**: Best for cross-framework compatibility
- **Quantization**: Best for edge devices and mobile deployment
- **Mobile**: Purpose-built for iOS/Android with smallest size

Choose the right approach based on your deployment target and requirements!