# üìò Notebook 3: Batch Inference

## Weekly Batch Inference for Wafer Yield Prediction

This notebook demonstrates **production-ready batch inference** using ML Jobs.

**What this does:**
1. Loads feature-engineered data from ML Dataset (feature store)
2. Runs batch inference as an ML Job on GPU compute pool
3. Saves predictions to table with metadata
4. Ready to schedule weekly (Tasks, Airflow, etc.)

**Key Pattern:** ML Jobs automatically manage compute pool lifecycle - no manual start/stop needed.

---

In [None]:
# ============================================================================
# SETUP
# ============================================================================

from snowflake.snowpark.context import get_active_session
from snowflake.ml.jobs import remote
from snowflake.ml import dataset

session = get_active_session()

# Set context
session.sql("USE DATABASE WAFER_YIELD_DEMO").collect()
session.sql("USE SCHEMA RAW_DATA").collect()

print("‚úÖ Setup complete")
print(f"   Database: {session.get_current_database()}")
print(f"   Schema: {session.get_current_schema()}")

In [None]:
# ============================================================================
# PREPARE INFERENCE DATA FROM FEATURE STORE
# ============================================================================

# Load ML Dataset (feature-engineered data from Notebook 01)
print("üì¶ Loading ML Dataset (feature store)...")
ds = dataset.load_dataset(
    session, 
    "WAFER_YIELD_DEMO.RAW_DATA.WAFER_YIELD_TRAINING_DATASET", 
    version="v1"
)
df = ds.read.to_snowpark_dataframe()

print(f"‚úÖ Loaded dataset: {df.count()} rows, {len(df.columns)} columns")

# Materialize to table for ML Job access
# (ML Jobs can't access ML Datasets directly due to permissions)
print("\nüìä Creating inference input table...")
df.write.mode("overwrite").save_as_table("WAFER_INFERENCE_INPUT")

print(f"‚úÖ Created: WAFER_INFERENCE_INPUT")
print(f"   Contains feature-engineered data for inference")

---

## Define Batch Inference Job

The `@remote()` decorator runs this function as an ML Job on the compute pool.

**Key points:**
- Compute pool auto-resumes when job starts
- Compute pool auto-suspends after job completes
- No manual pool management needed

---

In [None]:
# ============================================================================
# BATCH INFERENCE JOB
# ============================================================================

@remote("WAFER_TRAINING_POOL", stage_name="inference_payload")
def run_weekly_inference(
    database: str,
    schema: str,
    model_name: str,
    input_table: str,
    output_table: str
):
    """
    Weekly batch inference job.
    
    Runs as ML Job - compute pool lifecycle managed automatically.
    Perfect for scheduling with Snowflake Tasks or Airflow.
    """
    import torch
    from snowflake.snowpark import Session
    from snowflake.ml.registry import Registry
    
    session = Session.builder.getOrCreate()
    
    session.sql(f"USE DATABASE {database}").collect()
    session.sql(f"USE SCHEMA {schema}").collect()
    print(f"üìã Context: {database}.{schema}")
    
    registry = Registry(session=session)
    mv = registry.get_model(model_name).default
    print(f"‚úÖ Model: {model_name} v{mv.version_name}")
    
    input_df = session.table(input_table)
    print(f"‚úÖ Data: {input_table} ({input_df.count()} rows)")
    
    exclude_cols = ['WAFER_ID', 'YIELD_GOOD', 'YIELD_SCORE', 'DOMINANT_DEFECT_TYPE']
    feature_cols = [c for c in input_df.columns if c.upper() not in [x.upper() for x in exclude_cols]]
    
    print(f"üìä Running inference on {len(feature_cols)} features...")
    
    model_obj = mv.load()
    model_obj.eval()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_obj = model_obj.to(device)
    print(f"‚úÖ Model loaded to {device}")
    
    input_pandas = input_df.select(feature_cols).to_pandas()
    input_tensor = torch.FloatTensor(input_pandas.values).to(device)
    
    with torch.no_grad():
        predictions_tensor = model_obj(input_tensor)
    
    predictions = predictions_tensor.cpu().numpy()
    
    from snowflake.snowpark.types import DoubleType, TimestampType, StringType
    from snowflake.snowpark import Row
    from datetime import datetime
    
    timestamp = datetime.now()
    
    rows = [
        Row(
            OUTPUT_FEATURE_0=float(pred[0]),
            INFERENCE_TIMESTAMP=timestamp,
            MODEL_VERSION=mv.version_name
        )
        for pred in predictions
    ]
    
    predictions_df = session.create_dataframe(rows)
    predictions_df.write.mode("overwrite").save_as_table(output_table)
    
    result_count = session.table(output_table).count()
    print(f"‚úÖ Saved {result_count} predictions to {output_table}")
    
    return result_count

print("‚úÖ Inference job function defined")

---

## Run Batch Inference

Submit the job and wait for completion.

---

In [None]:
# ============================================================================
# SUBMIT INFERENCE JOB
# ============================================================================

print("üöÄ Submitting batch inference job...")

# Submit the job
job = run_weekly_inference(
    database=session.get_current_database(),
    schema=session.get_current_schema(),
    model_name="WAFER_YIELD_DDP_MODEL",
    input_table="WAFER_INFERENCE_INPUT",
    output_table="WAFER_YIELD_PREDICTIONS"
)

print(f"\n‚úÖ Job submitted: {job.id}")
print(f"üìä Status: {job.status}")
print(f"\n‚è≥ Waiting for completion (first run may take 3-5 min for pool startup)...")

In [None]:
# ============================================================================
# GET RESULTS
# ============================================================================

# Wait for job to finish
result = job.result()

print(f"\n‚úÖ Job complete!")
print(f"   Generated {result} predictions")
print(f"   Saved to: WAFER_YIELD_PREDICTIONS")

# Show job logs
print(f"\nüìã Job logs:")
print("=" * 60)
job.show_logs()
print("=" * 60)

In [None]:
# ============================================================================
# ANALYZE PREDICTIONS
# ============================================================================

print("üìä Prediction Results")
print("=" * 60)

predictions = session.table("WAFER_YIELD_PREDICTIONS")

# Sample predictions
print("\n1Ô∏è‚É£ Sample predictions:")
predictions.select(
    "OUTPUT_FEATURE_0",
    "INFERENCE_TIMESTAMP",
    "MODEL_VERSION"
).show(10)

# Statistics
total = predictions.count()
model_version = predictions.select('MODEL_VERSION').first()[0]
timestamp = predictions.select('INFERENCE_TIMESTAMP').first()[0]

print(f"\n2Ô∏è‚É£ Summary:")
print(f"   Total predictions: {total}")
print(f"   Model version: {model_version}")
print(f"   Timestamp: {timestamp}")

# Distribution
print(f"\n3Ô∏è‚É£ Prediction distribution:")
predictions.select("OUTPUT_FEATURE_0").describe().show()

print("\n‚úÖ Inference complete and verified!")

---

## How to Schedule Weekly

This inference job can be scheduled to run automatically:

### Option 1: Snowflake Tasks

```sql
CREATE TASK WEEKLY_INFERENCE_TASK
    WAREHOUSE = COMPUTE_WH
    SCHEDULE = 'USING CRON 0 2 * * 1 UTC'  -- Monday 2 AM
AS
    CALL RUN_INFERENCE_PROCEDURE();

ALTER TASK WEEKLY_INFERENCE_TASK RESUME;
```

### Option 2: Python Script + Cron

```python
# weekly_inference.py
from snowflake.snowpark import Session
from snowflake.ml.jobs import remote

session = Session.builder.getOrCreate()

@remote("WAFER_TRAINING_POOL", stage_name="inference_payload")
def run_weekly_inference(...):
    # ... (same as above)
    pass

job = run_weekly_inference(...)
result = job.result()
print(f"Completed: {result} predictions")
```

```bash
# Cron: Every Monday at 2 AM
0 2 * * 1 python weekly_inference.py
```

### Option 3: Airflow DAG

```python
# dags/wafer_inference.py
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime

def run_inference():
    # ... ML Job code ...
    return job.result()

with DAG(
    'wafer_weekly_inference',
    schedule='0 2 * * 1',
    start_date=datetime(2024, 1, 1),
    catchup=False
) as dag:
    PythonOperator(
        task_id='batch_inference',
        python_callable=run_inference
    )
```

**See Notebook 04 for detailed orchestration examples.**

---

---

## Summary

### What We Did

1. ‚úÖ Loaded feature-engineered data from ML Dataset (feature store)
2. ‚úÖ Materialized to table for ML Job access
3. ‚úÖ Defined inference function with `@remote()` decorator
4. ‚úÖ Submitted job to GPU compute pool
5. ‚úÖ Compute pool auto-started, ran inference, auto-stopped
6. ‚úÖ Saved predictions with metadata

### Key Takeaways

| Aspect | Implementation |
|--------|----------------|
| **Data Source** | ML Dataset (feature store) from Notebook 01 |
| **Compute** | GPU compute pool (auto-managed by ML Jobs) |
| **Pattern** | `@remote()` decorator for ML Jobs |
| **Scheduling** | Ready for Tasks, Airflow, or cron |
| **Cost** | Pool only runs during job execution |

### Next Steps

**Notebook 04: Production Orchestration**
- Schedule weekly inference with Tasks
- Build Airflow DAGs for complex workflows  
- Add monitoring and alerting
- Configure retries and error handling

---