In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, count, lit
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

# Initialize Spark Session very similar to previous collision pipeline
spark = (SparkSession.builder
         .appName("Collision_Road_Analysis")
         .config("spark.executor.memory", "4g")
         .config("spark.executor.cores", "2")
         .config("spark.driver.memory", "4g")
         .getOrCreate())

# Load cleaned data
collision_df = spark.read.csv("clean_collision_records.csv", header=True, inferSchema=True)
victim_df = spark.read.csv("clean_victim_records.csv", header=True, inferSchema=True)

# Join collision and victim data on CASE_ID
combined_df = collision_df.join(victim_df, "CASE_ID", "inner")

# Add binary injury column
combined_df = combined_df.withColumn(
    "INJURY_SEVERITY_BINARY",
    when((col("VICTIM_DEGREE_OF_INJURY").isin(1, 2)), 1).otherwise(0)
)

# Drop unused columns and handle nulls
cols_to_drop = ['REPORTING_DISTRICT']
combined_df = combined_df.drop(*cols_to_drop).dropna()

# Aggregate road data
road_columns = ["PRIMARY_RD", "CHP_BEAT_TYPE", "DISTANCE"]
accident_df = combined_df.groupBy(road_columns).agg(count("*").alias("ACCIDENT_COUNT"))

# Create binary target: HIGH_ACCIDENT (1 = high accident count)
threshold = 5
accident_df = accident_df.withColumn(
    "HIGH_ACCIDENT",
    when(col("ACCIDENT_COUNT") > threshold, 1).otherwise(0)
)

# Handle Class Imbalance Similar to collisions pipeline code 
total_count = accident_df.count()
class_1_count = accident_df.filter(col("HIGH_ACCIDENT") == 1).count()
class_0_count = total_count - class_1_count

weight_1 = total_count / (2 * class_1_count)
weight_0 = total_count / (2 * class_0_count)

# Add weights
accident_df = accident_df.withColumn(
    "weight",
    when(col("HIGH_ACCIDENT") == 1, lit(weight_1)).otherwise(lit(weight_0))
)

# Feature Engineering Pipeline
indexers = [StringIndexer(inputCol=col, outputCol=col + "_index", handleInvalid='skip') for col in road_columns]
assembler = VectorAssembler(inputCols=[col + "_index" for col in road_columns], outputCol="features")

# Logistic Regression Model
logistic_regression = LogisticRegression(
    featuresCol="features", 
    labelCol="HIGH_ACCIDENT", 
    weightCol="weight"
)

# Hyperparameter Tuning
paramGrid = (ParamGridBuilder()
             .addGrid(logistic_regression.regParam, [0.001, 0.01, 0.1])
             .addGrid(logistic_regression.elasticNetParam, [0.0, 0.5, 1.0])
             .addGrid(logistic_regression.maxIter, [10, 50, 100])
             .build())

# Cross-validation pipeline
evaluator = BinaryClassificationEvaluator(labelCol="HIGH_ACCIDENT", metricName="areaUnderROC")
pipeline = Pipeline(stages=indexers + [assembler, logistic_regression])

cv = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3
)

# Train-test split
train_df, test_df = accident_df.randomSplit([0.8, 0.2], seed=42)

# Train the model
cv_model = cv.fit(train_df)

# Evaluate on test data
predictions = cv_model.transform(test_df)
auc = evaluator.evaluate(predictions)

# Show Results
print(f"Optimized AUC Score: {auc}")
predictions.select(road_columns + ["ACCIDENT_COUNT", "HIGH_ACCIDENT", "prediction"]).show(10, truncate=False)


24/12/17 13:25:50 WARN Utils: Your hostname, Victors-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.68.65 instead (on interface en0)
24/12/17 13:25:50 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/17 13:25:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/12/17 13:25:51 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
24/12/17 13:26:06 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/12/17 13:26:07 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerati

Optimized AUC Score: 0.6918992129069795


24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/12/17 13:56:52 WARN RowBasedKeyValueBatch: Calling spill() on

+------------+-------------+--------+--------------+-------------+----------+
|PRIMARY_RD  |CHP_BEAT_TYPE|DISTANCE|ACCIDENT_COUNT|HIGH_ACCIDENT|prediction|
+------------+-------------+--------+--------------+-------------+----------+
|1 POWER SPUR|0            |38      |12            |1            |1.0       |
|1 POWER SPUR|0            |160     |2             |0            |1.0       |
|100TH AV    |0            |35      |16            |1            |1.0       |
|100TH AV    |0            |607     |24            |1            |1.0       |
|101ST AV    |0            |39      |12            |1            |0.0       |
|101ST ST    |0            |372     |12            |1            |1.0       |
|103RD AV    |0            |30      |12            |1            |1.0       |
|103RD AV    |0            |40      |2             |0            |1.0       |
|103RD AV    |0            |100     |36            |1            |1.0       |
|103RD ST    |0            |19      |18            |1           

24/12/17 13:56:57 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
                                                                                