# Optimized PySpark ML pipeline for classification (Iris)

In [3]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
from pyspark import StorageLevel
import time

# ============================================================
# CONFIGURATION
# ============================================================
DATA_PATH = "../datasets/iris.csv" 
MAX_LR_ITERATIONS = 10

# ============================================================
# START TOTAL BENCHMARK TIMING
# ============================================================
TOTAL_BENCHMARK_START = time.time()

print("=" * 70)
print("PYSPARK OPTIMIZED BENCHMARK - IRIS DATASET")
print("=" * 70)
print("Optimizations: Kryo Serializer, Explicit Schema, Caching, Tuned Partitions")

# ------------------------------------------------------------
# INITIALIZE SPARK SESSION (OPTIMIZED)
# ------------------------------------------------------------
print("\nInitializing Spark Session (with optimizations)...")
init_start = time.time()

spark = SparkSession.builder \
    .appName("IrisClassificationOptimized") \
    .master("local[*]") \
    .config("spark.driver.memory", "4g") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.kryoserializer.buffer.max", "512m") \
    .config("spark.sql.shuffle.partitions", "24") \
    .config("spark.default.parallelism", "24") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

init_time = time.time() - init_start
print(f"   Initialization time: {init_time:.2f}s")

# ------------------------------------------------------------
# LOAD DATA WITH EXPLICIT SCHEMA
# ------------------------------------------------------------
print(f"\nLoading data with explicit schema from {DATA_PATH}...")
load_start = time.time()

# Define schema explicitly (avoids inference pass)
schema = StructType([
    StructField("sepal length (cm)", DoubleType(), True),
    StructField("sepal width (cm)", DoubleType(), True),
    StructField("petal length (cm)", DoubleType(), True),
    StructField("petal width (cm)", DoubleType(), True),
    StructField("species", StringType(), True)
])

df = spark.read.csv(DATA_PATH, header=True, schema=schema)
total_records = df.count()
load_time = time.time() - load_start

print(f"   Loaded {total_records} records in {load_time:.2f}s")

# ------------------------------------------------------------
# PREPROCESS AND CACHE DATA
# ------------------------------------------------------------
print("\nPreprocessing & caching data...")
prep_start = time.time()

feature_columns = df.columns[:-1]
label_column = df.columns[-1]

# String Indexing
indexer = StringIndexer(inputCol=label_column, outputCol="label")
df_indexed = indexer.fit(df).transform(df)

# Vector Assembly
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(df_indexed).select("features", "label")

# Split Data
train_data, test_data = data.randomSplit([0.7, 0.3], seed=42)

# OPTIMIZATION: Persist with Memory+Disk
train_data.persist(StorageLevel.MEMORY_AND_DISK)
test_data.cache()

# Force materialization
train_count = train_data.count()
test_count = test_data.count()

prep_time = time.time() - prep_start

print(f"   Training size: {train_count}, Test size: {test_count}")
print(f"   Preprocessing + Caching time: {prep_time:.2f}s")

# ------------------------------------------------------------
# TRAIN AND EVALUATE MODELS
# ------------------------------------------------------------
print("\nTraining and evaluating models...")

accuracy_evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="accuracy"
)
f1_evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="f1"
)

def train_and_evaluate(model, name, total_train_records):
    print(f"\n   Training: {name}...")
    start = time.time()
    
    # Train (reads from cache)
    model_fit = model.fit(train_data)
    predictions = model_fit.transform(test_data)
    
    # Force execution
    predictions.cache()
    num_predictions = predictions.count()
    
    accuracy = accuracy_evaluator.evaluate(predictions)
    f1_score = f1_evaluator.evaluate(predictions)
    
    end = time.time()
    time_taken = end - start
    throughput = total_train_records / time_taken if time_taken > 0 else 0
    
    print(f"      [{name}]")
    print(f"      Accuracy:   {accuracy:.4f}")
    print(f"      F1 Score:   {f1_score:.4f}")
    print(f"      Time:       {time_taken:.2f}s")
    print(f"      Throughput: {throughput:,.0f} records/s")
    
    predictions.unpersist()
    return accuracy, f1_score, time_taken, throughput

results = []
training_start = time.time()

# Logistic Regression
lr_model = LogisticRegression(maxIter=MAX_LR_ITERATIONS, featuresCol="features", labelCol="label")
lr_acc, lr_f1, lr_time, lr_tput = train_and_evaluate(lr_model, "Logistic Regression", train_count)
results.append({"Model": "Logistic Regression", "Accuracy": lr_acc, "F1 Score": lr_f1, "Time": lr_time, "Throughput": lr_tput})

# Decision Tree
dt_model = DecisionTreeClassifier(featuresCol="features", labelCol="label")
dt_acc, dt_f1, dt_time, dt_tput = train_and_evaluate(dt_model, "Decision Tree", train_count)
results.append({"Model": "Decision Tree", "Accuracy": dt_acc, "F1 Score": dt_f1, "Time": dt_time, "Throughput": dt_tput})

training_time = time.time() - training_start

# ------------------------------------------------------------
# CLEANUP
# ------------------------------------------------------------
print("\nCleaning up...")
cleanup_start = time.time()

train_data.unpersist()
test_data.unpersist()

cleanup_time = time.time() - cleanup_start
print(f"   Cleanup time: {cleanup_time:.2f}s")

# ============================================================
# END TOTAL BENCHMARK TIMING
# ============================================================
TOTAL_BENCHMARK_END = time.time()
TOTAL_TIME = TOTAL_BENCHMARK_END - TOTAL_BENCHMARK_START

# ============================================================
# RESULTS SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("TIMING BREAKDOWN")
print("=" * 70)
print(f"Spark Initialization:   {init_time:>8.2f}s")
print(f"Data Loading:           {load_time:>8.2f}s")
print(f"Preprocessing+Cache:    {prep_time:>8.2f}s")
print(f"Training (both models): {training_time:>8.2f}s")
print(f"Cleanup:                {cleanup_time:>8.2f}s")
print("-" * 70)
print(f"TOTAL END-TO-END TIME:  {TOTAL_TIME:>8.2f}s")
print("=" * 70)

print("\n" + "=" * 70)
print("--- Summary of OPTIMIZED Benchmark (Iris) ---")
print("=" * 70)
print(f"{'Model':<20} | {'Acc':<6} | {'F1':<6} | {'Time (s)':<9} | {'Throughput (rec/s)':<20}")
print("-" * 70)
for res in results:
    print(f"{res['Model']:<20} | {res['Accuracy']:.4f} | {res['F1 Score']:.4f} | {res['Time']:.2f}s | {res['Throughput']:,.0f}")

print("\n" + "=" * 70)
print("COMPARISON METRICS")
print("=" * 70)
print(f"Total Job Time:         {TOTAL_TIME:.2f}s")
print(f"Training Records:       {train_count}")
print(f"Test Records:           {test_count}")
print("=" * 70)

spark.stop()

PYSPARK OPTIMIZED BENCHMARK - IRIS DATASET
Optimizations: Kryo Serializer, Explicit Schema, Caching, Tuned Partitions

Initializing Spark Session (with optimizations)...
   Initialization time: 0.15s

Loading data with explicit schema from ../datasets/iris.csv...
   Loaded 150 records in 0.10s

Preprocessing & caching data...
   Training size: 104, Test size: 46
   Preprocessing + Caching time: 0.30s

Training and evaluating models...

   Training: Logistic Regression...
      [Logistic Regression]
      Accuracy:   0.9783
      F1 Score:   0.9785
      Time:       0.49s
      Throughput: 211 records/s

   Training: Decision Tree...
      [Decision Tree]
      Accuracy:   0.9783
      F1 Score:   0.9785
      Time:       0.37s
      Throughput: 281 records/s

Cleaning up...
   Cleanup time: 0.00s

TIMING BREAKDOWN
Spark Initialization:       0.15s
Data Loading:               0.10s
Preprocessing+Cache:        0.30s
Training (both models):     0.87s
Cleanup:                    0.00s
----