# Distributed Multi-Modal Inference with Snowflake `run_batch`

## MedGemma ECG Image Classification Demo

**Key Takeaway:** This demo showcases how Snowflake's `run_batch` API enables **distributed GPU inference at scale** with just a few lines of code.

### Why This Matters
| Feature | Benefit |
|---------|----------|
| **Single API call** | `run_batch()` handles all distribution, scheduling, and scaling |
| **GPU acceleration** | Automatic GPU provisioning via Snowpark Container Services |
| **Multi-modal** | Process images + text together natively |
| **Enterprise-ready** | Data never leaves Snowflake's secure environment |

### Prerequisites
- Kaggle account (free) - Get your API key from [kaggle.com/settings](https://www.kaggle.com/settings)
- Database `MEDGEMMA_DEMO` with compute pool `MEDGEMMA_COMPUTE_POOL`

## Setup: Install Dependencies & Configure Kaggle

> Enter your Kaggle credentials below. Get your API key from **Settings → API → Create New Token** at [kaggle.com/settings](https://www.kaggle.com/settings)

In [None]:
import sys
sys.path.insert(0, '/tmp/whl')

from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Install required packages
session.file.get("@MEDGEMMA_DEMO.PUBLIC.WHL_FILE/snowflake_ml_python-1.27.0-py3-none-any.whl", "/tmp/whl/")
!pip install /tmp/whl/snowflake_ml_python-1.27.0-py3-none-any.whl --force-reinstall --quiet
!pip install -q kagglehub

In [None]:
import os

# === ENTER YOUR KAGGLE API KEY HERE ===
KAGGLE_KEY = "YOUR_KAGGLE_API_KEY"  # Get from kaggle.com/settings → API → Create New Token
# ======================================

os.environ["KAGGLE_KEY"] = KAGGLE_KEY
print("Kaggle API key configured")

In [None]:
import kagglehub

dataset_path = kagglehub.dataset_download("evilspirit05/ecg-analysis")
print(f"Dataset downloaded to: {dataset_path}")

# Count images
import glob
all_images = glob.glob(f"{dataset_path}/**/*.jpg", recursive=True)
all_images += glob.glob(f"{dataset_path}/**/*.png", recursive=True)
print(f"Found {len(all_images)} ECG images")

## Upload Images to Snowflake Stage

> This uploads the ECG images to a Snowflake stage for batch inference

In [None]:
DB_NAME = "MEDGEMMA_DEMO"
SCHEMA_NAME = "PUBLIC"
STAGE_NAME = "ECG_STAGE"
OUTPUT_STAGE_NAME = "ECG_BATCH_OUTPUT_STAGE"

session.sql(f"USE DATABASE {DB_NAME}").collect()
session.sql(f"USE SCHEMA {SCHEMA_NAME}").collect()
session.sql(f"CREATE STAGE IF NOT EXISTS {STAGE_NAME} DIRECTORY = (ENABLE = TRUE)").collect()
session.sql(f"CREATE STAGE IF NOT EXISTS {OUTPUT_STAGE_NAME} DIRECTORY = (ENABLE = TRUE)").collect()

print(f"Using: {DB_NAME}.{SCHEMA_NAME}")
print(f"Stages: {STAGE_NAME}, {OUTPUT_STAGE_NAME}")

In [None]:
stage_path = f"@{DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}"

print(f"Uploading {len(all_images)} images to {stage_path}...")
uploaded = 0
for img_path in all_images:
    rel_path = os.path.relpath(img_path, dataset_path)
    folder_name = os.path.dirname(rel_path)
    target = f"{stage_path}/{folder_name}" if folder_name else stage_path
    session.file.put(img_path, target, auto_compress=False, overwrite=False)
    uploaded += 1
    if uploaded % 100 == 0:
        print(f"  Uploaded {uploaded}/{len(all_images)}...")

session.sql(f"ALTER STAGE {STAGE_NAME} REFRESH").collect()
print(f"Upload complete! {uploaded} images in stage")

## Step 1: Load Model from Registry

> Get the pre-registered MedGemma model from Snowflake Model Registry

In [None]:
from snowflake.ml.registry import Registry

registry = Registry(session=session, database_name=DB_NAME, schema_name=SCHEMA_NAME)
mv = registry.get_model('MEDGEMMA_4B_IT_TEST').version('V_2026_02_06__15_25_17')
print(f"Loaded model: {mv.model_name} v{mv.version_name}")

## Step 2: Prepare Batch Input Data

> Format ECG images as multi-modal prompts for the vision model

In [None]:
stage_location = f"@{DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}"
files_df = session.sql(f"LS {stage_location}")
files_df.show(5)

In [None]:
import json
import random

ECG_CLASSIFICATION_PROMPT = """You are a board-certified cardiologist performing a systematic ECG interpretation. Analyze this 12-lead ECG following the standard clinical approach:

**STEP 1 - RATE AND RHYTHM:**
- Calculate the heart rate. Is it bradycardic (<60), normal (60-100), or tachycardic (>100)?
- Is the rhythm regular or irregular? 
- Is there a P wave before every QRS complex?
- If rhythm is irregular, is it irregularly irregular (suggests atrial fibrillation)?

**STEP 2 - P WAVES:**
- Are P waves present and upright in leads I, II, aVF?
- Is P wave morphology normal or abnormal (peaked, notched, absent)?

**STEP 3 - PR INTERVAL:**
- Is the PR interval normal (120-200ms), prolonged (heart block), or short (pre-excitation)?

**STEP 4 - QRS COMPLEX:**
- Is the QRS narrow (<120ms) or wide?
- Are there pathological Q waves (>40ms wide or >25% of R wave amplitude)?
- Q waves in V1-V4 suggest anterior MI
- Q waves in II, III, aVF suggest inferior MI

**STEP 5 - ST SEGMENT (CRITICAL FOR MI):**
- Is there ST elevation? In which leads?
  - V1-V4 elevation = anterior STEMI
  - II, III, aVF elevation = inferior STEMI
  - I, aVL, V5-V6 elevation = lateral STEMI
- Is there ST depression? This may indicate ischemia or reciprocal changes.

**STEP 6 - T WAVES:**
- Are T waves upright, inverted, or hyperacute (tall, peaked)?
- Hyperacute T waves are an early sign of MI
- Deep symmetric T wave inversion may indicate ischemia or evolved MI

**FINAL CLASSIFICATION:**
Based on your systematic analysis above, classify this ECG as ONE of:
- **MYOCARDIAL_INFARCTION**: If you see ST elevation, hyperacute T waves, or acute ischemic changes
- **POST_MI**: If you see pathological Q waves with no acute ST elevation (indicates old/healed MI)
- **ABNORMAL_HEARTBEAT**: If the rhythm is irregular, there are ectopic beats, or arrhythmia is present
- **NORMAL**: ONLY if rate is normal, rhythm is regular sinus, no Q waves, no ST changes, normal T waves

Provide your complete analysis following each step, then state your final classification."""

files_pandas = files_df.to_pandas()

categories = {
    "normal_ecg_images": [],
    "abnormal_heartbeat_ecg_images": [],
    "myocardial_infarction_ecg_images": [],
    "post_mi_history_ecg_images": []
}

for _, row in files_pandas.iterrows():
    name = row['"name"']
    if "ecg_data_new_version" in name and name.endswith(('.jpg', '.png', '.jpeg')):
        stage_path = f"@{DB_NAME}.{SCHEMA_NAME}.{name}"
        for cat in categories.keys():
            if cat in name:
                categories[cat].append(stage_path)
                break

print("Images per category:")
for cat, files in categories.items():
    print(f"  {cat}: {len(files)}")

NUM_PER_CATEGORY = 25
jpg_files = []
for cat, files in categories.items():
    sampled = random.sample(files, min(NUM_PER_CATEGORY, len(files)))
    jpg_files.extend(sampled)
    
random.shuffle(jpg_files)
print(f"\nSelected {len(jpg_files)} images ({NUM_PER_CATEGORY} per category)")

In [None]:
messages_list = []
for jpg_file in jpg_files:
    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a medical AI assistant specialized in ECG analysis."}]},
        {
            "role": "user",
            "content": [
                {"type": "text", "text": ECG_CLASSIFICATION_PROMPT},
                {"type": "image_url", "image_url": {"url": jpg_file}},
            ],
        },
    ]
    messages_list.append(messages)

schema = ["MESSAGES"]
data = [(json.dumps(m),) for m in messages_list]
input_df = session.create_dataframe(data, schema=schema)
print(f"Created input DataFrame with {input_df.count()} rows")

## Step 3: Run Distributed Batch Inference

> **This is the magic** - a single `run_batch()` call processes all images on GPU

In [None]:
from snowflake.ml.model import JobSpec, OutputSpec, SaveMode, InputSpec
from snowflake.ml.model.inference_engine import InferenceEngine

output_location = f"@{DB_NAME}.{SCHEMA_NAME}.{OUTPUT_STAGE_NAME}/results/"

job = mv.run_batch(
    compute_pool="MEDGEMMA_COMPUTE_POOL",
    X=input_df,
    input_spec=InputSpec(params={"temperature": 0.2, "max_tokens": 1024}),
    output_spec=OutputSpec(stage_location=output_location, mode=SaveMode.OVERWRITE),
    job_spec=JobSpec(gpu_requests="1"),
    inference_engine_options={
        "engine": InferenceEngine.VLLM,
        "engine_args_override": [
            "--max-model-len=7048",
            "--gpu-memory-utilization=0.9",
        ]
    }
)

print(f"Batch job started: {job}")
print(f"Processing {len(jpg_files)} images on GPU...")

## Step 4: Analyze Results

> Parse model outputs and visualize classification distribution

In [None]:
results_df = session.read.option("pattern", ".*\\.parquet").parquet(output_location)

print(f"Total results: {results_df.count()}")
results_df.show(5, max_width=500)

In [None]:
import json
import re

results_pandas = results_df.to_pandas()

def extract_classification(choices_data):
    try:
        if isinstance(choices_data, str):
            choices = json.loads(choices_data)
        else:
            choices = choices_data
        
        if choices and len(choices) > 0:
            content = choices[0].get("message", {}).get("content", "")
            content = content.replace("*", "")
            
            patterns = [
                r'classification is[:\s]*(NORMAL|ABNORMAL_HEARTBEAT|MYOCARDIAL_INFARCTION|POST_MI)',
                r'classified as[:\s]*(NORMAL|ABNORMAL_HEARTBEAT|MYOCARDIAL_INFARCTION|POST_MI)',
                r'the ECG is[:\s]*(NORMAL|ABNORMAL_HEARTBEAT|MYOCARDIAL_INFARCTION|POST_MI)',
                r'FINAL CLASSIFICATION[:\s\S]{0,100}(NORMAL|ABNORMAL_HEARTBEAT|MYOCARDIAL_INFARCTION|POST_MI)',
            ]
            
            for pattern in patterns:
                match = re.search(pattern, content, re.IGNORECASE)
                if match:
                    return match.group(1).upper().replace(" ", "_")
    except:
        pass
    return "UNKNOWN"

results_pandas["CLASSIFICATION"] = results_pandas["id"].apply(extract_classification)

print("\n" + "=" * 50)
print("         ECG CLASSIFICATION RESULTS")
print("=" * 50)
counts = results_pandas["CLASSIFICATION"].value_counts()
total = len(results_pandas)
for cls, count in counts.items():
    pct = (count / total) * 100
    bar = "█" * int(pct / 5)
    print(f"{cls:25} {count:3} ({pct:5.1f}%) {bar}")
print("=" * 50)
print(f"{'TOTAL':25} {total:3}")
print("=" * 50)

In [None]:
results_pandas.to_csv("/tmp/ecg_results.csv", index=False)
session.file.put("/tmp/ecg_results.csv", f"@{DB_NAME}.{SCHEMA_NAME}.ECG_BATCH_OUTPUT_STAGE", auto_compress=False, overwrite=True)
print("Results saved to @MEDGEMMA_DEMO.PUBLIC.ECG_BATCH_OUTPUT_STAGE/ecg_results.csv")

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

counts = results_pandas["CLASSIFICATION"].value_counts()
colors = ['#2ecc71', '#e74c3c', '#f39c12', '#9b59b6', '#95a5a6']

axes[0].pie(counts.values, labels=counts.index, autopct='%1.1f%%', colors=colors[:len(counts)])
axes[0].set_title('ECG Classification Distribution', fontsize=14, fontweight='bold')

axes[1].barh(counts.index, counts.values, color=colors[:len(counts)])
axes[1].set_xlabel('Count')
axes[1].set_title('Classification Counts', fontsize=14, fontweight='bold')
for i, v in enumerate(counts.values):
    axes[1].text(v + 0.5, i, str(v), va='center')

plt.tight_layout()
plt.show()

## Summary: The Power of `run_batch`

### What We Demonstrated
| Metric | Value |
|--------|-------|
| **Images Processed** | 100 ECG images |
| **Model Size** | 4B parameters (MedGemma) |
| **Input Type** | Multi-modal (image + text prompt) |
| **Lines of Code** | ~5 lines for batch inference |

### Key Benefits for Healthcare AI

1. **Scale Without Complexity** - Process thousands of medical images with a single API call
2. **Secure by Design** - Data stays in Snowflake, no external transfers required  
3. **GPU Power On-Demand** - Automatic provisioning via Snowpark Container Services
4. **Production-Ready** - Enterprise governance, logging, and monitoring built-in