One of the SQL questions recently asked in Trelleborg interview.
Given us sales table, find out the periodic sales.

In order to solve this questions, we used LAG Function. 


In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, col, when

# Create a Spark session
spark = SparkSession.builder.appName("Sales Data").getOrCreate()

# Define the schema for the sales table
schema = StructType([
    StructField("month", StringType(), True),
    StructField("ytd_sales", IntegerType(), True),
    StructField("monthnum", IntegerType(), True)
])

# Data for the sales table
data_sales = [
    ('jan', 15, 1),
    ('feb', 22, 2),
    ('mar', 35, 3),
    ('apr', 45, 4),
    ('may', 60, 5)
]

# Create a DataFrame for the sales table
df_sales = spark.createDataFrame(data_sales, schema)

# Show the DataFrame
df_sales.display()


month,ytd_sales,monthnum
jan,15,1
feb,22,2
mar,35,3
apr,45,4
may,60,5


In [0]:
df_sales.createOrReplaceTempView('sales')

In [0]:
%sql
with cte as(
  Select
    ytd_sales,
    ytd_sales - lag(ytd_sales) over (
      order by
        monthnum
    ) as lg
  from
    sales
)
select
  ytd_sales,case
    when lg is null then ytd_sales
    else lg
  end
from
  cte

ytd_sales,CASE WHEN (lg IS NULL) THEN ytd_sales ELSE lg END
15,15
22,7
35,13
45,10
60,15


In [0]:
window_spec = Window.orderBy("monthnum")

# Add a column 'lg' that calculates the lag of ytd_sales
df_with_lag = df_sales.withColumn("lg", lag("ytd_sales").over(window_spec))

# Calculate the difference and handle the null case for the first row
df_final = df_with_lag.withColumn(
    "final_sales",
    when(col("lg").isNull(), col("ytd_sales")).otherwise(col("ytd_sales") - col("lg"))
)

# Select and show the required columns
df_final.select("ytd_sales", "final_sales").display()

ytd_sales,final_sales
15,15
22,7
35,13
45,10
60,15
