In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.sql.functions import col, length

# Define data and schema
CollisionRecords = [
    "20160924_CollisionRecords.txt",
    "20170112_CollisionRecords.txt",
    "20180925_CollisionRecords.txt",
    "20201024_CollisionRecords.txt"
]

collision_schema = StructType([
    StructField("CASE_ID", StringType(), True),
    StructField("ACCIDENT_YEAR", IntegerType(), True),
    StructField("PROC_DATE", StringType(), True),
    StructField("JURIS", StringType(), True),
    StructField("COLLISION_DATE", StringType(), True),
    StructField("COLLISION_TIME", StringType(), True),
    StructField("OFFICER_ID", StringType(), True),
    StructField("REPORTING_DISTRICT", StringType(), True),
    StructField("DAY_OF_WEEK", IntegerType(), True),
    StructField("CHP_SHIFT", StringType(), True),
    StructField("POPULATION", IntegerType(), True),
    StructField("CNTY_CITY_LOC", StringType(), True),
    StructField("SPECIAL_COND", StringType(), True),
    StructField("BEAT_TYPE", StringType(), True),
    StructField("CHP_BEAT_TYPE", StringType(), True),
    StructField("CITY_DIVISION_LAPD", StringType(), True),
    StructField("CHP_BEAT_CLASS", StringType(), True),
    StructField("BEAT_NUMBER", StringType(), True),
    StructField("PRIMARY_RD", StringType(), True),
    StructField("SECONDARY_RD", StringType(), True),
    StructField("DISTANCE", IntegerType(), True),
    StructField("DIRECTION", StringType(), True),
    StructField("INTERSECTION", StringType(), True),
    StructField("LATITUDE", DoubleType(), True),
    StructField("LONGITUDE", DoubleType(), True),
])

# Start Spark session
spark = SparkSession.builder.appName("CollisionRecords").getOrCreate()

# Load the data
collision_df = spark.read.csv(CollisionRecords, schema=collision_schema, header=False)

# Drop rows where CASE_ID matches the header
header = collision_df.first().asDict()
collision_df = collision_df.filter(col("CASE_ID") != header["CASE_ID"])

# Step 1: Display non-numerical columns to the user
non_numerical_columns = [f.name for f in collision_schema.fields if isinstance(f.dataType, StringType)]
print("Non-Numerical Columns: ", non_numerical_columns)



# Step 3: Drop columns with all null values
columns_with_all_nulls = [col_name for col_name in collision_df.columns if collision_df.filter(col(col_name).isNotNull()).count() == 0]
collision_df = collision_df.drop(*columns_with_all_nulls)

# Handle nulls in critical columns
collision_df = collision_df.fillna({
    "ACCIDENT_YEAR": 0,
    "DAY_OF_WEEK": 0,
    "POPULATION": 0,
    "DISTANCE": 0
})

# Ensure DAY_OF_WEEK values are valid
collision_df = collision_df.filter((col("DAY_OF_WEEK") >= 1) & (col("DAY_OF_WEEK") <= 7))

# Debug: Check for corrupted data
collision_df.filter(col("CASE_ID").rlike("[^a-zA-Z0-9]")).show(5)
collision_df.filter(length(col("CASE_ID")) > 50).show(5)

# Debug: Reduce dataset for testing
collision_df = collision_df.limit(10)

# Show cleaned data
print("Cleaned Data Info:")
collision_df.printSchema()
collision_df.show(10)


Non-Numerical Columns:  ['CASE_ID', 'PROC_DATE', 'JURIS', 'COLLISION_DATE', 'COLLISION_TIME', 'OFFICER_ID', 'REPORTING_DISTRICT', 'CHP_SHIFT', 'CNTY_CITY_LOC', 'SPECIAL_COND', 'BEAT_TYPE', 'CHP_BEAT_TYPE', 'CITY_DIVISION_LAPD', 'CHP_BEAT_CLASS', 'BEAT_NUMBER', 'PRIMARY_RD', 'SECONDARY_RD', 'DIRECTION', 'INTERSECTION']
+-------+-------------+---------+-----+--------------+--------------+----------+------------------+-----------+----------+-------------+---------+-------------+----------+------------+--------+---------+------------+
|CASE_ID|ACCIDENT_YEAR|PROC_DATE|JURIS|COLLISION_DATE|COLLISION_TIME|OFFICER_ID|REPORTING_DISTRICT|DAY_OF_WEEK|POPULATION|CNTY_CITY_LOC|BEAT_TYPE|CHP_BEAT_TYPE|PRIMARY_RD|SECONDARY_RD|DISTANCE|DIRECTION|INTERSECTION|
+-------+-------------+---------+-----+--------------+--------------+----------+------------------+-----------+----------+-------------+---------+-------------+----------+------------+--------+---------+------------+
| 449906|         2002| 20030

In [2]:


# Fix negative distance values 
collision_df = collision_df.filter(col("DISTANCE") >= 0)
collision_df.describe().show()



+-------+--------------------+-------------+------------------+-----+-----------------+----------------+-----------------+------------------+-----------------+----------+-------------+---------+-------------+------------+------------+-----------------+---------+------------+
|summary|             CASE_ID|ACCIDENT_YEAR|         PROC_DATE|JURIS|   COLLISION_DATE|  COLLISION_TIME|       OFFICER_ID|REPORTING_DISTRICT|      DAY_OF_WEEK|POPULATION|CNTY_CITY_LOC|BEAT_TYPE|CHP_BEAT_TYPE|  PRIMARY_RD|SECONDARY_RD|         DISTANCE|DIRECTION|INTERSECTION|
+-------+--------------------+-------------+------------------+-----+-----------------+----------------+-----------------+------------------+-----------------+----------+-------------+---------+-------------+------------+------------+-----------------+---------+------------+
|  count|                  10|           10|                10|   10|               10|              10|               10|                10|               10|        10|  

In [3]:
# Filter rows where PRIMARY_RD or SECONDARY_RD contains "..."
invalid_rd_cases = collision_df.filter(
    (col("PRIMARY_RD").contains("...")) | (col("SECONDARY_RD").contains("..."))
)

invalid_rd_cases.show(truncate=False)



+-------+-------------+---------+-----+--------------+--------------+----------+------------------+-----------+----------+-------------+---------+-------------+----------+------------+--------+---------+------------+
|CASE_ID|ACCIDENT_YEAR|PROC_DATE|JURIS|COLLISION_DATE|COLLISION_TIME|OFFICER_ID|REPORTING_DISTRICT|DAY_OF_WEEK|POPULATION|CNTY_CITY_LOC|BEAT_TYPE|CHP_BEAT_TYPE|PRIMARY_RD|SECONDARY_RD|DISTANCE|DIRECTION|INTERSECTION|
+-------+-------------+---------+-----+--------------+--------------+----------+------------------+-----------+----------+-------------+---------+-------------+----------+------------+--------+---------+------------+
+-------+-------------+---------+-----+--------------+--------------+----------+------------------+-----------+----------+-------------+---------+-------------+----------+------------+--------+---------+------------+



In [4]:


# remove all "not stated - values" and null
filterCondition = (col(collision_df.columns[0]) != "-")
filterCondition = filterCondition | (col(collision_df.columns[0]) != None)


for c in collision_df.columns[1:]:
    filterCondition = filterCondition | (col(c) != "-")
    filterCondition = filterCondition | (col(c) != None)

filtered_df = collision_df.filter(filterCondition)
filtered_df.head(5)

[Row(CASE_ID='0100010101011401155', ACCIDENT_YEAR=2001, PROC_DATE='20010416', JURIS='0100', COLLISION_DATE='20010101', COLLISION_TIME='0114', OFFICER_ID='1155', REPORTING_DISTRICT='0', DAY_OF_WEEK=1, POPULATION=4, CNTY_CITY_LOC='0198', BEAT_TYPE='0', CHP_BEAT_TYPE='0', PRIMARY_RD='DUBLIN BL', SECONDARY_RD='SCARLETT CT', DISTANCE=267, DIRECTION='W', INTERSECTION='N'),
 Row(CASE_ID='0100010103174503131', ACCIDENT_YEAR=2001, PROC_DATE='20010416', JURIS='0100', COLLISION_DATE='20010103', COLLISION_TIME='1745', OFFICER_ID='3131', REPORTING_DISTRICT='10', DAY_OF_WEEK=3, POPULATION=4, CNTY_CITY_LOC='0198', BEAT_TYPE='0', CHP_BEAT_TYPE='0', PRIMARY_RD='DOUGHERTY RD', SECONDARY_RD='AMADOR VLY BL', DISTANCE=80, DIRECTION='N', INTERSECTION='N'),
 Row(CASE_ID='0100010104134002415', ACCIDENT_YEAR=2001, PROC_DATE='20010608', JURIS='0100', COLLISION_DATE='20010104', COLLISION_TIME='1340', OFFICER_ID='2415', REPORTING_DISTRICT='0', DAY_OF_WEEK=4, POPULATION=4, CNTY_CITY_LOC='0198', BEAT_TYPE='0', CHP_

In [5]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml import Pipeline
encoded_and_index = ["INTERSECTION","SPECIAL_COND","CHP_BEAT_TYPE","BEAT_TYPE","DIRECTION","CHP_BEAT_CLASS","CHP_SHIFT"]

indexers = [StringIndexer(inputCol=c, outputCol=c+"_index") for c in encoded_and_index]
encoders = [OneHotEncoder(inputCol=c+"_index", outputCol=c+"_vec") for c in encoded_and_index]

encoded_df = filtered_df

for col_name in encoded_and_index:
    
    indexer = StringIndexer(inputCol=col_name, outputCol=col_name + "_index")
    print(indexer)
    print(col_name)
    encoded_df = indexer.fit(encoded_df).transform(encoded_df)
   
    encoder = OneHotEncoder(inputCol=col_name + "_index", outputCol=col_name + "_vec")
    encoded_df = encoder.fit(encoded_df).transform(encoded_df)


columns_to_drop = ["CHP_SHIFT", "CITY_DIVISION_LAPD", "SPECIAL_COND", "CITY_DIVISION_LAPD", "CHP_BEAT_CLASS", "BEAT_NUMBER" ] 

encoded_df = encoded_df.drop(*columns_to_drop)


encoded_df.show(truncate=False)

StringIndexer_a05346a8780f
INTERSECTION
StringIndexer_feedb56d0808
SPECIAL_COND


Py4JJavaError: An error occurred while calling o344.fit.
: org.apache.spark.SparkException: Input column SPECIAL_COND does not exist.
	at org.apache.spark.ml.feature.StringIndexerBase.$anonfun$validateAndTransformSchema$2(StringIndexer.scala:128)
	at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
	at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
	at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
	at org.apache.spark.ml.feature.StringIndexerBase.validateAndTransformSchema(StringIndexer.scala:123)
	at org.apache.spark.ml.feature.StringIndexerBase.validateAndTransformSchema$(StringIndexer.scala:115)
	at org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema(StringIndexer.scala:145)
	at org.apache.spark.ml.feature.StringIndexer.transformSchema(StringIndexer.scala:252)
	at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:71)
	at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:237)
	at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:145)
	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103)
	at java.base/java.lang.reflect.Method.invoke(Method.java:580)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:1570)
