# OpenAI-Compatible Batch Inference Demo

This notebook demonstrates how to use the batch inference API with OpenAI-compatible endpoints.

In [None]:
!pip install requests
!pip install -r ../requirements.dev

## 1. Import Libraries and Setup

In [None]:
import requests
import json
import time
import logging
from IPython.display import JSON, display

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

BASE_URL = "http://localhost:8000"

## 2. Start API Server

In [None]:
import subprocess
import os
import time
import threading
from IPython.display import display, Markdown

# Start server in background
def start_api_server():
    env = os.environ.copy()
    env["ENVIRONMENT"] = "DEV"
    env["GPU_AVAILABLE"] = "false"
    
    cmd = ["python", "api/main.py"]
    
    return subprocess.Popen(
        cmd,
        cwd="/Users/michaelsigamani/Documents/DevelopmentCode/2025-fall/PoC-offline-batch-inference",
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
# Start the server
print("Starting API server...")
server_process = start_api_server()
time.sleep(3)  # Wait for server to start
display(Markdown("Server started on http://localhost:8000"))

## 3. Load Sample Data

In [None]:
with open('sample_batch.jsonl', 'r') as f:
    prompts = [json.loads(line)['prompt'] for line in f if line.strip()]

logger.info(f"Loaded {len(prompts)} sample prompts:")
for i, prompt in enumerate(prompts, 1):
    logger.info(f"{i}. {prompt}")

## 4. Submit Initial Batch Job

In [None]:
batch_request = {
    "model": "Qwen/Qwen2.5-0.5B-Instruct",
    "input": [{"prompt": m} for m in prompts],
    "max_tokens": 100,
    "temperature": 0.7
}

logger.info("Submitting batch job...")
response = requests.post(
    f"{BASE_URL}/v1/batches",
    json=batch_request
)

if response.status_code == 200:
    data = response.json()
    batch_id = data["id"]
    logger.info(f"Batch created with ID: {batch_id}")
    logger.info(f"Status: {data['status']}")
    logger.info(f"Created at: {data['created_at']}")
else:
    logger.error(f"Failed to create batch: {response.status_code}")
    logger.error(response.text)

## 5. Monitor Job Progress

In [None]:
def check_job_status(batch_id):
    response = requests.get(f"{BASE_URL}/v1/batches/{batch_id}")
    if response.status_code == 200:
        data = response.json()
        return data
    else:
        logger.error(f"Failed to get status: {response.status_code}")
        return None

logger.info("Monitoring job progress...")
for i in range(10):  
    status_data = check_job_status(batch_id)
    if status_data:
        status = status_data["status"]
        logger.info(f"Check {i+1}: {status}")
        if status in ["completed", "failed"]:
            break
    time.sleep(2.0)
else:
    logger.warning("Timeout or error checking status")

## 6. Retrieve Results

In [None]:
logger.info("Retrieving final results...")
response = requests.get(f"{BASE_URL}/v1/batches/{batch_id}/results")

if response.status_code == 200:
    data = response.json()
    results = data.get("data", [])
    logger.info(f"Retrieved {len(results)} results:")
    display(JSON(results))
    
    for i, result in enumerate(results[:3], 1):
        prompt = result.get("prompt", "")
        response_text = result.get("response", "")
        tokens = result.get("tokens", 0)
        logger.info(f"--- Result {i} ---")
        logger.info(f"Prompt: {prompt}")
        logger.info(f"Response: {response_text}")
        logger.info(f"Tokens: {tokens}")
else:
    logger.error(f"Failed to get results: {response.status_code}")

## 7. Debug Queue Issues

In [None]:
# Check job status to see if it is still processing
response = requests.get(f"{BASE_URL}/v1/batches/{batch_id}")
if response.status_code == 200:
    data = response.json()
    print(f"Job ID: {data['id']}")
    print(f"Status: {data['status']}")
    print(f"Created at: {data['created_at']}")
    print(f"Completed at: {data.get('completed_at', 'Not completed yet')}")
else:
    print(f"Failed to get status: {response.status_code}")

In [None]:
# Check debug endpoints to see what is happening
response = requests.get(f"{BASE_URL}/debug/worker")
if response.status_code == 200:
    debug_info = response.json()
    print(f"Worker running: {debug_info['worker_running']}")
    print(f"Queue depth: {debug_info['queue_depth']}")
    print(f"Worker thread alive: {debug_info['worker_thread_alive']}")
else:
    print("Debug endpoint not available")
# Check GPU pool status
response = requests.get(f"{BASE_URL}/debug/gpu-pools")
if response.status_code == 200:
    pool_info = response.json()
    print(f"GPU pools: {pool_info}")
else:
    print("GPU debug endpoint not available")

In [None]:
import time
time.sleep(5)
response = requests.get(f"{BASE_URL}/debug/worker")
if response.status_code == 200:
    debug_info = response.json()
    print(f"Queue depth after 5 seconds: {debug_info['queue_depth']}")
    print(f"Worker running: {debug_info['worker_running']}")
    print(f"Worker thread alive: {debug_info['worker_thread_alive']}")
else:
    print("Debug endpoint not available")

In [None]:
# Check queue state before submission
response = requests.get(f"{BASE_URL}/debug/worker")
if response.status_code == 200:
    debug_info = response.json()
    print(f"Queue depth BEFORE: {debug_info['queue_depth']}")

# Submit a new job
batch_request = {
    "model": "Qwen/Qwen2.5-0.5B-Instruct", 
    "input": [{"prompt": "Debug test prompt"}],
    "max_tokens": 50,
    "temperature": 0.7
}

print("Submitting new batch...")
response = requests.post(f"{BASE_URL}/v1/batches", json=batch_request)
if response.status_code == 200:
    data = response.json()
    batch_id = data["id"]
    print(f"Submitted batch ID: {batch_id}")
else:
    print(f"Failed to submit: {response.status_code}")
    print(response.text)

# Check queue depth immediately after submission
response = requests.get(f"{BASE_URL}/debug/worker")
if response.status_code == 200:
    debug_info = response.json()
    print(f"Queue depth AFTER: {debug_info['queue_depth']}")

# Check if job file was created
import os
import json

if 'batch_id' in locals():
    job_file = f"/tmp/job_{batch_id}.json"
    print(f"Job file exists: {os.path.exists(job_file)}")
    
    if os.path.exists(job_file):
        with open(job_file, 'r') as f:
            job_data = json.load(f)
        print(f"Job file status: {job_data.get('status', 'No status')}")
        print(f"Job file created at: {job_data.get('created_at', 'No timestamp')}")

In [None]:
import time
time.sleep(5)
response = requests.get(f"{BASE_URL}/debug/worker")
if response.status_code == 200:
    debug_info = response.json()
    print(f"Queue depth after 5 seconds: {debug_info['queue_depth']}")
    print(f"Worker running: {debug_info['worker_running']}")
    print(f"Worker thread alive: {debug_info['worker_thread_alive']}")

response = requests.get(f"{BASE_URL}/debug/queue-contents")
if response.status_code == 200:
    queue_data = response.json()
    print(f"Queue contents: {queue_data}")
else:
    print("Queue contents endpoint not available")

In [None]:
test_batch_request = {
    "model": "Qwen/Qwen2.5-0.5B-Instruct",
    "input": [{"prompt": "Test after queueing fix"}],
    "max_tokens": 50,
    "temperature": 0.7
}

print("\n== Testing Fixed Queueing ===")
response = requests.post(f"{BASE_URL}/v1/batches", json=test_batch_request)
if response.status_code == 200:
    data = response.json()
    test_batch_id = data["id"]
    print(f"Test batch created: {test_batch_id}")
    print(f"Status: {data['status']}")
else:
    print(f"Failed to create test batch: {response.status_code}")
    print(response.text)

# Wait a bit and check if job processes
time.sleep(10)  # Give time for processing

# Check final status
response = requests.get(f"{BASE_URL}/v1/batches/{test_batch_id}")
if response.status_code == 200:
    status_data = response.json()
    final_status = status_data["status"]
    print(f"Final status: {final_status}")
    
    if final_status == "completed":
        # Try to get results
        response = requests.get(f"{BASE_URL}/v1/batches/{test_batch_id}/results")
        if response.status_code == 200:
            results_data = response.json()
            results = results_data.get("data", [])
            print(f" Successfully retrieved {len(results)} results!")
            for i, result in enumerate(results[:2], 1):
                prompt = result.get("prompt", "")
                response_text = result.get("response", "")
                tokens = result.get("tokens", 0)
                print(f"Result {i}: {response_text[:50]}...")
        else:
            print(f"Results endpoint returned: {response.status_code}")
    else:
        print(f"Job ended with status: {final_status}")
else:
    print(f"Failed to check final status: {response.status_code}")

print("\n=== Test Complete ===")

## 10. Summary

This notebook demonstrates the complete OpenAI-compatible batch inference workflow:

### **Setup Steps:**
1. **Start API Server** - Launch FastAPI server in DEV mode
2. **Load Sample Data** - Read test prompts from JSONL file
3. **Submit Batch Job** - Create batch via POST request
4. **Monitor Progress** - Poll job status until completion
5. **Retrieve Results** - Get processed results when complete

### **Debugging Tools:**
- **Worker Debug** - Check `/debug/worker` for queue depth and worker status
- **GPU Pool Debug** - Check `/debug/gpu-pools` for resource allocation
- **Queue Contents** - Check `/debug/queue-contents` to see actual queue state
- **Job Status** - Direct status checks for individual batches

### **PoC Features:**
- **Thread-safe operations** - Concurrent request handling
- **GPU resource management** - Spot/dedicated pool allocation
- **SLA-aware scheduling** - Priority-based job processing
- **Error handling** - Proper retries and fallback behavior
- **OpenAI compatibility** - Standard batch API endpoints