
We have a dataset with daily stock prices for various stocks. For each stock, you need to flag rows where the price has increased, decreased, or remained the same compared to the previous day. The first day for each stock should have a NULL value in the new column, as there is no previous day to compare.


In [0]:
from pyspark.sql.functions import *
from pyspark.sql.functions import lead, lag, to_timestamp
from pyspark.sql.window import Window

data = [
    ("A", "2024-01-01", "100"),
    ("A", "2024-01-02", "105"),
    ("A", "2024-01-03", "104"),
    ("B", "2024-01-01", "200"),
    ("B", "2024-01-02", "200"),
    ("B", "2024-01-03", "201"),
]

schema = ["stockid", "date", "price"]

df = spark.createDataFrame(data, schema)
df.display()

stockid,date,price
A,2024-01-01,100
A,2024-01-02,105
A,2024-01-03,104
B,2024-01-01,200
B,2024-01-02,200
B,2024-01-03,201


In [0]:
myWindow = Window.partitionBy("stockid").orderBy("date")
df1 = df.withColumn("previousprice", lag("price").over(myWindow)).orderBy(
    col("stockid"), col("date")
)
df2 = df1.withColumn("pricediff", col("price") - col("previousprice"))
df3 = df2.withColumn(
    "price_change",
    when(col("pricediff") == 0.0, "SAME")
    .when(col("pricediff") == "null", "NULL")
    .when(col("pricediff") >= 1.0, "UP")
    .when(col("pricediff") < 0.0, "DOWN"),
)
df4 = df3.select("stockid", "date", "price", "price_change")
df4.display()

stockid,date,price,price_change
A,2024-01-01,100,
A,2024-01-02,105,UP
A,2024-01-03,104,DOWN
B,2024-01-01,200,
B,2024-01-02,200,SAME
B,2024-01-03,201,UP


In [0]:
df.createOrReplaceTempView("stock_data")

In [0]:
%sql
WITH cte AS (
    SELECT 
        stockid,
        date,
        price,
        LAG(price) OVER (PARTITION BY stockid ORDER BY date) AS lprice
    FROM stock_data
)
SELECT 
    stockid,
    date,
    price,
    CASE 
        WHEN lprice IS NULL THEN 'NULL'
        WHEN price > lprice THEN 'UP'
        WHEN price < lprice THEN 'DOWN'
        ELSE 'SAME' 
    END AS price_change
FROM cte;


stockid,date,price,price_change
A,2024-01-01,100,
A,2024-01-02,105,UP
A,2024-01-03,104,DOWN
B,2024-01-01,200,
B,2024-01-02,200,SAME
B,2024-01-03,201,UP


Explanation:

Temporary View:
 
We use createOrReplaceTempView("stock_data") to make the DataFrame accessible via SQL queries.
LAG Window Function: In the SQL query, we use the LAG function to get the previous price (previousprice), partitioned by stock_id and ordered by date.
Price Difference: We compute the price difference (pricediff) by subtracting the previousprice from the current price.
CASE WHEN: We use CASE WHEN to categorize the price change (price_change):
"SAME" if the price difference is 0,
"NULL" if the pricediff is NULL,
"UP" if the price went up,
"DOWN" if the price went down.