## Creating Dimension and Fact Tables

### Station Dimension Table

In [0]:

# Load the staging stations table
stations = spark.table("shplc_db.default.staging_stations")

# Drop the existing dim_station table if it exists
spark.sql("DROP TABLE IF EXISTS shplc_db.default.dim_station")

# Select unique stations and save them to the dim_station table
stations.dropDuplicates(["station_id"]) \
        .select('station_id', 'name', 'latitude', 'longitude') \
        .write.format("delta") \
        .mode("overwrite") \
        .saveAsTable("shplc_db.default.dim_station")


### Rider Dimension Table


In [0]:
# Load the staging stations table
riders = spark.table("shplc_db.default.staging_riders")

# Drop the existing dim_station table if it exists
spark.sql("DROP TABLE IF EXISTS shplc_db.default.dim_rider")

# Select unique stations and save them to the dim_station table
riders.dropDuplicates(["rider_id"]) \
        .select('rider_id', 'first_name', 'last_name', 'address', 'birthday', 'start_date', 'end_date', 'is_member') \
        .write.format("delta") \
        .mode("overwrite") \
        .saveAsTable("shplc_db.default.dim_rider")

### Date Dimension Table 

In [0]:
from pyspark.sql.functions import col, lit, sequence, explode, year, quarter, month, dayofweek, dayofmonth, date_format
from pyspark.sql.types import DateType

# Load trips data
df_trips = spark.read.format("delta").load('/delta/trips')
df_trips.createOrReplaceTempView("trips_table")

# Get the min and max trip dates
min_date = df_trips.agg({"started_at": "min"}).collect()[0][0]
max_date = df_trips.agg({"started_at": "max"}).collect()[0][0]

# Create a date range
date_range = spark.createDataFrame([(min_date, max_date)], ["start", "end"]) \
    .select(explode(sequence(col("start"), col("end"))).alias("date"))

# Add additional date fields
date_dim = date_range.withColumn("formatted_date", date_format(col("date"), "yyyy-MM-dd")) \
    .withColumn("date_id", col("formatted_date").cast("string")) \
    .withColumn("year", year(col("date"))) \
    .withColumn("quarter", quarter(col("date"))) \
    .withColumn("month", month(col("date"))) \
    .withColumn("day_of_week", dayofweek(col("date"))) \
    .withColumn("day_of_month", dayofmonth(col("date")))

# Save the date dimension table
date_dim.write.format("delta")\
        .mode("overwrite")\
        .saveAsTable("shplc_db.default.dim_date")

### Payments Fact Table

In [0]:
# Load the staging stations table
payments = spark.table("shplc_db.default.staging_payments")

# Drop the existing fact_payment table if it exists
spark.sql("DROP TABLE IF EXISTS shplc_db.default.fact_payment")

# Select unique stations and save them to the fact_payment table
payments = payments.join(spark.table("shplc_db.default.dim_date").alias("d"), 
                          payments.date.cast("string") == col("d.date_id")) \
                  .join(spark.table("shplc_db.default.dim_rider").alias("r"), 
                          payments.rider_id == col("r.rider_id")) \
                  .select('payment_id', 'amount', 'r.rider_id', 'd.date_id') \
                  .dropDuplicates(["payment_id"]) \
                  .write.format("delta") \
                  .mode("overwrite") \
                  .saveAsTable("shplc_db.default.fact_payment")

### Trips Fact Table

In [0]:
# Load trips data
trips = spark.read.format("delta").load('/delta/trips')
trips.createOrReplaceTempView("trips_table")



In [0]:
trips.printSchema

<bound method DataFrame.printSchema of DataFrame[trip_id: string, rideable_type: string, started_at: timestamp, ended_at: timestamp, start_station_id: string, end_station_id: string, rider_id: int]>

In [0]:
# Load trips data
trips = spark.read.format("delta").load('/delta/trips')
trips.createOrReplaceTempView("trips_table")

# Drop the existing fact_payment table if it exists
spark.sql("DROP TABLE IF EXISTS shplc_db.default.fact_trips")

# Select unique stations and save them to the fact_payment table
trips = trips.alias("t").join(
            spark.table("shplc_db.default.dim_date").alias("sd"), 
            col("t.started_at") == col("sd.date")
        ).join(
            spark.table("shplc_db.default.dim_date").alias("ed"), 
            col("t.ended_at") == col("ed.date")  
        ).join(
            spark.table("shplc_db.default.dim_rider").alias("r"), 
            col("t.rider_id") == col("r.rider_id")
        ).select(
            'trip_id', 
            'r.rider_id', 
            'start_station_id', 
            'end_station_id', 
            col('sd.date_id').alias('start_time_id'), 
            col('ed.date_id').alias('end_time_id'), 
            'rideable_type', 
            (col('t.ended_at').cast("long") - col('t.started_at').cast("long")).alias('duration'), 
            (year(col('t.started_at')) - year(col('r.birthday'))).alias('rider_age')
        ).dropDuplicates(["trip_id"])

trips.write.format("delta").mode("overwrite").saveAsTable("shplc_db.default.fact_trips")