### Deploy registered AG model in MLFlow for RT Inference

In [6]:
! pip install fastavro

Collecting fastavro
  Downloading fastavro-1.12.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (5.7 kB)
Downloading fastavro-1.12.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.5/3.5 MB[0m [31m74.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastavro
Successfully installed fastavro-1.12.0


In [33]:
import boto3
import fastavro
import matplotlib
import mlflow
import mlflow.sagemaker as mfs
from mlflow import MlflowClient
from sagemaker.serve import SchemaBuilder, ModelBuilder, Mode
import sagemaker

In [34]:
bucket = 'ag-example-timeseries'
avro_prefix = 'avro-inf-stream'

# Create S3 client
s3 = boto3.client("s3")


mlflow_uri = "arn:aws:sagemaker:us-east-1:543531862107:mlflow-tracking-server/ag-ex-timeseries"  # with sagemaker-mlflow plugin
mlflow_experiment   = "autogluon-timeseries"
region      = sagemaker.Session().boto_region_name
session     = sagemaker.Session()
role        = sagemaker.get_execution_role() 

### Generate dummy data

In [10]:
import io, json, time, uuid, threading, datetime as dt
import numpy as np
import pandas as pd
import boto3
from fastavro import writer, parse_schema

In [11]:
ITEM_IDS        = ["A", "B"]                        # a couple of series to demo
FREQ_SECS       = 5                                 # new file cadence
HORIZON         = 24  

In [13]:
AVRO_SCHEMA = {
    "type": "record",
    "name": "TimePoint",
    "fields": [
        {"name": "item_id",        "type": "string"},
        {"name": "timestamp",      "type": "string"},  # ISO 8601
        {"name": "target",         "type": ["null", "double"], "default": None},
        {"name": "random_feature", "type": ["null", "double"], "default": None},
    ],
}
PARSED_SCHEMA = parse_schema(AVRO_SCHEMA)

_stop_stream = {"flag": False}

def _sine(i, base=10.0, noise=0.3):
    return base + 2.0*np.sin(i/6.0) + np.random.randn()*noise

def write_dummy_avro_loop():
    i = 0
    while not _stop_stream["flag"]:
        now = pd.Timestamp.utcnow().floor("s")
        rows = []
        for item in ITEM_IDS:
            rows.append({
                "item_id": item,
                "timestamp": (now).isoformat(),
                "target": float(_sine(i + hash(item)%7)),      # last observed target (optional at inference)
                "random_feature": float(np.random.randn()),    # example past/known covariate
            })
        buf = io.BytesIO()
        writer(buf, PARSED_SCHEMA, rows)
        buf.seek(0)
        key = f"{avro_prefix}{now.strftime('%Y/%m/%d/%H%M%S')}_{uuid.uuid4().hex[:8]}.avro"
        s3.upload_fileobj(buf, bucket, key)
        print(f"[gen] wrote {len(rows)} rows to s3://{bucket}/{key}")
        i += 1
        time.sleep(FREQ_SECS)

# start writer thread
t = threading.Thread(target=write_dummy_avro_loop, daemon=True)
t.start()



[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220031_fd1c6ddd.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220036_ad5dac74.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220041_bd4b67bd.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220046_ea79e7ed.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220051_4906e4b8.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220056_b5381fe1.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220102_b0facca5.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220107_1c6d9ef8.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220112_7781ba66.avro
[gen] wrote 2 rows to s3://ag-example-timeseries/avro-inf-stream2025/09/04/220117_fe0234f7.avro
[gen] wrote 2 rows to s3://ag-example-ti

In [25]:
# Stop it later with:
_stop_stream["flag"] = True
t.join()

### Inference 

In [37]:
mlflow.set_tracking_uri(mlflow_uri)
client = MlflowClient()
registered_model = client.get_registered_model(name="ag_ex_model")
source_path = registered_model.latest_versions[0].source

In [40]:
# Define IO

sample_input = pd.DataFrame({
    "item_id":        ["A", "A", "A", "B", "B"],
    "timestamp":      pd.to_datetime([
        "2025-01-01 00:00:00",
        "2025-01-01 01:00:00",
        "2025-01-01 02:00:00",
        "2025-01-01 00:00:00",
        "2025-01-01 01:00:00",
    ]),
    "target":         [10.0, 11.2, 12.1,  8.7,  9.1],   # can be None at inference if you only pass history elsewhere
    "random_feature": [0.2, -0.1,  0.3, -0.4,  0.0],    # optional past covariate
})

sample_output = pd.DataFrame({
    "item_id":   ["A", "A", "B", "B"],
    "timestamp": pd.to_datetime([
        "2025-01-01 03:00:00", "2025-01-01 04:00:00",
        "2025-01-01 02:00:00", "2025-01-01 03:00:00",
    ]),
    "mean":      [12.6, 13.0,  9.4,  9.7],
    # you could also add quantiles like "0.1", "0.5", "0.9"
})

# Build schema from sample input/output
ts_schema_builder = SchemaBuilder(
    sample_input=sample_input,
    sample_output=sample_output,
)

In [41]:
# Create model builder with the schema builder.
source_path = f"s3://{bucket}/mlflow/mlflow-artifacts/model"

model_builder = ModelBuilder(
    mode=Mode.SAGEMAKER_ENDPOINT,
    schema_builder=ts_schema_builder,
    role_arn=role,
    model_metadata={"MLFLOW_MODEL_PATH": source_path,
                    "SAGEMAKER_DEFAULT_ENDPOINT_NAME": 'ag-ex-endpoint'
                   },
)

# Build & deploy endpoint
built_model = model_builder.build()
predictor = built_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.4xlarge",   # swap to ml.g5 if you need GPU
)

print("Deployed SageMaker endpoint:", predictor.endpoint_name)

ModelBuilder: INFO:     ModelBuilder will collect telemetry to help us better understand our user's needs, diagnose issues, and deliver additional features. To opt out of telemetry, please disable via TelemetryOptOut in intelligent defaults. See https://sagemaker.readthedocs.io/en/stable/overview.html#configuring-and-using-defaults-with-the-sagemaker-python-sdk for more info.


In [14]:
from fastavro import reader as avro_reader

def read_avro_s3_to_df(bucket: str, key: str) -> pd.DataFrame:
    bio = io.BytesIO()
    s3.download_fileobj(bucket, key, bio)
    bio.seek(0)
    records = list(avro_reader(bio))
    df = pd.DataFrame(records)
    if "timestamp" in df.columns:
        df["timestamp"] = pd.to_datetime(df["timestamp"])
    return df


In [None]:
import boto3, time, io
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output

smr = boto3.client("sagemaker-runtime", region_name=REGION)
s3  = boto3.client("s3", region_name=REGION)

def list_new_objects(bucket: str, prefix: str, seen: set[str]) -> list[str]:
    keys = []
    token = None
    while True:
        kw = dict(Bucket=bucket, Prefix=prefix, MaxKeys=1000)
        if token: kw["ContinuationToken"] = token
        resp = s3.list_objects_v2(**kw)
        for obj in resp.get("Contents", []):
            k = obj["Key"]
            if k.endswith(".avro") and k not in seen:
                keys.append(k)
        token = resp.get("NextContinuationToken")
        if not token:
            break
    return sorted(keys)

def invoke_endpoint_df(df: pd.DataFrame) -> pd.DataFrame:
    # Expect the endpoint to accept JSON records with item_id, timestamp, optional target, optional random_feature
    body = df.to_json(orient="records", date_format="iso")
    resp = smr.invoke_endpoint(
        EndpointName=ENDPOINT_NAME,
        ContentType="application/json",
        Accept="application/json",
        Body=body,
    )
    # Response is JSON; read to pandas
    out = pd.read_json(io.BytesIO(resp["Body"].read()))
    # Ensure timestamp is datetime if present
    if "timestamp" in out.columns:
        out["timestamp"] = pd.to_datetime(out["timestamp"])
    return out

def live_infer_and_plot(loop_seconds=300, poll_every=5):
    """
    Polls for new Avro files every `poll_every` seconds for up to `loop_seconds`,
    invokes the endpoint with all new rows, and live-updates a simple line chart.
    """
    seen = set()
    all_preds = []  # store all returned predictions

    start = time.time()
    while time.time() - start < loop_seconds:
        new_keys = list_new_objects(bucket, avro_prefix, seen)
        if new_keys:
            dfs = []
            for k in new_keys:
                try:
                    df = read_avro_s3_to_df(bucket, k)
                    # minimal schema guard
                    if {"item_id","timestamp"}.issubset(df.columns):
                        dfs.append(df[["item_id","timestamp"] + [c for c in ["target","random_feature"] if c in df.columns]])
                except Exception as e:
                    print(f"[read] {k} failed: {e}")
                finally:
                    seen.add(k)

            if dfs:
                batch = pd.concat(dfs, ignore_index=True)
                try:
                    preds = invoke_endpoint_df(batch)
                    # expected columns: item_id, timestamp, mean (and/or quantiles)
                    all_preds.append(preds)
                except Exception as e:
                    print(f"[invoke] failed: {e}")

        # Plot
        if all_preds:
            pred_df = pd.concat(all_preds, ignore_index=True)
            # Normalize expected columns (AutoGluon usually returns 'mean' or '0.5' for median)
            y_col = "mean"
            if y_col not in pred_df.columns:
                # try a common quantile column name
                qcols = [c for c in pred_df.columns if c.startswith("0.") or c.startswith("quantile")]
                y_col = qcols[0] if qcols else None

            clear_output(wait=True)
            plt.figure(figsize=(9, 4))
            if y_col and {"item_id","timestamp", y_col}.issubset(pred_df.columns):
                for item, g in pred_df.sort_values("timestamp").groupby("item_id"):
                    plt.plot(g["timestamp"], g[y_col], label=str(item))
                plt.title(f"Live forecasts (last {int(time.time()-start)}s)")
                plt.xlabel("timestamp")
                plt.ylabel(y_col if y_col else "prediction")
                plt.legend()
                plt.tight_layout()
                plt.show()
            else:
                print("[plot] Prediction DataFrame does not have the expected columns:",
                      pred_df.columns.tolist())
        else:
            print("[plot] Waiting for first predictions...")

        time.sleep(poll_every)

# Run the live loop for, say, 3 minutes
# Stop the generator with `_stop_stream["flag"] = True` when you’re done.
live_infer_and_plot(loop_seconds=180, poll_every=FREQ_SECS)


[invoke] failed: name 'ENDPOINT_NAME' is not defined
[plot] Waiting for first predictions...
