In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, first, sum as _sum, col, when

# Initialize Spark session
spark = SparkSession.builder \
    .appName("Insert into brands table") \
    .getOrCreate()

# Define the schema for the brands table
schema = StructType([
    StructField("category", StringType(), True),
    StructField("brand_name", StringType(), True)
])

# Sample data for the brands table
data = [
    ('chocolates', '5-star'),
    (None, 'dairy milk'),
    (None, 'perk'),
    (None, 'eclair'),
    ('Biscuits', 'Britania'),
    (None, 'good day'),
    (None, 'boost')
]

# Create a DataFrame from the data and schema
df = spark.createDataFrame(data, schema)

# Create the brands table in Spark SQL
df.createOrReplaceTempView("brands")

# Display the contents of the brands table
df.display()


category,brand_name
chocolates,5-star
,dairy milk
,perk
,eclair
Biscuits,Britania
,good day
,boost


In [0]:
%sql
with CTE as 
(
Select 
row_number() over (order by (select null))as rn,
category,Brand_name,
(case when category is null then 0 else 1 end) as m
from brands
), cte2 as 
(Select rn,category,Brand_name,sum(m) over (order by rn)as n from cte)

select rn,first_value(category)over(partition by n order by rn ) as category,Brand_name  from cte2

rn,category,Brand_name
1,chocolates,5-star
2,chocolates,dairy milk
3,chocolates,perk
4,chocolates,eclair
5,Biscuits,Britania
6,Biscuits,good day
7,Biscuits,boost


In [0]:
# Add row_number and m columns
window_spec_row = Window.orderBy(col("brand_name"))
df_with_rn = df.withColumn("rn", row_number().over(window_spec_row)) \
               .withColumn("m", when(col("category").isNull(), 0).otherwise(1))

# CTE equivalent: Step 2
# Add n column using the sum of m over window
window_spec_n = Window.orderBy("rn")
df_with_n = df_with_rn.withColumn("n", _sum("m").over(window_spec_n))

# Final query: select rn, first_value(category) over (partition by n order by rn), and brand_name
window_spec_final = Window.partitionBy("n").orderBy("rn")
df_final = df_with_n.withColumn("category", first("category").over(window_spec_final))
df_final.display()


category,brand_name,rn,m,n
chocolates,5-star,1,1,1
Biscuits,Britania,2,1,2
Biscuits,boost,3,0,2
Biscuits,dairy milk,4,0,2
Biscuits,eclair,5,0,2
Biscuits,good day,6,0,2
Biscuits,perk,7,0,2
