In [0]:
from pyspark.sql import functions as F, types as T
from datetime import datetime
from functools   import reduce
from helpers import SILVER, GOLD, TAXI_TYPES, table_exists

### 📅 Criação e popularização da dimensão de datas (`dim_date`)

Este trecho cria e popula uma tabela de dimensão temporal com datas de **01/01/2023 a 31/12/2023**, contendo:

- `date_id`: formato `yyyymmdd`;
- Atributos como ano, mês, dia, dia da semana, quarter e flag de fim de semana.

🎯 **Objetivo**: fornecer uma dimensão temporal para facilitar análises agregadas e permitir *joins* eficientes com os dados de corridas.


In [0]:
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {GOLD.TBL_DIM_DATE} (
    date_id      INT    COMMENT 'yyyymmdd',
    date         DATE,
    year         INT,
    month        INT,
    day          INT,
    day_of_week  INT,          
    is_weekend   BOOLEAN,
    quarter      INT
) USING DELTA
""")

dates = (
    spark.sql("""
        SELECT explode(
                 sequence( to_date('2023-01-01')
                         , to_date('2023-12-31')
                         , interval 1 day)
               ) AS date
    """)
    .withColumn("date_id",     F.date_format("date", "yyyyMMdd").cast("int"))
    .withColumn("year",        F.year("date"))
    .withColumn("month",       F.month("date"))
    .withColumn("day",         F.dayofmonth("date"))
    .withColumn("day_of_week", F.dayofweek("date"))     
    .withColumn("is_weekend",  F.col("day_of_week").isin(1,7))
    .withColumn("quarter",     F.quarter("date"))
)


(dates.write
      .format("delta")
      .mode("overwrite")         
      .option("overwriteSchema", "true")
      .saveAsTable(GOLD.TBL_DIM_DATE))       

rows = spark.table(GOLD.TBL_DIM_DATE).count()
print(f"dim_date populada com {rows:,} linhas")


### 🏁 Criação da tabela fato `fact_trips` (camada Gold)

Este trecho:

- Lê os dados da camada Silver para cada tipo de táxi;
- Calcula colunas auxiliares como:
  - `pickup_date_id` e `pickup_time_id` para facilitar análises temporais;
  - `duration_min`, `distance_mi` e `fare` para enriquecer a base;
- Realiza **agregações por data e hora do embarque**, gerando métricas como:
  - `num_trips`, `total_revenue`, `total_passengers`, `avg_fare`, `revenue_mi`, etc;
- Escreve o resultado final como tabela Delta particionada por `pickup_date_id`.

🎯 **Objetivo**: consolidar uma visão analítica unificada das corridas para suportar dashboards, KPIs e análises diversas.


In [0]:
agg_dfs = []
for tt in TAXI_TYPES:
    src_tbl = f"{SILVER.SCHEMA}.{tt}_tripdata_silver"
    if not spark.catalog.tableExists(src_tbl):
        print(f"Silver ausente → {tt}")
        continue

    df = spark.table(src_tbl)

    pickup_col  = next(c for c in df.columns if c.endswith("pickup_datetime"))
    dropoff_col = next(c for c in df.columns if c.endswith("dropoff_datetime"))

    df = (df
          .withColumnRenamed(pickup_col,  "pickup_ts")
          .withColumnRenamed(dropoff_col, "dropoff_ts")
          .withColumn("pickup_date_id",  F.date_format("pickup_ts", "yyyyMMdd").cast("int"))
          .withColumn("pickup_time_id",  (F.hour("pickup_ts")*100 + F.minute("pickup_ts")).cast("int"))
          .withColumn("taxi_type",       F.lit(tt))
          .withColumn(
            "duration_min",
            F.expr("timestampdiff(SECOND, pickup_ts, dropoff_ts)") / 60.0
           )
          .withColumn("distance_mi",     F.col("trip_distance"))
          .withColumn("fare",            F.col("total_amount"))
    )

    agg = (df.groupBy("taxi_type","pickup_date_id","pickup_time_id")
              .agg(
                  F.count("*").alias("num_trips"),
                  F.sum("passenger_count").alias("total_passengers"),
                  F.sum("fare").alias("total_revenue"),
                  F.sum("distance_mi").alias("distance_mi"),
                  F.sum("duration_min").alias("total_duration"))
              .withColumn("avg_fare",     F.round(F.col("total_revenue")/F.col("num_trips"),2))
              .withColumn("avg_distance", F.round(F.col("distance_mi")/F.col("num_trips"),2))
              .withColumn("avg_duration", F.round(F.col("total_duration")/F.col("num_trips"),2))
              .withColumn("revenue_mi",   F.round(F.col("total_revenue")/F.col("distance_mi"),2))
              .withColumn("load_factor",  F.round(F.col("total_passengers")/F.col("num_trips"),2))
          )
    agg_dfs.append(agg)

all_agg = reduce(lambda a,b: a.unionByName(b, allowMissingColumns=True), agg_dfs)

tgt_tbl = f"{GOLD.SCHEMA}.fact_trips"

(all_agg.write
    .format("delta")
    .mode("overwrite")            
    .partitionBy("pickup_date_id") 
    .option("overwriteSchema", "true")
    .option("delta.feature.checkConstraints","supported")
    .saveAsTable(tgt_tbl) 
)

print("Gold fact_trips criada —",
      spark.table(tgt_tbl).count(), "linhas")
