# Moderation Model Training (Streaming Only)

This notebook trains from `moderation.training_data` using PyIceberg batch streaming and `SGDClassifier.partial_fit`.

In this notebook, we:

1. connect to Iceberg
2. capture the current snapshot id
3. train in streaming batches
4. save artifacts locally

In [None]:
import json
from datetime import datetime, timezone

import joblib
import numpy as np
from pyiceberg.catalog import load_catalog
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import OneHotEncoder, StandardScaler

numeric_features = [
    "time_since_upload_seconds", "hour_of_day", "day_of_week",
    "views_5min", "views_1hr", "comments_5min", "comments_1hr",
    "flags_5min", "flags_1hr", "total_flags",
    "view_velocity_per_min", "comment_to_view_ratio", "recent_engagement_score",
    "caption_length", "user_image_count", "user_age_days"
]

categorical_features = ["is_weekend", "has_caption", "category"]
label_column = "label_needs_moderation_24h"

print("Imports ready")


## Load Iceberg table and capture snapshot id

Iceberg gives us reproducibility through snapshots. We record the exact snapshot used for training so we can trace model artifacts back to a precise dataset version.

In [None]:
catalog = load_catalog("gourmetgram")
table = catalog.load_table("moderation.training_data")

snapshot_id = table.metadata.current_snapshot_id
print(f"Using snapshot id: {snapshot_id}")
print(f"Table identifier: {table._identifier}")


## Streaming training loop

We stream Arrow batches instead of loading all rows into memory at once.

In [None]:
scanner = table.scan()
batches = scanner.to_arrow_batch_reader()

model = SGDClassifier(loss="log_loss", random_state=42)
scaler = StandardScaler()
ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
classes = np.array([0, 1])

first_batch = True
total_rows = 0
batch_count = 0

for batch in batches:
    df_batch = batch.to_pandas()
    if df_batch.empty:
        continue

    X = df_batch[numeric_features + categorical_features]
    y = df_batch[label_column].astype(int)

    X_num = X[numeric_features].fillna(0)
    X_cat = X[categorical_features].fillna("Unknown").astype(str)

    if first_batch:
        scaler.fit(X_num)
        ohe.fit(X_cat)
        first_batch = False

    X_transformed = np.hstack((
        scaler.transform(X_num),
        ohe.transform(X_cat),
    ))

    model.partial_fit(X_transformed, y, classes=classes)

    batch_count += 1
    total_rows += len(df_batch)

    if batch_count % 10 == 0:
        print(f"Processed batch {batch_count} | total rows: {total_rows}")

print(f"Training complete | batches: {batch_count} | rows: {total_rows}")


In [None]:
import os
os.makedirs("models", exist_ok=True)

joblib.dump(model, "models/model.joblib")
joblib.dump(scaler, "models/scaler.joblib")
joblib.dump(ohe, "models/encoder.joblib")

metadata = {
    "trained_at": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
    "snapshot_id": str(snapshot_id),
    "rows_seen": int(total_rows),
    "batches_seen": int(batch_count),
    "algorithm": "SGDClassifier(log_loss)",
}

with open("models/metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print("Saved local artifacts:")
print("- models/model.joblib")
print("- models/scaler.joblib")
print("- models/encoder.joblib")
print("- models/metadata.json")
print("metadata:", metadata)
