In [1]:
# Cell 1: Import required libraries
from pyspark.sql import SparkSession
import pandas as pd
import time
from datetime import datetime
import json
from benchmark_utils import BenchmarkResult


print("Libraries imported successfully")

# Cell 2: Initialize Spark session with proper configurations
def init_spark_session():
    """Initialize Spark session with Iceberg support"""
    print("Initializing Spark session...")
    spark = SparkSession \
        .builder \
        .appName("IcebergBenchmark") \
        .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") \
        .config("spark.sql.catalog.demo", "org.apache.iceberg.spark.SparkCatalog") \
        .config("spark.sql.catalog.demo.type", "rest") \
        .config("spark.sql.defaultCatalog", "demo") \
        .config("spark.sql.catalog.demo.uri", "http://rest:8181") \
        .config("spark.sql.catalog.demo.warehouse", "s3://warehouse/") \
        .config("spark.sql.catalog.demo.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") \
        .config("spark.sql.catalog.demo.s3.endpoint", "http://minio:9000") \
        .getOrCreate()
    
    print("Spark session initialized")
    return spark

# Initialize Spark
spark = init_spark_session()

# Cell 3: Verify existing table
print("\nVerifying existing Iceberg table:")
print("\nTotal records:")
spark.sql("SELECT COUNT(*) as count FROM demo.nyc.taxis").show()

print("\nSample data:")
spark.sql("SELECT * FROM demo.nyc.taxis LIMIT 5").show()


Libraries imported successfully
Initializing Spark session...


24/12/31 14:39:04 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


Spark session initialized

Verifying existing Iceberg table:

Total records:


                                                                                

+--------+
|   count|
+--------+
|77966324|
+--------+


Sample data:


[Stage 3:>                                                          (0 + 1) / 1]

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       2| 2023-10-12 00:11:15|  2023-10-12 00:41:14|            1.0|        10.22|       1.0|                 N|         132|          61|           1|       43.6|  1.0|    0.5|      9.5

                                                                                

In [2]:
# Cell 3: Create namespace if it doesn't exist
def ensure_namespace_exists():
    """Create the namespace if it doesn't exist"""
    print("\nEnsuring namespace exists...")
    
    # Create namespace
    spark.sql("CREATE NAMESPACE IF NOT EXISTS demo.nyc")
    
    # Verify catalogs and namespaces
    print("\nAvailable catalogs:")
    spark.sql("SHOW CATALOGS").show()
    
    print("\nAvailable namespaces in demo:")
    spark.sql("SHOW NAMESPACES IN demo").show()

ensure_namespace_exists()


Ensuring namespace exists...

Available catalogs:
+-------------+
|      catalog|
+-------------+
|         demo|
|spark_catalog|
+-------------+


Available namespaces in demo:
+---------+
|namespace|
+---------+
|      nyc|
+---------+



In [3]:
# Cell 4: Run simple query
def run_simple_query(spark):
    """Run and time a simple aggregation query"""
    query = """
        SELECT
            COUNT(*) AS total_records,
            AVG(trip_distance) AS avg_trip_distance,
            MAX(total_amount) AS max_total_amount,
            MIN(fare_amount) AS min_fare_amount
        FROM demo.nyc.taxis
    """
    
    print("Executing simple aggregation query...")
    print("\nQuery:")
    print(query)
    
    start_time = time.time()
    result = spark.sql(query)
    df = result.toPandas()
    duration = time.time() - start_time
    
    print(f"\nQuery completed in {duration:.2f} seconds")
    print("\nResults:")
    print(df)
    
    return {"duration": duration, "result": df}

# Run simple query
simple_results = run_simple_query(spark)


Executing simple aggregation query...

Query:

        SELECT
            COUNT(*) AS total_records,
            AVG(trip_distance) AS avg_trip_distance,
            MAX(total_amount) AS max_total_amount,
            MIN(fare_amount) AS min_fare_amount
        FROM demo.nyc.taxis
    





Query completed in 7.73 seconds

Results:
   total_records  avg_trip_distance  max_total_amount  min_fare_amount
0       77966324           5.040317         401095.62     -133391414.0


                                                                                

In [4]:
# Cell 5: Run complex query
def run_complex_query(spark):
    """Run and time a complex analysis query"""
    query = """
        SELECT 
            payment_type,
            COUNT(*) as trip_count,
            AVG(total_amount) as avg_total_amount,
            MAX(tip_amount) as max_tip_amount,
            SUM(CASE WHEN passenger_count > 1 THEN 1 ELSE 0 END) as shared_rides
        FROM demo.nyc.taxis
        WHERE trip_distance > 2 AND total_amount > 0
        GROUP BY payment_type
        ORDER BY trip_count DESC
    """
    
    print("Executing complex analysis query...")
    print("\nQuery:")
    print(query)
    
    start_time = time.time()
    result = spark.sql(query)
    df = result.toPandas()
    duration = time.time() - start_time
    
    print(f"\nQuery completed in {duration:.2f} seconds")
    print("\nResults:")
    print(df)
    
    return {"duration": duration, "result": df}

# Run complex query
complex_results = run_complex_query(spark)

Executing complex analysis query...

Query:

        SELECT 
            payment_type,
            COUNT(*) as trip_count,
            AVG(total_amount) as avg_total_amount,
            MAX(tip_amount) as max_tip_amount,
            SUM(CASE WHEN passenger_count > 1 THEN 1 ELSE 0 END) as shared_rides
        FROM demo.nyc.taxis
        WHERE trip_distance > 2 AND total_amount > 0
        GROUP BY payment_type
        ORDER BY trip_count DESC
    


                                                                                


Query completed in 9.76 seconds

Results:
   payment_type  trip_count  avg_total_amount  max_tip_amount  shared_rides
0             1    27504582         37.871267         1400.16       6696053
1             2     5922309         33.125444           32.00       1745599
2             0     1603076         35.306151           92.35             0
3             4      134663         42.777958           95.00         32343
4             3       90810         33.755191           90.00         20430


In [5]:
def run_advanced_analysis_spark(spark):
    """Run advanced fare analysis with trend detection and interpolation for Spark
    
    Args:
        spark: Spark session
    Returns:
        dict: Contains duration and results dataframe
    """
    query = """
    WITH fare_stats AS (
        SELECT 
            date_trunc('month', tpep_pickup_datetime) as month,
            payment_type,
            AVG(fare_amount) as avg_base_fare,
            AVG(total_amount) as avg_total_fare,
            AVG(tip_amount) as avg_tip,
            COUNT(*) as num_rides,
            STDDEV(fare_amount) as fare_stddev
        FROM demo.nyc.taxis
        WHERE fare_amount > 0 
        AND total_amount > 0
        AND YEAR(tpep_pickup_datetime) >= 2022
        GROUP BY 
            date_trunc('month', tpep_pickup_datetime),
            payment_type
    ),
    interpolated_points AS (
        SELECT 
            month,
            payment_type,
            avg_base_fare,
            avg_total_fare,
            CASE 
                WHEN avg_tip IS NULL THEN 
                    LAG(avg_tip, 1) OVER (PARTITION BY payment_type ORDER BY month) +
                    (LEAD(avg_tip, 1) OVER (PARTITION BY payment_type ORDER BY month) -
                     LAG(avg_tip, 1) OVER (PARTITION BY payment_type ORDER BY month)) * 0.5
                ELSE avg_tip
            END as interpolated_tip,
            num_rides,
            fare_stddev,
            avg_base_fare - LAG(avg_base_fare) OVER (PARTITION BY payment_type ORDER BY month) as fare_change,
            100.0 * (avg_base_fare - LAG(avg_base_fare) OVER (PARTITION BY payment_type ORDER BY month)) / 
                NULLIF(LAG(avg_base_fare) OVER (PARTITION BY payment_type ORDER BY month), 0) as fare_change_pct
        FROM fare_stats
    )
    SELECT 
        month,
        payment_type,
        ROUND(avg_base_fare, 2) as avg_base_fare,
        ROUND(avg_total_fare, 2) as avg_total_fare,
        ROUND(interpolated_tip, 2) as interpolated_tip,
        num_rides,
        ROUND(fare_stddev, 2) as fare_stddev,
        ROUND(fare_change, 2) as fare_change,
        ROUND(fare_change_pct, 1) as fare_change_pct,
        CASE 
            WHEN fare_change_pct > 5 THEN 'Strong Increase'
            WHEN fare_change_pct BETWEEN 1 AND 5 THEN 'Moderate Increase'
            WHEN fare_change_pct BETWEEN -1 AND 1 THEN 'Stable'
            WHEN fare_change_pct BETWEEN -5 AND -1 THEN 'Moderate Decrease'
            ELSE 'Strong Decrease'
        END as trend_category
    FROM interpolated_points
    ORDER BY month DESC, payment_type;
    """
    
    print("Executing advanced fare trend analysis (Spark)...")
    start_time = time.time()
    
    try:
        result = spark.sql(query)
        df = result.toPandas()
        duration = time.time() - start_time
        
        print(f"\nAnalysis completed in {duration:.2f} seconds")
        print(f"Total records analyzed: {len(df):,}")
        print("\nResults preview (most recent month):")
        print(df.head())
        
        # Calculate summary statistics
        print("\nSummary Statistics:")
        print(f"Average base fare: ${df['avg_base_fare'].mean():.2f}")
        print(f"Average total fare: ${df['avg_total_fare'].mean():.2f}")
        print(f"Average tip: ${df['interpolated_tip'].mean():.2f}")
        print("\nTrend Distribution:")
        print(df['trend_category'].value_counts())
        
        return {"duration": duration, "result": df}
    
    except Exception as e:
        print(f"Error executing Spark analysis: {str(e)}")
        return None
# Run advanced analysis
advanced_results = run_advanced_analysis_spark(spark)



Executing advanced fare trend analysis (Spark)...


                                                                                


Analysis completed in 16.22 seconds
Total records analyzed: 125

Results preview (most recent month):
       month  payment_type  avg_base_fare  avg_total_fare  interpolated_tip  \
0 2024-01-01             1          72.96           85.72              8.67   
1 2024-01-01             2           6.50           11.50              0.00   
2 2023-12-01             0          22.90           28.99              1.77   
3 2023-12-01             1          20.02           30.12              4.49   
4 2023-12-01             2          20.10           25.41              0.00   

   num_rides  fare_stddev  fare_change  fare_change_pct   trend_category  
0          5        40.06        52.94            264.4  Strong Increase  
1          1          NaN       -13.60            -67.7  Strong Decrease  
2     176966        14.13         1.11              5.1  Strong Increase  
3    2574324        17.90        -0.04             -0.2           Stable  
4     536489        19.89         0.11         

In [6]:
# Cell 6: Save benchmark results
result = BenchmarkResult("Spark", spark.version)

# For simple query:
total_records = simple_results["result"]["total_records"].iloc[0]  # Get the COUNT(*) result
result.add_query_result(
    "simple_aggregation",
    simple_results["duration"],
    total_records,
    simple_results["result"]
)

# For complex query:
total_records = complex_results["result"]["trip_count"].sum()  # Sum all group counts
result.add_query_result(
    "complex_analysis",
    complex_results["duration"],
    total_records,
    complex_results["result"]
)

# Add to benchmark results
result.add_query_result(
    "advanced_analysis",
    advanced_results["duration"],
    advanced_results["result"]["num_rides"].sum(),
    advanced_results["result"]
)

# Save results (automatically captures resource usage)
result.save('spark_benchmark_results.json')

Benchmark results saved to spark_benchmark_results.json


In [7]:
# Cell 7: Clean up
spark.stop()
print("\nSpark session stopped")


Spark session stopped
