# The Museum of Ideas that Didn't Quite Work Out

In [None]:
# The output here was a mix of intuitive and counter-intuitive

print("--- Heatmap: Lift of Accident Severity by Consolidated Weather Condition ---")

# 1. Calculate overall prevalence of each Severity level
global_severity_counts_spark = df_categorized_weather.groupBy("Severity").count()
global_severity_counts_pd = global_severity_counts_spark.toPandas()
total_accidents = global_severity_counts_pd['count'].sum()
global_severity_proportions = global_severity_counts_pd.set_index('Severity')['count'] / total_accidents

print("\nGlobal Severity Proportions:")
print(global_severity_proportions)

# 2. Calculate counts for each Consolidated_Weather and Severity combination (as before)
severity_weather_counts = df_categorized_weather.groupBy("Consolidated_Weather", "Severity").count().orderBy("Consolidated_Weather", "Severity")
severity_weather_pd = severity_weather_counts.toPandas()

# 3. Pivot the DataFrame to get weather conditions as index and severity as columns
pivot_table = severity_weather_pd.pivot_table(
    index='Consolidated_Weather',
    columns='Severity',
    values='count',
    fill_value=0
)

# 4. Calculate P(Severity | Weather_Condition) - proportions within each weather condition
pivot_table_conditional_prob = pivot_table.div(pivot_table.sum(axis=1), axis=0) # Already in decimal, not percentage for lift calc

# 5. Calculate Lift Score for each cell: P(Severity | Weather) / P(Severity)
lift_table = pivot_table_conditional_prob.copy()
for severity_level in global_severity_proportions.index:
    if severity_level in lift_table.columns and global_severity_proportions[severity_level] > 0:
        lift_table[severity_level] = lift_table[severity_level] / global_severity_proportions[severity_level]
    else:
        # Handle cases where global proportion is 0 (shouldn't happen for 1-4) or column missing
        lift_table[severity_level] = 0 # Or handle as NaN if preferred

# Ensure severity columns are ordered correctly (1, 2, 3, 4) for consistent plotting
ordered_columns = [1, 2, 3, 4]
if all(col_name in lift_table.columns for col_name in ordered_columns):
    lift_table = lift_table[ordered_columns]
else:
    # Handle case where some severity levels might be missing (e.g., if a weather condition never had Sev1)
    # This might happen if original data is sparse or small. For our large dataset, unlikely.
    # If it happens, we'd fill missing columns with 0.
    for col_name in ordered_columns:
        if col_name not in lift_table.columns:
            lift_table[col_name] = 0
    lift_table = lift_table[ordered_columns] # Re-order


# Sort the index (Consolidated Weather Conditions) by the lift of higher severities for better insights
# We can sort by the lift of Severity 4, then Severity 3, to highlight the most "dangerous" conditions
lift_table_sorted = lift_table.sort_values(by=[4, 3], ascending=[False, False])


# Create the heatmap
plt.figure(figsize=(10, 8)) # Adjust size as needed for 17 rows x 4 columns
sns.heatmap(
    lift_table_sorted,
    annot=True,     # Show the lift values on the heatmap
    fmt=".1f",      # Format annotations to one decimal place
    cmap="viridis",  # Color map (often good for quantitative scales, "YlGnBu" or "OrRd" also work)
    linewidths=.5,  # Add lines between cells
    cbar_kws={'label': 'Lift Score (Relative Risk)'} # Color bar label
)

plt.title('Lift of Accident Severity by Consolidated Weather Condition')
plt.xlabel('Accident Severity')
plt.ylabel('Consolidated Weather Condition')
plt.yticks(rotation=0) # Keep Y-labels horizontal
plt.tight_layout()
plt.show()

print("\nLift Table (sorted by Severity 4, then Severity 3 lift):")
print(lift_table_sorted)

### Roundabout Investigation

In [None]:
# We looked for examples of "Before and afters" on roundabouts and found... 0

print("Original Start_Lat Dtype:", df.schema["Start_Lat"].dataType)
print("Original Start_Lng Dtype:", df.schema["Start_Lng"].dataType)

# Using .cast() directly is the most straightforward for numeric conversion in PySpark
df = df.withColumn("Start_Lat", col("Start_Lat").cast(DoubleType()))
df = df.withColumn("Start_Lng", col("Start_Lng").cast(DoubleType()))

print("\nAfter conversion (if needed):")
print("New Start_Lat Dtype:", df.schema["Start_Lat"].dataType)
print("New Start_Lng Dtype:", df.schema["Start_Lng"].dataType)

# Assess Decimal Place Precision 

# Get the string representation of the numbers, replace 'null' or 'NaN' with empty string
df = df.withColumn("Start_Lat_str", regexp_replace(col("Start_Lat").cast(StringType()), "null", ""))
df = df.withColumn("Start_Lng_str", regexp_replace(col("Start_Lng").cast(StringType()), "null", ""))

# Calculate decimal places:
# If there's a '.', calculate length after it. Else, 0.
df = df.withColumn(
    "Start_Lat_decimals",
    when(instr(col("Start_Lat_str"), ".") > 0, length(col("Start_Lat_str")) - instr(col("Start_Lat_str"), "."))
    .otherwise(0)
)
df = df.withColumn(
    "Start_Lng_decimals",
    when(instr(col("Start_Lng_str"), ".") > 0, length(col("Start_Lng_str")) - instr(col("Start_Lng_str"), "."))
    .otherwise(0)
)

# Get the distribution of decimal places
print("\nStart_Lat Decimal Place Distribution (All Records):")
df.groupBy("Start_Lat_decimals").count().orderBy("Start_Lat_decimals").show(truncate=False)

print("\nStart_Lng Decimal Place Distribution (All Records):")
df.groupBy("Start_Lng_decimals").count().orderBy("Start_Lng_decimals").show(truncate=False)

# Remove the temporary string and decimal columns if you wish
df = df.drop("Start_Lat_str", "Start_Lng_str", "Start_Lat_decimals", "Start_Lng_decimals")

print("\nStart_Lat Descriptive Statistics:")
df.select(col("Start_Lat")).describe().show()

print("\nStart_Lng Descriptive Statistics:")
df.select(col("Start_Lng")).describe().show()

# Check for null values
print(f"\nNull Start_Lat values: {df.filter(col('Start_Lat').isNull()).count()}")
print(f"Null Start_Lng values: {df.filter(col('Start_Lng').isNull()).count()}")

# Check for values that fall outside typical US bounds (e.g., beyond 24-50 Lat, -66 to -125 Lng)
lat_outliers = df.filter((col("Start_Lat") < 24) | (col("Start_Lat") > 50)).count()
lng_outliers = df.filter((col("Start_Lng") < -125) | (col("Start_Lng") > -66)).count()

print(f"Start_Lat outliers (outside 24-50): {lat_outliers}")
print(f"Start_Lng outliers (outside -125 to -66): {lng_outliers}")

In [None]:
df.groupBy('Roundabout').count().show()

In [None]:
print("--- Schema of DataFrame 'df' ---")
df.printSchema()

print("\n--- Distribution of 'Roundabout' column ---")
df.groupBy("Roundabout").count().show()


df = df.withColumn("Start_Lat", col("Start_Lat").cast(DoubleType()))
df = df.withColumn("Start_Lng", col("Start_Lng").cast(DoubleType()))

# Filter out rows where Start_Lat or Start_Lng are null after casting
df_filtered_coords = df.filter(col("Start_Lat").isNotNull() & col("Start_Lng").isNotNull())

# Define the precision for grouping nearby locations.
ROUNDING_PRECISION = 4

# 1. Create a "location_group_id" by rounding lat/long to group nearby accidents.
df_with_location_key = df_filtered_coords.withColumn("rounded_lat", round(col("Start_Lat"), ROUNDING_PRECISION)) \
                                         .withColumn("rounded_lng", round(col("Start_Lng"), ROUNDING_PRECISION))

# --- Debugging the potential agg issue ---
print("\n--- Schema of DataFrame 'df_with_location_key' before agg ---")
df_with_location_key.printSchema()

# 2. Identify unique rounded locations that have *both*
#    accidents where 'Roundabout' is TRUE AND accidents where 'Roundabout' is FALSE.
#    This is our robust indicator of a potentially "changed" intersection.
potential_changed_locations = df_with_location_key.groupBy("rounded_lat", "rounded_lng") \
    .agg(
        # Count accidents where 'Roundabout' is true
        spark_sum(when(col("Roundabout") == True, 1).otherwise(0)).alias("roundabout_accident_count"),
        # Count accidents where 'Roundabout' is false (meaning some other control type was present)
        spark_sum(when(col("Roundabout") == False, 1).otherwise(0)).alias("non_roundabout_accident_count")
    ) \
    .filter((col("roundabout_accident_count") > 0) & (col("non_roundabout_accident_count") > 0))

# 3. Join back to the DataFrame to get all accidents associated with these identified locations
changed_intersections_accidents = df_with_location_key.join(
    potential_changed_locations.select("rounded_lat", "rounded_lng"), # Select only the join keys
    on=["rounded_lat", "rounded_lng"],
    how="inner"
)

# 4. Count the total number of accidents at these potentially "changed" intersections
total_accidents_at_changed_intersections = changed_intersections_accidents.count()

print(f"\nTotal count of accidents at potentially 'changed' intersections (based on {ROUNDING_PRECISION} decimal rounding for location matching): {total_accidents_at_changed_intersections}")

num_unique_changed_locations = potential_changed_locations.count()
print(f"Number of unique potentially 'changed' intersection locations identified: {num_unique_changed_locations}")

## Description Investigation - Trucks, Intoxication

We looked at the description field and found it was extremely limited. The "Descriptions" are more like radio announcements about which lanes are closed; no real accident info.

In [None]:
drunk_words = "drunk|intoxicated|sobriety|alcohol|DUI"

drunk_intoxicated_accidents_count = df.filter(
    col("Description").isNotNull() &
    (col("Description").rlike(f"(?i){drunk_words}"))
).count()

print(f"\nTotal count of accidents involving {drunk_words} in the description: {drunk_intoxicated_accidents_count}")

print(f"\nSample of descriptions involving {drunk_words} (first 10):")
df.filter(
    col("Description").isNotNull() &
    (col("Description").rlike(f"(?i){drunk_words}"))
).select("Description").show(10, truncate=False)

In [None]:
# some lazy person didn't change their variable names, tsk tsk

drunk_intoxicated_accidents_count = df.filter(
    col("Description").isNotNull() &
    (col("Description").rlike("(?i)truck"))
).count()

print(f"\nTotal count of accidents involving trucks in the description: {drunk_intoxicated_accidents_count}")

print("\nSample of descriptions involving trucks (first 10):")
df.filter(
    col("Description").isNotNull() &
    (col("Description").rlike("(?i)truck"))
).select("Description").show(10, truncate=False)

## DST Investigation

Methods may be wrong as we found nothing close to statistically significant increase in accidents during periods after DST compared to before. (Some years even included a substantial drop in accidents)

In [None]:
df = df.withColumn("Date", F.to_date("Start_Time"))

# Hardcoded DST start dates for each year
dst_dates = {
    2016: "2016-03-13",
    2017: "2017-03-12",
    2018: "2018-03-11",
    2019: "2019-03-10",
    2020: "2020-03-08",
    2021: "2021-03-14",
    2022: "2022-03-13",
    2023: "2023-03-12",
}

# Collect results for each year
results = []
for year, dst_date in sorted(dst_dates.items()):  # Ensure correct year order
    before_dst = df.filter(F.col("Date").between(F.lit(dst_date) - F.expr("INTERVAL 6 DAY"), F.lit(dst_date) - F.expr("INTERVAL 4 DAY")))
    after_dst = df.filter(F.col("Date").between(F.lit(dst_date) + F.expr("INTERVAL 1 DAY"), F.lit(dst_date) + F.expr("INTERVAL 3 DAY")))

    before_count = before_dst.count()
    after_count = after_dst.count()

    results.append((year, before_count, after_count))

# Print results in ascending order
for year, before, after in results:
    print(f"{year}: Accidents Monday-Wednesday BEFORE DST: {before}, AFTER DST: {after}")



In [None]:
# Sanity check on intervals

# Example DataFrame with one test date
test_df = spark.createDataFrame([(datetime(2023, 3, 12),)], ["TestDate"]).withColumn("TestDate", F.col("TestDate").cast(DateType()))

# Applying INTERVAL shifts
test_df = test_df.withColumn("Minus_6_Days", F.col("TestDate") - F.expr("INTERVAL 6 DAY"))
test_df = test_df.withColumn("Minus_4_Days", F.col("TestDate") - F.expr("INTERVAL 4 DAY"))
test_df = test_df.withColumn("Plus_1_Day", F.col("TestDate") + F.expr("INTERVAL 1 DAY"))
test_df = test_df.withColumn("Plus_3_Days", F.col("TestDate") + F.expr("INTERVAL 3 DAY"))

# Show results
test_df.show()

In [None]:
# No one likes Mondays!

df = df.withColumn("Date", F.to_date("Start_Time"))

# Hardcoded Daylight Savings time dates (checked by calendar sampling)
dst_dates = {
    2016: "2016-03-13",
    2017: "2017-03-12",
    2018: "2018-03-11",
    2019: "2019-03-10",
    2020: "2020-03-08",
    2021: "2021-03-14",
    2022: "2022-03-13",
    2023: "2023-03-12",
}

# Collect results for each year
results = []
for year, dst_date in sorted(dst_dates.items()):
    monday_before = df.filter(F.col("Date") == F.lit(dst_date) - F.expr("INTERVAL 6 DAY"))
    monday_after = df.filter(F.col("Date") == F.lit(dst_date) + F.expr("INTERVAL 1 DAY"))

    before_count = monday_before.count()
    after_count = monday_after.count()

    results.append((year, before_count, after_count))

# Print results in ascending order
for year, before, after in results:
    print(f"{year}: Accidents on Monday BEFORE DST: {before}, AFTER DST: {after}")

## Lighting intermed steps

In [None]:
selected_highways = ["I-95", "I-5", "I-10"] # quick sample

# Filter dataset for selected highways
df_selected = df.filter(F.col("Street_minus_dir").isin(selected_highways))

# Compute county-level severity metrics
county_severity = (
    df_selected.groupBy("Street_minus_dir", "State", "County")
    .agg(
        F.sum(F.when(F.col("Sunrise_Sunset") == "Night", F.col("Severity_Score")).otherwise(0)).alias("Night_Severity"),
        F.sum("Severity_Score").alias("Total_Severity"),
        F.count(F.when(F.col("Sunrise_Sunset") == "Night", True)).alias("Night_Accident_Count")
    )
    .withColumn("Night_Severity_Ratio", F.col("Night_Severity") / F.col("Total_Severity"))
)

# Set output directory
output_dir = "created_data"
os.makedirs(output_dir, exist_ok=True)

# Export separate CSVs for each highway
for highway in selected_highways:
    highway_df = county_severity.filter(F.col("Street_minus_dir") == highway).toPandas()
    highway_df.to_csv(f"{output_dir}/{highway}_severity.csv", index=False)

print("CSV exports completed.")

In [None]:


# Define selected highways
selected_highways = ["I-95", "I-5", "I-10"]

# Filter dataset for selected highways
df_selected = df.filter(F.col("Street_minus_dir").isin(selected_highways))

# Extract accident locations for mapping
accident_locations = (
    df_selected.select("Street_minus_dir", "State", "County", "Severity", "Start_Lat", "Start_Lng")
    .filter(F.col("Start_Lat").isNotNull() & F.col("Start_Lng").isNotNull())
)

# Set output directory
output_dir = "created_data"
os.makedirs(output_dir, exist_ok=True)

# Export separate CSVs for each highway
for highway in selected_highways:
    highway_df = accident_locations.filter(F.col("Street_minus_dir") == highway).toPandas()
    highway_df.to_csv(f"{output_dir}/{highway}_accidents.csv", index=False)

print("CSV exports completed.")


In [None]:
# "That belongs in a museum!"
# Define selected highways
selected_highway = "I-5"  # Modify as needed

# Filter dataset for selected highway
df_selected = df.filter(F.col("Street_minus_dir") == selected_highway)

# Round lat/lon values for uniqueness
rounded_locations = (
    df_selected.select(
        F.round(F.col("Start_Lat"), 2).alias("Latitude"),
        F.round(F.col("Start_Lng"), 2).alias("Longitude")
    )
    .distinct()  # Keep only unique lat/lon pairs
)

# Set output directory
output_dir = "created_data"
os.makedirs(output_dir, exist_ok=True)

# Export to CSV
rounded_locations.toPandas().to_csv(f"{output_dir}/{selected_highway}_unique_locations.csv", index=False)

print(f"CSV export completed for {selected_highway}.")

### Failed covid investigation

In [None]:
print("\n--- Analyzing Accident Frequency by Calendar Day ---")

df_doy_accidents = df.withColumn("Day_of_year", dayofyear(col("Start_Time")))
daily_counts_spark = df_doy_accidents.groupBy("Day_of_year").count().orderBy("Day_of_year")

daily_counts_pd = daily_counts_spark.toPandas()

all_days = pd.DataFrame({'Day_of_year': range(365)})
daily_counts_pd = pd.merge(all_days, daily_counts_pd, on='Day_of_year', how='left').fillna(0)

plt.figure(figsize=(12, 6))
plt.hist(daily_counts_pd['Day_of_year'], weights=daily_counts_pd['count'], bins=60, color='skyblue', edgecolor='black')

plt.title('Number of Accidents by Day of Year')
plt.xlabel('Day of Year, Starting Jan 1')
plt.ylabel('Number of Accidents')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

## Clearance time investigation

Somehow it took you 6 months to clear an accident? Full of data entry errors and unusable

In [None]:
# Filter rows where Clear_Time exceeds 24 hours
outlier_df = df.filter(col("Clear_Time") > 24).select("Start_Time", "End_Time", "Clear_Time")

# Show results
outlier_df.show(truncate=False)