In [0]:
# 03_gold_spc_charts.py
#
# SPC Charts (P-chart, Xbar & R)
#

import dlt
from pyspark.sql.functions import (
    col, avg, sum as spark_sum, lit, round as spark_round,
    row_number, sqrt, max as spark_max, min as spark_min, when,
    unix_timestamp, from_unixtime, to_timestamp
)
from pyspark.sql.window import Window

# -----------------------------------------------------------
# CONFIG
# -----------------------------------------------------------
SAMPLE_SIZE_DEFECT = 30   # for P-chart (subgroup size)
SAMPLE_SIZE_TEMP = 10     # for Xbar-R chart (subgroup size)

# ===========================================================
# 1️⃣ P-CHART (Per Assembly Line)
# ===========================================================
@dlt.table(
    name="03_gold.inspector_pchart",
    comment="P-chart per assembly line showing defect proportion, control limits and subgroup time"
)
def inspector_pchart():
    df = dlt.read("02_silver.inspector_enriched")

    # Order each line's events chronologically and assign row/subgroup ids
    w = Window.partitionBy("line_id").orderBy("event_time")
    df = df.withColumn("row_id", row_number().over(w))
    df = df.withColumn("subgroup_id", ((col("row_id") - 1) / SAMPLE_SIZE_DEFECT).cast("int"))

    # Aggregate per subgroup (fixed sample size per line).
    # For subgroup_time we average unix timestamps then convert back to timestamp.
    df_grouped = (
        df.groupBy("line_id", "subgroup_id")
          .agg(
              spark_sum("defective_count").alias("total_defective"),
              spark_sum("produced_count").alias("total_produced"),
              avg(unix_timestamp(col("event_time"))).alias("avg_event_epoch")
          )
          .withColumn("p_i", when(col("total_produced") > 0, col("total_defective") / col("total_produced")).otherwise(None))
          # convert averaged epoch back to timestamp
          .withColumn("subgroup_time", to_timestamp(from_unixtime(col("avg_event_epoch"))))
    )

    # Compute p_bar per line (overall proportion across all subgroups)
    stats = (
        df_grouped.groupBy("line_id")
                  .agg(
                      (spark_sum("total_defective") / spark_sum("total_produced")).alias("p_bar")
                  )
    )

    # Join to get p_bar per subgroup (per line) and compute limits
    df_joined = df_grouped.join(stats, on="line_id", how="left")

    df_final = (
        df_joined.withColumn(
            "UCL",
            col("p_bar") + 3 * sqrt((col("p_bar") * (1 - col("p_bar"))) / lit(SAMPLE_SIZE_DEFECT))
        )
        .withColumn(
            "LCL",
            col("p_bar") - 3 * sqrt((col("p_bar") * (1 - col("p_bar"))) / lit(SAMPLE_SIZE_DEFECT))
        )
        .withColumn("LCL", when(col("LCL") < 0, 0).otherwise(col("LCL")))
        .select(
            "line_id",
            "subgroup_id",
            "subgroup_time",
            spark_round("p_i", 4).alias("p_i"),
            spark_round("p_bar", 4).alias("p_bar"),
            spark_round("UCL", 4).alias("UCL"),
            spark_round("LCL", 4).alias("LCL"),
            "total_defective",
            "total_produced"
        )
    )

    return df_final

# ===========================================================
# 2️⃣ XBAR-R CHART (Per Machine Type per Line)
# ===========================================================
def compute_xbar_r(df):
    # Order by event_time within each (line_id, machine_type) and assign subgroup ids
    w = Window.partitionBy("line_id", "machine_type").orderBy("event_time")
    df = df.withColumn("row_id", row_number().over(w))
    df = df.withColumn("subgroup_id", ((col("row_id") - 1) / SAMPLE_SIZE_TEMP).cast("int"))

    # Compute subgroup stats and subgroup_time (avg of event_time)
    subgroups = (
        df.groupBy("line_id", "machine_type", "subgroup_id")
          .agg(
              avg("temperature_c").alias("xbar"),
              (spark_max("temperature_c") - spark_min("temperature_c")).alias("R"),
              avg(unix_timestamp(col("event_time"))).alias("avg_event_epoch")
          )
          .withColumn("subgroup_time", to_timestamp(from_unixtime(col("avg_event_epoch"))))
    )

    # Compute grand averages per (line, machine)
    stats = (
        subgroups.groupBy("line_id", "machine_type")
                 .agg(
                     avg("xbar").alias("xbarbar"),
                     avg("R").alias("Rbar")
                 )
    )

    # Constants for n = SAMPLE_SIZE_TEMP (n=10 here)
    # A2, D3, D4 values for n=10
    A2, D3, D4 = 0.308, 0.223, 1.777

    df_joined = subgroups.join(stats, on=["line_id", "machine_type"], how="left")

    return (
        df_joined.select(
            "line_id",
            "machine_type",
            "subgroup_id",
            "subgroup_time",
            spark_round("xbar", 3).alias("xbar"),
            spark_round("R", 3).alias("R"),
            spark_round("xbarbar", 3).alias("xbarbar"),
            spark_round("Rbar", 3).alias("Rbar"),
            spark_round((col("xbarbar") + A2 * col("Rbar")), 3).alias("UCLx"),
            spark_round((col("xbarbar") - A2 * col("Rbar")), 3).alias("LCLx"),
            spark_round((D4 * col("Rbar")), 3).alias("UCLr"),
            spark_round((D3 * col("Rbar")), 3).alias("LCLr")
        )
    )

@dlt.table(
    name="03_gold.temperature_xbar_r_chart",
    comment="Xbar and R chart per machine type and assembly line (with subgroup_time)"
)
def temperature_xbar_r_chart():
    # union drillcutter and polisher (they have same normalized schema)
    df = dlt.read("02_silver.drillcutter_enriched").unionByName(
             dlt.read("02_silver.polisher_enriched")
         )

    return compute_xbar_r(df)
