![tracker](https://us-central1-vertex-ai-mlops-369716.cloudfunctions.net/pixel-tracking?path=statmike%2Fvertex-ai-mlops%2FFramework+Workflows%2FPyTorch%2Fserving&file=dataflow-streaming-runinference.ipynb)
# Dataflow Streaming Inference with RunInference

Real-time anomaly detection using Dataflow streaming with PyTorch RunInference.

## Architecture
```
Pub/Sub Input → Dataflow (RunInference) → Pub/Sub Output + BigQuery
```

---
## Environment Setup

In [None]:
PROJECT_ID = 'statmike-mlops-349915'
REQ_TYPE = 'ALL'
INSTALL_TOOL = 'poetry'

In [None]:
REQUIREMENTS_URL = 'https://raw.githubusercontent.com/statmike/vertex-ai-mlops/refs/heads/main/Framework%20Workflows/PyTorch/requirements.txt'
REQUIRED_APIS = ["dataflow.googleapis.com", "pubsub.googleapis.com", "bigquery.googleapis.com", "storage.googleapis.com"]

In [None]:
import os, urllib.request
url = 'https://raw.githubusercontent.com/statmike/vertex-ai-mlops/refs/heads/main/core/notebook-template/python_setup.py'
urllib.request.urlretrieve(url, 'python_setup_local.py')
import python_setup_local as python_setup
os.remove('python_setup_local.py')
setup_info = python_setup.setup_environment(PROJECT_ID, REQ_TYPE, REQUIREMENTS_URL, REQUIRED_APIS, INSTALL_TOOL)

---
## Python Setup

In [None]:
import subprocess
import apache_beam as beam
from apache_beam import window
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.options.pipeline_options import PipelineOptions, StandardOptions
from apache_beam.io.gcp.pubsub import ReadFromPubSub, WriteToPubSub
from apache_beam.io.gcp.bigquery import WriteToBigQuery
import torch
import json
from datetime import datetime

In [None]:
PROJECT_ID = subprocess.run(["gcloud", "config", "get-value", "project"], capture_output=True, text=True, check=True).stdout.strip()
REGION = "us-central1"
SERIES = "frameworks"
EXPERIMENT = "pytorch-autoencoder"
BUCKET_URI = f"gs://{PROJECT_ID}"
MODEL_PATH = f"{BUCKET_URI}/{SERIES}/{EXPERIMENT}/dataflow/final_model_traced.pt"
INPUT_SUB = f"projects/{PROJECT_ID}/subscriptions/{EXPERIMENT}-input-sub"
OUTPUT_TOPIC = f"projects/{PROJECT_ID}/topics/{EXPERIMENT}-output"
BQ_DATASET = SERIES.replace("-", "_")
BQ_TABLE = f"{EXPERIMENT}_streaming_results"
print(f"Input: {INPUT_SUB}")
print(f"Output: {OUTPUT_TOPIC}")

---
## Create ModelHandler

In [None]:
class PyTorchAutoencoderHandler(PytorchModelHandlerTensor):
    def run_inference(self, batch, model, inference_args=None):
        with torch.no_grad():
            predictions = model(batch)
        results = []
        for i in range(len(batch)):
            results.append({"anomaly_score": float(predictions["denormalized_MAE"][i].item()), "encoded": predictions["encoded"][i].tolist()})
        return results

model_handler = PyTorchAutoencoderHandler(state_dict_path=MODEL_PATH, model_class=None, device="cpu")
print("✅ ModelHandler created")

---
## Build Streaming Pipeline

In [None]:
def parse_json(message):
    """Parse Pub/Sub message"""
    data = json.loads(message.decode("utf-8"))
    return torch.tensor(data["features"], dtype=torch.float32)

def format_result(element, window=beam.DoFn.WindowParam):
    """Format for Pub/Sub and BigQuery"""
    prediction = element[1]
    return {
        "instance_id": str(hash(str(element[0]))),
        "anomaly_score": prediction["anomaly_score"],
        "encoded": prediction["encoded"],
        "timestamp": datetime.utcnow().isoformat(),
        "window_start": window.start.to_utc_datetime().isoformat(),
        "window_end": window.end.to_utc_datetime().isoformat()
    }

def to_json(element):
    """Convert to JSON for Pub/Sub"""
    return json.dumps(element).encode("utf-8")

options = PipelineOptions([
    f"--project={PROJECT_ID}",
    f"--region={REGION}",
    "--runner=DataflowRunner",
    f"--temp_location={BUCKET_URI}/dataflow/temp",
    f"--staging_location={BUCKET_URI}/dataflow/staging",
    f"--job_name=pytorch-streaming-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
    "--streaming",
    "--save_main_session=True"
])

print("✅ Streaming pipeline configured")

### Run Streaming Job

In [None]:
p = beam.Pipeline(options=options)

results = (
    p
    | "Read from Pub/Sub" >> ReadFromPubSub(subscription=INPUT_SUB)
    | "Parse JSON" >> beam.Map(parse_json)
    | "Window (1 min)" >> beam.WindowInto(window.FixedWindows(60))
    | "RunInference" >> RunInference(model_handler)
    | "Format results" >> beam.Map(format_result)
)

# Write to Pub/Sub
_ = results | "To JSON" >> beam.Map(to_json) | "Write to Pub/Sub" >> WriteToPubSub(topic=OUTPUT_TOPIC)

# Write to BigQuery
_ = results | "Write to BigQuery" >> WriteToBigQuery(
    table=f"{PROJECT_ID}:{BQ_DATASET}.{BQ_TABLE}",
    write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND
)

result = p.run()
print("\n✅ Streaming job started!")
print(f"Monitor: https://console.cloud.google.com/dataflow/jobs/{REGION}?project={PROJECT_ID}")
print("\n⚠️  Job will run continuously until canceled")

---
## Simulate Streaming Data

In [None]:
from google.cloud import pubsub_v1
import time

publisher = pubsub_v1.PublisherClient()
topic_path = publisher.topic_path(PROJECT_ID, f"{EXPERIMENT}-input")

# Send test messages
for i in range(5):
    message = {"features": [0.1] * 30}  # Dummy transaction
    publisher.publish(topic_path, json.dumps(message).encode("utf-8"))
    print(f"Published message {i+1}")
    time.sleep(2)

print("\n✅ Sent 5 test messages")

### Monitor Results

In [None]:
from google.cloud import bigquery
bq = bigquery.Client(project=PROJECT_ID)
query = f"SELECT * FROM `{PROJECT_ID}.{BQ_DATASET}.{BQ_TABLE}` ORDER BY timestamp DESC LIMIT 10"
df = bq.query(query).to_dataframe()
print(f"Latest {len(df)} results:")
df

---
## Clean Up

⚠️ **Important**: Cancel streaming job to stop charges

```python
# Cancel job in Cloud Console or use:
# gcloud dataflow jobs cancel JOB_ID --region=us-central1
```

---
## Summary

✅ Built streaming Dataflow pipeline

✅ Real-time RunInference with PyTorch

✅ Windowed processing (1-min windows)

✅ Dual output (Pub/Sub + BigQuery)

### Next: [Vertex Endpoint Integration](./dataflow-vertex-endpoint.ipynb)