In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql import Window
from pyspark.sql.functions import col, row_number, sum as spark_sum

# Initialize Spark session
spark = SparkSession.builder.appName("students_example").getOrCreate()

# Define schema for students table
schema = StructType([
    StructField("sname", StringType(), True),
    StructField("sid", StringType(), True),
    StructField("marks", IntegerType(), True)
])

# Create data for students table
data = [
    ('A', 'X', 75),
    ('A', 'Y', 75),
    ('A', 'Z', 80),
    ('B', 'X', 90),
    ('B', 'Y', 91),
    ('B', 'Z', 75)
]

# Create DataFrame for students
df_students = spark.createDataFrame(data, schema)

# Show the data
df_students.display()


sname,sid,marks
A,X,75
A,Y,75
A,Z,80
B,X,90
B,Y,91
B,Z,75


In [0]:
# Define a window specification to partition by 'sname' and order by 'marks' descending
window_spec = Window.partitionBy("sname").orderBy(col("marks").desc())

# Add the row number for each student partitioned by 'sname'
students_with_row_number = df_students.withColumn("marks_rn", row_number().over(window_spec))

# Filter the rows where 'marks_rn' is less than 3
filtered_df = students_with_row_number.filter(col("marks_rn") < 3)

# Group by 'sname' and calculate the total marks
result_df = filtered_df.groupBy("sname").agg(spark_sum("marks").alias("TotalMarks"))

# Show the result
result_df.display()


sname,TotalMarks
A,155
B,181


In [0]:
df_students.createOrReplaceTempView('students')

In [0]:
result_df = spark.sql("""
WITH student_agg AS (
    SELECT sname, marks, ROW_NUMBER() OVER (PARTITION BY sname ORDER BY marks DESC) AS marks_rn
    FROM Students
)
SELECT sname, SUM(marks) AS TotalMarks
FROM student_agg
WHERE marks_rn < 3
GROUP BY sname
""")

# Show the result
result_df.display()

sname,TotalMarks
A,155
B,181
