In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import dotenv
dotenv.load_dotenv()

import os
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
from sqlalchemy import create_engine, Table, Column, MetaData, Integer, Computed, DateTime, Numeric, Float
from sqlalchemy.orm import sessionmaker, declarative_base

In [3]:
import torch
from torchmetrics import Accuracy, Precision, Recall, F1Score, ROC, AUROC, PrecisionRecallCurve, AveragePrecision, ConfusionMatrix

In [4]:
from population_stability_index.psi import calculate_psi
from lift_curve import lift_curve

In [5]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, pandas_udf, monotonically_increasing_id, lit, array, PandasUDFType
from pyspark.sql.types import IntegerType, TimestampType, FloatType, StructType, StructField, ArrayType, BooleanType

In [7]:
spark = (SparkSession.builder
         .master("local[4]")
         .config(key="spark.sql.caseSensitive", value=True)
         .config(key="spark.sql.execution.arrow.pyspark.fallback.enabled", value=True)
         .config(key="spark.sql.execution.arrow.pyspark.enabled", value=True)
         .config(key="spark.sql.execution.arrow.pyspark.datetime64.enabled", value=True)
         .config(key="spark.jars", value="./jar/postgresql-42.6.0.jar")
         .getOrCreate())
spark

# Load and Preprocess Data

In [7]:
baseline_data = spark.read.json("./datasets/baseline_data_predicted.jsonl")
realtime_data = spark.read.json("./datasets/realtime_data_predicted.jsonl")

23/06/02 16:36:17 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
23/06/02 16:36:17 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
23/06/02 16:36:17 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
23/06/02 16:36:17 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
23/06/02 16:36:18 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the f

In [8]:
num_shards = realtime_data.count() // 100
timestamp_now = spark.sparkContext.broadcast(datetime.now())
@udf(TimestampType())
def generate_timestamp(shard_id):
    return timestamp_now.value + timedelta(hours=shard_id - num_shards)

In [9]:
baseline_data = baseline_data.withColumn("timestamp", lit(datetime.now()))
realtime_data = realtime_data.withColumn("id", monotonically_increasing_id())
realtime_data = realtime_data.withColumn("shard_id", realtime_data["id"] % num_shards)
realtime_data = realtime_data.withColumn("timestamp", generate_timestamp(realtime_data["shard_id"]))

# Set Up Database

In [16]:
POSTGRES_PASSWORD = os.environ["POSTGRES_PASSWORD"]
POSTGRES_URL = f"postgresql+psycopg://postgres:{POSTGRES_PASSWORD}@localhost/postgres"
engine = create_engine(POSTGRES_URL)
Session = sessionmaker(bind=engine)
def get_db_options(table):
    return dict(
        url=f"jdbc:postgresql://localhost:5432/postgres",
        dbtable=table,
        user="postgres",
        password=POSTGRES_PASSWORD,
        driver="org.postgresql.Driver"
    )

# Common

In [10]:
@udf(StructType([
    StructField(f"_{i}", FloatType())
    for i in range(100)
]))
def unpack_thresholds(thresholds: list, values: list) -> tuple:
    result = [None] * 100
    for t, v in zip(thresholds, values):
        result[round(t * 100)] = v
    return tuple(result)

@udf(StructType([
    StructField(f"_{i}", FloatType())
    for i in range(100)
]))
def threshold_value_baseline_diff(baseline: list, *values: list) -> tuple:
    return tuple(v - b for v, b in zip(values, baseline))

@udf(StructType([
    StructField(f"decile_{i}", FloatType())
    for i in range(1, 11)
]))
def unpack_deciles(deciles: list, values: list) -> tuple:
    result = [None] * 10
    for d, v in zip(deciles, values):
        result[round(d * 10) - 1] = v
    return tuple(result)

@udf(StructType([
    StructField(f"decile_{i}", FloatType())
    for i in range(1, 11)
]))
def decile_value_baseline_diff(baseline: list, *values: list) -> tuple:
    return tuple(v - b for v, b in zip(values, baseline))



# Accuracy

In [None]:
class AccuracyTable(Base := declarative_base()):
    __tablename__ = 'acc_table'
    __metricname__ = 'accuracy'

    timestamp = Column(DateTime, primary_key=True)
    baseline = Column(Float)
    realtime = Column(Float)

Base.metadata.create_all(engine)

@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def accuracy(score: pd.Series, label: pd.Series) -> float:
    return Accuracy(task="binary")(preds=torch.tensor(score), target=torch.tensor(label)).item()

agg_expr = accuracy("score", "label").alias("accuracy")
baseline_acc = baseline_data.agg(agg_expr).toPandas().iloc[0]
realtime_acc = realtime_data.groupby("timestamp").agg(agg_expr)

(realtime_acc
 .select("timestamp",
         lit(baseline_acc["accuracy"]).alias("baseline"),
         realtime_acc["accuracy"].alias("realtime"))
 .write.format('jdbc')
 .options(**get_db_options("acc_table"))
 .mode("overwrite")
 .save())

# Precision

In [None]:
@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def precision(score: pd.Series, label: pd.Series) -> float:
    return Precision(task="binary")(preds=torch.tensor(score), target=torch.tensor(label)).item()

# Recall

In [None]:
@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def recall(score: pd.Series, label: pd.Series) -> float:
    return Recall(task="binary")(preds=torch.tensor(score), target=torch.tensor(label)).item()

# F1 Score

In [None]:

@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def f1score(score: pd.Series, label: pd.Series) -> float:
    return F1Score(task="binary")(preds=torch.tensor(score), target=torch.tensor(label)).item()

# AUROC

In [None]:

@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def auroc(score: pd.Series, label: pd.Series) -> float:
    return AUROC(task="binary")(preds=torch.tensor(score), target=torch.tensor(label)).item()

# AUPRC

In [None]:

@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def auprc(score: pd.Series, label: pd.Series) -> float:
    return AveragePrecision(task="binary")(preds=torch.tensor(score), target=torch.tensor(label)).item()

# Confusion Matrix

In [None]:

@pandas_udf(ArrayType(FloatType()), PandasUDFType.GROUPED_AGG)
def confusion(score: pd.Series, label: pd.Series) -> list[float]:
    confmat = ConfusionMatrix(task="binary")(preds=torch.tensor(score), target=torch.tensor(label))
    return confmat.flatten().tolist()

@udf(StructType([
    StructField("tn", FloatType()),
    StructField("fn", FloatType()),
    StructField("fp", FloatType()),
    StructField("tp", FloatType()),
]))
def confusion_struct(confmat: list) -> tuple:
    return tuple(confmat)


# ROC Curve

In [None]:

@pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.GROUPED_AGG)
def roccurve(score: pd.Series, label: pd.Series) -> list[float]:
    roc = ROC(task="binary", thresholds=torch.arange(0, 1, 0.01))(preds=torch.tensor(score), target=torch.tensor(label))
    return [c.tolist() for c in roc]

@udf(StructType([
    StructField("rocfpr", ArrayType(FloatType())),
    StructField("roctpr", ArrayType(FloatType())),
    StructField("rocthresh", ArrayType(FloatType())),
]))
def roccurve_struct(roc: list) -> tuple:
    return tuple(roc)

# PR Curve

In [None]:

@pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.GROUPED_AGG)
def prccurve(score: pd.Series, label: pd.Series) -> list[float]:
    prc = PrecisionRecallCurve(task="binary", thresholds=torch.arange(0, 1, 0.01))(preds=torch.tensor(score), target=torch.tensor(label))
    return [c.tolist() for c in prc]

@udf(StructType([
    StructField("prcprec", ArrayType(FloatType())),
    StructField("prcrec", ArrayType(FloatType())),
    StructField("prcthresh", ArrayType(FloatType())),
]))
def prccurve_struct(roc: list) -> tuple:
    return tuple(roc)

# Lift Curve

In [None]:

@pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.GROUPED_AGG)
def liftcurve(score: pd.Series, label: pd.Series) -> list[float]:
    lift = lift_curve(y_val=torch.tensor(label), y_pred=torch.tensor(score), step=0.1)
    return [c.tolist() for c in lift]

@udf(StructType([
    StructField("decile", ArrayType(FloatType())),
    StructField("lift", ArrayType(FloatType())),
]))
def liftcurve_struct(roc: list) -> tuple:
    return tuple(roc)

In [None]:
Base = declarative_base()

class LiftTable(Base):
    __tablename__ = 'lift_table'

    timestamp = Column(DateTime, primary_key=True)

for decile in range(1, 11):
    setattr(LiftTable, f"decile_{decile}", Column(Float))

class LiftDiffTable(Base):
    __tablename__ = 'lift_diff_table'

    timestamp = Column(DateTime, primary_key=True)

for decile in range(1, 11):
    setattr(LiftDiffTable, f"decile_{decile}", Column(Float))

Base.metadata.create_all(engine)

In [None]:
baseline_values = list(unpack_deciles.func(baseline_metrics["decile"], baseline_metrics["lift"]))
value = (realtime_metrics
         .select("timestamp", unpack_deciles("decile", "lift").alias("struct"))
         .select("timestamp", "struct.*")).cache()
diff = (value
        .select("timestamp", decile_value_baseline_diff(lit(baseline_values), *[f"decile_{i}" for i in range(1, 11)]).alias("struct"))
        .select("timestamp", "struct.*"))
(value
 .write.format('jdbc')
 .options(**get_db_options(LiftTable.__tablename__))
 .mode("overwrite")
 .save())
(diff
 .write.format('jdbc')
 .options(**get_db_options(LiftDiffTable.__tablename__))
 .mode("overwrite")
 .save())

# PSI

In [None]:
Base = declarative_base()

class PSITable(Base):
    __tablename__ = "psi_table"
    __metricname__ = "psi"

    timestamp = Column(DateTime, primary_key=True)
    value = Column(Float)

Base.metadata.create_all(engine)

baseline_score = spark.sparkContext.broadcast(baseline_data.select("score").toPandas().to_numpy().flatten())
@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def psi(score: pd.Series) -> float:
    return calculate_psi(baseline_score.value, np.array(score))


(realtime_metrics
 .select("timestamp", realtime_metrics[PSITable.__metricname__].alias("value"))
 .write.format('jdbc')
 .options(**get_db_options(PSITable.__tablename__))
 .mode("overwrite")
 .save())


# Transform Data

In [11]:
def eval_performance(dataset, compute_psi=False):
    args = dataset["score"], dataset["label"]
    agg_fns = [
        accuracy(*args).alias("accuracy"),
        precision(*args).alias("precision"),
        recall(*args).alias("recall"),
        f1score(*args).alias("f1score"),
        confusion(*args).alias("confusion"),
        roccurve(*args).alias("roccurve"),
        auroc(*args).alias("auroc"),
        prccurve(*args).alias("prccurve"),
        auprc(*args).alias("auprc"),
        liftcurve(*args).alias("liftcurve"),
    ]
    if compute_psi:
        agg_fns.append(psi(args[0]).alias("psi"))
    metrics = (dataset
               .groupby("timestamp")
               .agg(*agg_fns)
               .withColumn("confusion_struct", confusion_struct("confusion"))
               .withColumn("roccurve_struct", roccurve_struct("roccurve"))
               .withColumn("prccurve_struct", prccurve_struct("prccurve"))
               .withColumn("liftcurve_struct", liftcurve_struct("liftcurve"))
               .select("*", "confusion_struct.*", "roccurve_struct.*", "prccurve_struct.*", "liftcurve_struct.*")
               .drop("confusion", "roccurve", "prccurve", "liftcurve")
               .drop("confusion_struct", "roccurve_struct", "prccurve_struct", "liftcurve_struct"))
    return metrics

In [12]:
baseline_metrics = eval_performance(baseline_data).toPandas().iloc[0]
realtime_metrics = eval_performance(realtime_data, compute_psi=True)

                                                                                

# Push Basic Metrics

In [17]:


class PrecisionTable(Base):
    __tablename__ = 'prec_table'
    __metricname__ = 'precision'

    timestamp = Column(DateTime, primary_key=True)
    baseline = Column(Float)
    realtime = Column(Float)

class RecallTable(Base):
    __tablename__ = 'rec_table'
    __metricname__ = 'recall'

    timestamp = Column(DateTime, primary_key=True)
    baseline = Column(Float)
    realtime = Column(Float)

class F1Table(Base):
    __tablename__ = 'f1_table'
    __metricname__ = 'f1score'

    timestamp = Column(DateTime, primary_key=True)
    baseline = Column(Float)
    realtime = Column(Float)

class AUROCTable(Base):
    __tablename__ = 'auroc_table'
    __metricname__ = 'auroc'

    timestamp = Column(DateTime, primary_key=True)
    baseline = Column(Float)
    realtime = Column(Float)

class AUPRCTable(Base):
    __tablename__ = 'auprc_table'
    __metricname__ = 'auprc'

    timestamp = Column(DateTime, primary_key=True)
    baseline = Column(Float)
    realtime = Column(Float)

Base.metadata.create_all(engine)

In [18]:
for table in AccuracyTable, PrecisionTable, RecallTable, F1Table, AUROCTable, AUPRCTable:
    (realtime_metrics
     .select("timestamp",
             lit(baseline_metrics[table.__metricname__]).alias("baseline"),
             realtime_metrics[table.__metricname__].alias("realtime"))
     .write.format('jdbc')
     .options(**get_db_options(table.__tablename__))
     .mode("overwrite")
     .save())

23/06/02 16:37:17 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

# Push ROC and PRC Curves

In [None]:
Base = declarative_base()

def make_curve_table_class(type_name, table_name, metric_name, thresh_name, thresh_interval=1):
    return type(type_name, (Base,), dict(
        __tablename__=table_name,
        __metricname__=metric_name,
        __threshname__=thresh_name,
        timestamp=Column(DateTime, primary_key=True),
    ) | {f"_{i}": Column(Float) for i in range(0, 100, thresh_interval)})

thresh_interval = 1

table_metadata = {
    "ROCFPR": {
        "table_prefix": "roc_fp",
        "metric_name": "rocfpr",
        "thresh_name": "rocthresh"
    },
    "ROCTPR": {
        "table_prefix": "roc_tp",
        "metric_name": "roctpr",
        "thresh_name": "rocthresh"
    },
    "PRCPrecision": {
        "table_prefix": "prc_prec",
        "metric_name": "prcprec",
        "thresh_name": "prcthresh"
    },
    "PRCRecall": {
        "table_prefix": "prc_rec",
        "metric_name": "prcrec",
        "thresh_name": "prcthresh"
    }
}

tables = {
    prefix: (
        make_curve_table_class(
            type_name=prefix + "Table",
            table_name=meta["table_prefix"] + "_table",
            metric_name=meta["metric_name"],
            thresh_name=meta["thresh_name"],
            thresh_interval=thresh_interval
        ),
        make_curve_table_class(
            type_name=prefix + "DiffTable",
            table_name=meta["table_prefix"] + "_diff_table",
            metric_name=meta["metric_name"],
            thresh_name=meta["thresh_name"],
            thresh_interval=thresh_interval
        )
    ) for prefix, meta in table_metadata.items()
}

Base.metadata.create_all(engine)

In [None]:
for value_table, diff_table in tables.values():
    threshname = value_table.__threshname__
    metricname = value_table.__metricname__
    baseline_values = list(unpack_thresholds.func(baseline_metrics[threshname], baseline_metrics[metricname]))
    value = (realtime_metrics
             .select("timestamp", unpack_thresholds(threshname, metricname).alias("struct"))
             .select("timestamp", "struct.*")).cache()
    diff = (value
            .select("timestamp", threshold_value_baseline_diff(lit(baseline_values), *[f"_{i}" for i in range(100)]).alias("struct"))
            .select("timestamp", "struct.*"))
    (value
     .write.format('jdbc')
     .options(**get_db_options(value_table.__tablename__))
     .mode("overwrite")
     .save())
    (diff
     .write.format('jdbc')
     .options(**get_db_options(diff_table.__tablename__))
     .mode("overwrite")
     .save())

# Push Model Prediction Scores

In [21]:
Base = declarative_base()

class PositiveScoreTable(Base):
    __tablename__ = 'pos_score_table'

    id = Column(Integer, primary_key=True, autoincrement=True)
    timestamp = Column(DateTime)
    score = Column(Float)

class NegativeScoreTable(Base):
    __tablename__ = 'neg_score_table'

    id = Column(Integer, primary_key=True, autoincrement=True)
    timestamp = Column(DateTime)
    score = Column(Float)

Base.metadata.create_all(engine)

In [22]:
@udf(BooleanType())
def filter_positive(label):
    return label == 1

@udf(BooleanType())
def filter_negative(label):
    return label == 0

In [23]:
(realtime_data
 .filter(filter_positive("label"))
 .select("timestamp", "score")
 .write.format('jdbc')
 .options(**get_db_options(PositiveScoreTable.__tablename__))
 .mode("overwrite")
 .save())
(realtime_data
 .filter(filter_negative("label"))
 .select("timestamp", "score")
 .write.format('jdbc')
 .options(**get_db_options(NegativeScoreTable.__tablename__))
 .mode("overwrite")
 .save())

                                                                                

# Push Confusion Matrices

In [24]:
Base = declarative_base()

class ConfusionTable(Base):
    __tablename__ = 'confusion_table'

    timestamp = Column(DateTime, primary_key=True)
    tp = Column(Float)
    fp = Column(Float)
    tn = Column(Float)
    fn = Column(Float)

Base.metadata.create_all(engine)

In [25]:
(realtime_metrics
 .select("timestamp", "tp", "fp", "tn", "fn")
 .write.format('jdbc')
 .options(**get_db_options(ConfusionTable.__tablename__))
 .mode("overwrite")
 .save())

                                                                                