Problem Statement:

Here are given a dataset of product production records, which includes information about the production date (dt), brand name (brand), model name (model), and production cost (production_cost). The goal is to calculate the total production cost for each brand on a specific date using a window function that partitions the data by dt (date) and brand, and orders the data by production_cost.

For each record, you need to compute the sum of the production costs within the same dt and brand group. This sum should be consistent for all rows that belong to the same partition.

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType
from pyspark.sql import Row
from datetime import datetime

# Initialize Spark session
spark = SparkSession.builder.appName("CreateTableAndInsertRecords").getOrCreate()

# Define schema
schema = StructType([
    StructField("dt", DateType(), True),
    StructField("brand", StringType(), True),
    StructField("model", StringType(), True),
    StructField("production_cost", IntegerType(), True)
])

# Data for the table
data = [
    ('2023-12-01', 'A', 'A1', 1000),
    ('2023-12-01', 'A', 'A2', 1300),
    ('2023-12-01', 'B', 'B1', 800),
    ('2023-12-02', 'A', 'A1', 1800),
    ('2023-12-02', 'B', 'B1', 900),
    ('2023-12-10', 'A', 'A1', 1400),
    ('2023-12-10', 'A', 'A1', 1200),
    ('2023-12-10', 'C', 'C1', 2500)
]

# Convert the data to Rows
rows = [Row(datetime.strptime(row[0], '%Y-%m-%d'), row[1], row[2], row[3]) for row in data]

# Create DataFrame
df = spark.createDataFrame(rows, schema)

# Show the DataFrame
df.display()


dt,brand,model,production_cost
2023-12-01,A,A1,1000
2023-12-01,A,A2,1300
2023-12-01,B,B1,800
2023-12-02,A,A1,1800
2023-12-02,B,B1,900
2023-12-10,A,A1,1400
2023-12-10,A,A1,1200
2023-12-10,C,C1,2500


In [0]:
# Register the DataFrame as a temporary table
df.createOrReplaceTempView("prd_tbl")

In [0]:
%sql
select
  *,
  sum(production_cost) over (
    partition by dt,
    brand
    order by
      production_cost rows between unbounded preceding
      and unbounded following
  ) as sum_cost
from
  prd_tbl

dt,brand,model,production_cost,sum_cost
2023-12-01,A,A1,1000,2300
2023-12-01,A,A2,1300,2300
2023-12-01,B,B1,800,800
2023-12-02,A,A1,1800,1800
2023-12-02,B,B1,900,900
2023-12-10,A,A1,1200,2600
2023-12-10,A,A1,1400,2600
2023-12-10,C,C1,2500,2500


In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import sum

# Define the window specification
window_spec = Window.partitionBy("dt", "brand").orderBy("production_cost").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

# Calculate the sum of production_cost over the window
df_with_sum_cost = df.withColumn("sum_cost", sum("production_cost").over(window_spec))

# Show the result
df_with_sum_cost.display()


dt,brand,model,production_cost,sum_cost
2023-12-01,A,A1,1000,2300
2023-12-01,A,A2,1300,2300
2023-12-01,B,B1,800,800
2023-12-02,A,A1,1800,1800
2023-12-02,B,B1,900,900
2023-12-10,A,A1,1200,2600
2023-12-10,A,A1,1400,2600
2023-12-10,C,C1,2500,2500


Explanation:

Window.partitionBy("dt", "brand"): This creates partitions based on the dt and brand columns, similar to PARTITION BY dt, brand in SQL.

orderBy("production_cost"): This orders the rows within each partition by production_cost.

rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing): This specifies that the sum should be calculated over all rows within the partition (equivalent to ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING in SQL).

sum("production_cost").over(window_spec): This computes the sum of production_cost over the defined window.

This code calculates the rolling sum of production_cost for each dt and brand group, ordered by production_cost, and adds the result as a new column called sum_cost.