In [1]:
print("Hello")

Hello


In [12]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType, DateType
from pyspark.sql.functions import sum

# Step 1: Initialize Spark session
spark = SparkSession.builder \
    .appName("SchemaFilteringAggregationExample") \
    .getOrCreate()

In [13]:
# Step 2: Define schema
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("category", StringType(), True),
    StructField("amount", DoubleType(), True),
    StructField("date", StringType(), True)
])

# Step 3: Sample data (you can replace this with spark.read.csv(...))
data = [
    (1, "Electronics", 1500.0, "2024-06-01"),
    (2, "Clothing", 300.0, "2024-06-02"),
    (3, "Electronics", 700.0, "2024-06-02"),
    (4, "Grocery", 200.0, "2024-06-03"),
    (5, "Clothing", 1200.0, "2024-06-03")
]

# Step 4: Create DataFrame with enforced schema
df = spark.createDataFrame(data, schema=schema)


In [14]:
# Step 5: Filter transactions with amount >= 500
filtered_df = df.filter(df.amount >= 500)

# Step 6: Aggregate total sales per category
aggregated_df = filtered_df.groupBy("category") \
    .agg(sum("amount").alias("total_sales"))

# Step 7: Show results
aggregated_df.show()

# Optional: Stop Spark session
spark.stop()

+-----------+-----------+
|   category|total_sales|
+-----------+-----------+
|Electronics|     2200.0|
|   Clothing|     1200.0|
+-----------+-----------+

