# Traffic Collision Data Analysis

# Objective

In this case study, you will be working on California Traffic Collision Data Analysis using Apache Spark, a powerful distributed computing framework designed for big data processing. This assignment aims to provide hands-on experience in analyzing large-scale traffic collision datasets using PySpark and AWS services. You will apply data analytics techniques to clean, transform, and explore crash data, drawing meaningful insights to support traffic safety and urban planning. Beyond understanding how big data tools optimize performance on a single machine and across clusters, you will develop a structured approach to analyzing crash trends, identifying high-risk locations, and evaluating contributing factors to traffic incidents. Additionally, you will utilize AWS S3 to store the processed data efficiently after the ETL process, enabling scalable storage and easy retrieval for further analysis.


# Business Value:

Traffic collisions pose significant risks to public safety, requiring continuous monitoring and analysis to enhance road safety measures. Government agencies, city planners, and policymakers must leverage data-driven insights to improve infrastructure, optimize traffic management, and implement preventive measures.

In this assignment, you will analyze California traffic collision data to uncover patterns related to accident severity, location-based risks, and key contributing factors. With Apache Spark's ability to handle large datasets efficiently and AWS S3's scalable storage, transportation authorities can process vast amounts of crash data in real time, enabling faster and more informed decision-making.

As an analyst examining traffic safety trends, your task is to analyze historical crash data to derive actionable insights that can drive policy improvements and safety interventions. Your analysis will help identify high-risk areas, categorize accidents by severity and contributing factors, and store the processed data in an AWS S3 bucket for scalable and long-term storage.

By leveraging big data analytics and cloud-based storage, urban planners and traffic authorities can enhance road safety strategies, reduce accident rates, and improve public transportation planning.


# Dataset Overview

The dataset used in this analysis consists of California traffic collision data obtained from the Statewide Integrated Traffic Records System (SWITRS). It includes detailed records of traffic incidents across California, covering various attributes such as location, severity, involved parties, and contributing factors. The dataset has been preprocessed and transformed using PySpark to facilitate large-scale analysis. By leveraging Apache Spark, we ensure efficient data handling, enabling deeper insights into traffic patterns, accident trends, and potential safety improvements.

The dataset is a .sqlite file contains detailed information about traffic collisions across California and is structured into four primary tables:
- `collisions` table contains information about the collision, where it happened, what vehicles were involved.

- `parties` table contains information about the groups people involved in the collision including age, sex, and sobriety.

- `victims` table contains information about the injuries of specific people involved in the collision.

- `locations` table contains information about the geographical location and details of road intersections.

# Assignment Tasks

<ol>
    <li>
        <strong>Data Preparation</strong></br>
        The dataset consists of structured tables containing traffic collision data. Before conducting any analysis, it is essential to ensure that the data is properly formatted and structured for efficient processing.</br>
        Check for data consistency and ensure all columns are correctly formatted.</br>
        Apply sampling techniques if needed to extract a representative subset for analysis.</br>
        Structure and prepare the data for further processing and analysis.</br>
    </br>
    <li>
        <strong>Data Cleaning</strong></br>
            2.1 <strong>Fixing Columns:</strong> Ensure all columns are properly named and formatted.</br>
            2.2 <strong>Handling Missing Values:</strong> Decide on an approach to handle missing data (e.g., imputation or removal). Mention the approach in your report.</br>
            2.3 <strong>Handling Outliers:</strong> Identify outliers in the dataset and explain why they are considered outliers. It is not necessary to remove them for this task, but mention your approach for handling them.</br>
    </br>
    <li>
        <strong>Exploratory Data Analysis</strong></br>
        Finding Patterns and analyze the dataset and find patterns based on the following points:
                <ul>
                3.1 Classify variables into categorical and numerical types.</br>
                3.2 Analyze the distribution of collision severity.</br>
                3.3 Examine weather conditions during collisions.</br>
                3.4 Analyze the distribution of victim ages.</br>
                3.5 Study the relationship between collision severity and the number of victims.</br>
                3.6 Analyze the correlation between weather conditions and collision severity.</br>
                3.7 Visualize the impact of lighting conditions on collision severity.</br>
                3.8 Extract and analyze weekday-wise collision trends.</br>
                3.9 Assess the number of collisions occurring on different days of the week.</br>
                3.10 Study spatial distribution of collisions by county.</br>
                3.11 Generate a scatter plot to analyze collision locations geographically.</br>
                3.12 Extract and analyze collision trends over time, including yearly, monthly, and hourly trends.</br>
</ul>
</br>
<li>
<strong>ETL Querying</strong><br>
Write PySpark SQL queries for the following:</br>
<ol>
    4.1. Load the processed dataset as CSV files in S3 bucket.</br>
    4.2. Identify the top 5 counties with the highest number of collisions.</br>
    4.3. Identify the month with the highest number of collisions.</br>
    4.4. Determine the most common weather condition during collisions.</br>
    4.5. Calculate the percentage of collisions that resulted in fatalities.</br>
    4.6. Find the most dangerous time of day for collisions.</br>
    4.7. Identify the top 5 road surface conditions with the highest collision frequency.</br>
    4.8. Analyze lighting conditions that contribute to the highest number of collisions.</br>
</ol>
</br>
<li>
<strong>Conclusion</strong></br>
Provide final insights and recommendations based on the analysis:
    <ul>
        5.1 Recommendations to improve road safety by identifying high-risk locations and peak accident times for infrastructure improvements.</br>
        5.2 Suggestions to optimize traffic management by analyzing trends in collision severity, weather conditions, and lighting to improve road design and traffic signal timing.</br>
        5.3 Propose data-driven policy changes to enhance pedestrian and cyclist safety based on collision trends involving vulnerable road users.</br>
        5.4 Identify potential high-risk zones for proactive intervention by examining geographic collision density and historical accident data.</br>
        5.5 Assess the impact of environmental factors such as weather, road surface conditions, and lighting on accident frequency and severity.</br>
        5.6 Develop predictive models to anticipate collision hotspots and support proactive safety measures.</br>
        </ul>
        Conclude the analysis by summarizing key findings and business implications.</br>
        Explain the results of univariate, segmented univariate, and bivariate analyses in real-world traffic safety and policy terms.</br>
        Include visualizations and summarize the most important results in the report. Insights should explain why each variable is important and how they can influence traffic safety policies and urban planning.</br>
        </ul>
        </br>
    <li>
    <strong>Visualization Integration [Optional]</strong>
    <p>Enhance the project by incorporating a visualization component that connects the processed data stored in an S3 bucket to a business intelligence tool such as Tableau or Power BI. This involves setting up the connection between the S3 bucket and the chosen visualization tool, importing the processed dataset for analysis and visualization, creating interactive dashboards to explore key trends and insights and ensuring data updates are reflected dynamically in the visualization tool.<br>
</br>
</ol>

In [1]:
#from google.colab import drive
#drive.mount('/content/drive')

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
0,application_1750136211518_0001,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Install Required Libraries

### As pip install fails from JupyterLab, following installation steps were done
#### 1. Create a shell file as below and upload to S3.
#### 2. Point Bootstrap actions to this S3 while creating EMR cluster.

In [2]:
#!/bin/bash

# Install pip for Python 3 (Amazon Linux 2023 compatible)
# sudo dnf install -y python3-pip

# Upgrade pip and install numpy
# sudo python3 -m pip install --upgrade pip
# sudo python3 -m pip install numpy
# sudo python3 -m pip install Pillow
# sudo python3 -m pip install --quiet matplotlib seaborn
# sudo python3 -m pip install boto3
# sudo python3 -m pip install IPython

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
# Import the necessary libraries
import sqlite3
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, to_date, max, length
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, BooleanType, \
     DoubleType, LongType, DateType, TimestampType, DecimalType

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings(action='ignore')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Start the Spark Context

In [4]:
sc

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<SparkContext master=yarn appName=livy-session-0>

# **1. Data Preparation** <font color = red>[5 marks]</font> <br>

The dataset consists of structured tables containing traffic collision data.

Before conducting any analysis, it is essential to ensure that the data is properly formatted and structured for efficient processing.

Check for data consistency and ensure all columns are correctly formatted.

### 1.1 Read Victims Sample CSV File

In [5]:
# Write code to load the data and check the schema
victims_df = spark.read.load("/user/livy/sample_victims.csv", format = "csv", header = "true", inferSchema = "true")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
print("Victims Schema as read by PySpark:\n")
victims_df.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Victims Schema as read by PySpark:

root
 |-- id: integer (nullable = true)
 |-- case_id: double (nullable = true)
 |-- party_number: integer (nullable = true)
 |-- victim_role: string (nullable = true)
 |-- victim_sex: string (nullable = true)
 |-- victim_age: double (nullable = true)
 |-- victim_degree_of_injury: string (nullable = true)
 |-- victim_seating_position: string (nullable = true)
 |-- victim_safety_equipment_1: string (nullable = true)
 |-- victim_safety_equipment_2: string (nullable = true)
 |-- victim_ejected: string (nullable = true)

In [7]:
victims_df.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+-----------+------------+-----------+----------+----------+-----------------------+-----------------------+-------------------------+-------------------------+--------------+
|     id|    case_id|party_number|victim_role|victim_sex|victim_age|victim_degree_of_injury|victim_seating_position|victim_safety_equipment_1|victim_safety_equipment_2|victim_ejected|
+-------+-----------+------------+-----------+----------+----------+-----------------------+-----------------------+-------------------------+-------------------------+--------------+
|2166324|  6324923.0|           1|     driver|      male|      62.0|      complaint of pain|                 driver|     air bag not deployed|     lap/shoulder harn...|   not ejected|
| 779470|   761473.0|           1|  passenger|      male|      25.0|              no injury|       passenger seat 3|     lap/shoulder harn...|                     NULL|   not ejected|
| 742636|   723123.0|           1|     driver|      NULL|      31.0|      compla

#### Inspect the colums 'case_id' and 'victim_age'

In [8]:
victims_df.select("case_id").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+
|             case_id|
+--------------------+
|           8474612.0|
|           7130714.0|
|           4714551.0|
|           5866798.0|
|           1304870.0|
|            990175.0|
|           5368293.0|
|           6233401.0|
|3.711010106104204...|
|            470926.0|
|           4845793.0|
|           8044503.0|
|           7071072.0|
|3.404010717121100...|
|         9.0017906E7|
|           8574868.0|
|           5392691.0|
|           1567280.0|
|           6304776.0|
|           2504632.0|
+--------------------+
only showing top 20 rows

In [9]:
victims_df.select("victim_age").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+
|victim_age|
+----------+
|      NULL|
|      29.0|
|     107.0|
|      47.0|
|       1.0|
|      94.0|
|     122.0|
|     103.0|
|     106.0|
|      63.0|
|      82.0|
|      66.0|
|      61.0|
|      46.0|
|      28.0|
|      13.0|
|      26.0|
|      76.0|
|      69.0|
|      98.0|
+----------+
only showing top 20 rows

#### Some of the data types in Victims schema need to change

- 'case_id': double --> long int
- 'victim_age': double --> int

#### Define schema with 'case_id' and 'victim_age' as String and after Spark read, cast to the target data type explicitly. This is intermediate steps is intriduced as a direct schema read into target data type is resulting into NULL values in the table.

In [10]:
victimSchema = StructType([StructField('id', IntegerType(),True),
                        StructField('case_id', StringType(),True),
                        StructField('party_number', IntegerType(),True),
                        StructField('victim_role', StringType(),True),
                        StructField('victim_sex', StringType(),True),
                        StructField('victim_age', StringType(),True),
                        StructField('victim_degree_of_injury', StringType(),True),
                        StructField('victim_seating_position', StringType(),True),
                        StructField('victim_safety_equipment_1', StringType(),True),
                        StructField('victim_safety_equipment_2', StringType(),True),
                        StructField('victim_ejected', StringType(),True),
                        ])

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [11]:
victims_df2 = spark.read.load("/user/livy/sample_victims.csv", format = "csv", header = "true", schema = victimSchema)
victims_df2.select(max(length(col('case_id')))).show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+
|max(length(case_id))|
+--------------------+
|                  22|
+--------------------+

#### 'case_id' is an integer of max 22 digits. Define a large decimal type with 22 digits and 0 scale and use it for casting

In [12]:
large_case_id_decimal_type = DecimalType(22, 0)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Now validate the schema and sample records

In [13]:
victims_df1 = spark.read.load("/user/livy/sample_victims.csv", format = "csv", header = "true", schema = victimSchema)

victims_df1 = victims_df1.withColumn("case_id", col("case_id").cast(large_case_id_decimal_type))

victims_df1 = victims_df1.withColumn("victim_age", col("victim_age").cast("int"))
print("Victims Schema after modification:\n")
victims_df1.printSchema()
print("\nSample data from Victims Schema:\n")
victims_df1.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Victims Schema after modification:

root
 |-- id: integer (nullable = true)
 |-- case_id: decimal(22,0) (nullable = true)
 |-- party_number: integer (nullable = true)
 |-- victim_role: string (nullable = true)
 |-- victim_sex: string (nullable = true)
 |-- victim_age: integer (nullable = true)
 |-- victim_degree_of_injury: string (nullable = true)
 |-- victim_seating_position: string (nullable = true)
 |-- victim_safety_equipment_1: string (nullable = true)
 |-- victim_safety_equipment_2: string (nullable = true)
 |-- victim_ejected: string (nullable = true)


Sample data from Victims Schema:

+-------+--------+------------+-----------+----------+----------+-----------------------+-----------------------+-------------------------+-------------------------+--------------+
|     id| case_id|party_number|victim_role|victim_sex|victim_age|victim_degree_of_injury|victim_seating_position|victim_safety_equipment_1|victim_safety_equipment_2|victim_ejected|
+-------+--------+------------+------

### 1.2 Read Parties Sample CSV File

In [14]:
parties_df = spark.read.load("/user/livy/sample_parties.csv", format = "csv", header = "true", inferSchema = "true")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [15]:
print("Parties Schema as read by PySpark:\n")
parties_df.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Parties Schema as read by PySpark:

root
 |-- id: integer (nullable = true)
 |-- case_id: double (nullable = true)
 |-- party_number: integer (nullable = true)
 |-- party_type: string (nullable = true)
 |-- at_fault: integer (nullable = true)
 |-- party_sex: string (nullable = true)
 |-- party_age: double (nullable = true)
 |-- party_sobriety: string (nullable = true)
 |-- direction_of_travel: string (nullable = true)
 |-- party_safety_equipment_1: string (nullable = true)
 |-- party_safety_equipment_2: string (nullable = true)
 |-- financial_responsibility: string (nullable = true)
 |-- cellphone_in_use: double (nullable = true)
 |-- cellphone_use_type: string (nullable = true)
 |-- other_associate_factor_1: string (nullable = true)
 |-- party_number_killed: integer (nullable = true)
 |-- party_number_injured: integer (nullable = true)
 |-- movement_preceding_collision: string (nullable = true)
 |-- vehicle_year: double (nullable = true)
 |-- vehicle_make: string (nullable = true)
 |-

In [16]:
parties_df.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------+--------------------+------------+----------+--------+---------+---------+--------------------+-------------------+------------------------+------------------------+------------------------+----------------+--------------------+------------------------+-------------------+--------------------+----------------------------+------------+------------+----------------------+-----------------------+----------------------+----------+
|      id|             case_id|party_number|party_type|at_fault|party_sex|party_age|      party_sobriety|direction_of_travel|party_safety_equipment_1|party_safety_equipment_2|financial_responsibility|cellphone_in_use|  cellphone_use_type|other_associate_factor_1|party_number_killed|party_number_injured|movement_preceding_collision|vehicle_year|vehicle_make|statewide_vehicle_type|chp_vehicle_type_towing|chp_vehicle_type_towed|party_race|
+--------+--------------------+------------+----------+--------+---------+---------+--------------------+-------------

#### Inspect the colums 'case_id', 'party_age', 'vehicle_year' and cellphone_in_use

In [17]:
parties_df.select("case_id").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+
|             case_id|
+--------------------+
|           6577619.0|
|           5304948.0|
|           6936449.0|
|           8017444.0|
|           5059621.0|
|           5621639.0|
|           3333832.0|
|           2265243.0|
|         9.1389171E7|
|           4393203.0|
|           3179263.0|
|           1403433.0|
|3.700010128125601...|
|           3249766.0|
|           5163354.0|
|           3300431.0|
|           2299152.0|
|           8931817.0|
|           3436879.0|
|           3204338.0|
+--------------------+
only showing top 20 rows

In [18]:
parties_df.select("party_age").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+
|party_age|
+---------+
|     NULL|
|     29.0|
|     47.0|
|      1.0|
|     94.0|
|    114.0|
|    103.0|
|    106.0|
|     63.0|
|     82.0|
|     66.0|
|     61.0|
|     46.0|
|     28.0|
|     13.0|
|     26.0|
|     76.0|
|     69.0|
|     98.0|
|     85.0|
+---------+
only showing top 20 rows

In [19]:
parties_df.select("vehicle_year").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+
|vehicle_year|
+------------+
|        NULL|
|      1909.0|
|      2010.0|
|      1932.0|
|      1996.0|
|      1969.0|
|      1998.0|
|      1985.0|
|      1944.0|
|      2017.0|
|      2200.0|
|      1948.0|
|      1994.0|
|      1921.0|
|      1926.0|
|      1920.0|
|      2006.0|
|      1947.0|
|      1972.0|
|      2015.0|
+------------+
only showing top 20 rows

In [20]:
parties_df.select("cellphone_in_use").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------------+
|cellphone_in_use|
+----------------+
|            NULL|
|             1.0|
|             0.0|
+----------------+

#### Some of the data types in Party schema need to change

- 'case_id': double --> long int as case id is of 19 digit number with no precision
- 'party_age': double --> int
- 'vehicle_year': double --> int
- 'cellphone_in_use': double --> int

#### Define schema with 'case_id', 'party_age', 'vehicle_year' and  'cellphone_in_use' as String and after Spark read, cast to the target data type explicitly. This is intermediate steps is intriduced as a direct schema read into target data type is resulting into NULL values in the table.

In [21]:
partiesSchema = StructType([StructField('id', IntegerType(),True),
                        StructField('case_id', StringType(),True),
                        StructField('party_number', IntegerType(),True),
                        StructField('party_type', StringType(),True),
                        StructField('at_fault', IntegerType(),True),
                        StructField('party_sex', StringType(),True),
                        StructField('party_age', StringType(),True),
                        StructField('party_sobriety', StringType(),True),
                        StructField('direction_of_travel', StringType(),True),
                        StructField('party_safety_equipment_1', StringType(),True),
                        StructField('party_safety_equipment_2', StringType(),True),
                        StructField('financial_responsibility', StringType(),True),
                        StructField('cellphone_in_use', StringType(),True),
                        StructField('cellphone_use_type', StringType(),True),
                        StructField('other_associate_factor_1', StringType(),True),
                        StructField('party_number_killed', IntegerType(),True),
                        StructField('party_number_injured', IntegerType(),True),
                        StructField('movement_preceding_collision', StringType(),True),
                        StructField('vehicle_year', StringType(),True),
                        StructField('vehicle_make', StringType(),True),
                        StructField('statewide_vehicle_type', StringType(),True),
                        StructField('chp_vehicle_type_towing', StringType(),True),
                        StructField('chp_vehicle_type_towed', StringType(),True),
                        StructField('party_race', StringType(),True),
                        ])

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [22]:
parties_df2 = spark.read.load("/user/livy/sample_parties.csv", format = "csv", header = "true", schema = partiesSchema)
parties_df2.select(max(length(col('case_id')))).show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+
|max(length(case_id))|
+--------------------+
|                  22|
+--------------------+

#### 'case_id' is an integer of max 22 digits. Define a large decimal type with 22 digits and 0 scale and use it for casting

#### Now validate the schema and sample records

In [23]:
parties_df1 = spark.read.load("/user/livy/sample_parties.csv", format = "csv", header = "true", schema = partiesSchema)

parties_df1 = parties_df1.withColumn("case_id", col("case_id").cast(large_case_id_decimal_type))

parties_df1 = parties_df1.withColumn("party_age", col("party_age").cast("int"))
parties_df1 = parties_df1.withColumn("cellphone_in_use", col("cellphone_in_use").cast("int"))
parties_df1 = parties_df1.withColumn("vehicle_year", col("vehicle_year").cast("int"))
print("Parties Schema after modification:\n")
parties_df1.printSchema()
print("\nSample data from Parties Schema:\n")
parties_df1.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Parties Schema after modification:

root
 |-- id: integer (nullable = true)
 |-- case_id: decimal(22,0) (nullable = true)
 |-- party_number: integer (nullable = true)
 |-- party_type: string (nullable = true)
 |-- at_fault: integer (nullable = true)
 |-- party_sex: string (nullable = true)
 |-- party_age: integer (nullable = true)
 |-- party_sobriety: string (nullable = true)
 |-- direction_of_travel: string (nullable = true)
 |-- party_safety_equipment_1: string (nullable = true)
 |-- party_safety_equipment_2: string (nullable = true)
 |-- financial_responsibility: string (nullable = true)
 |-- cellphone_in_use: integer (nullable = true)
 |-- cellphone_use_type: string (nullable = true)
 |-- other_associate_factor_1: string (nullable = true)
 |-- party_number_killed: integer (nullable = true)
 |-- party_number_injured: integer (nullable = true)
 |-- movement_preceding_collision: string (nullable = true)
 |-- vehicle_year: integer (nullable = true)
 |-- vehicle_make: string (nullable =

In [24]:
parties_df1.select("case_id").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+
|            case_id|
+-------------------+
|9252010105233514000|
|             295388|
|            9089584|
|            7059163|
|           90217428|
|            2544466|
|           90616884|
|            2258927|
|            9020118|
|            1282340|
|           91031051|
|            1035951|
|             556275|
|            2383501|
|            5100942|
|            8649659|
|            2241991|
|            5912077|
|            4287744|
|            6414430|
+-------------------+
only showing top 20 rows

In [25]:
parties_df1.select("party_age").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+
|party_age|
+---------+
|     NULL|
|       65|
|       66|
|       67|
|       68|
|       69|
|       70|
|       71|
|       72|
|       73|
|       74|
|       75|
|       76|
|       77|
|       78|
|       79|
|       80|
|       81|
|       82|
|       83|
+---------+
only showing top 20 rows

In [26]:
parties_df1.select("vehicle_year").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+
|vehicle_year|
+------------+
|        NULL|
|        2003|
|        2004|
|        2005|
|        2006|
|        2007|
|        2008|
|        2009|
|        2010|
|        2011|
|        2012|
|        2013|
|        2014|
|        2015|
|        2016|
|        2017|
|        2018|
|        2019|
|        2020|
|        2021|
+------------+
only showing top 20 rows

In [27]:
parties_df1.select("cellphone_in_use").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------------+
|cellphone_in_use|
+----------------+
|            NULL|
|               1|
|               0|
+----------------+

### 1.3 Read Collisions Sample CSV File

In [28]:
collisions_df = spark.read.load("/user/livy/sample_collisions.csv", format = "csv", header = "true", inferSchema = "true")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [29]:
print("Collissions Schema as read by PySpark:\n")
collisions_df.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Collissions Schema as read by PySpark:

root
 |-- case_id: double (nullable = true)
 |-- jurisdiction: double (nullable = true)
 |-- officer_id: string (nullable = true)
 |-- reporting_district: string (nullable = true)
 |-- chp_shift: string (nullable = true)
 |-- population: string (nullable = true)
 |-- county_city_location: integer (nullable = true)
 |-- county_location: string (nullable = true)
 |-- special_condition: double (nullable = true)
 |-- beat_type: string (nullable = true)
 |-- chp_beat_type: string (nullable = true)
 |-- chp_beat_class: string (nullable = true)
 |-- beat_number: string (nullable = true)
 |-- primary_road: string (nullable = true)
 |-- secondary_road: string (nullable = true)
 |-- distance: double (nullable = true)
 |-- direction: string (nullable = true)
 |-- intersection: double (nullable = true)
 |-- weather_1: string (nullable = true)
 |-- state_highway_indicator: double (nullable = true)
 |-- caltrans_county: string (nullable = true)
 |-- caltrans_d

In [30]:
collisions_df.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+------------+----------+------------------+--------------+----------------+--------------------+---------------+-----------------+-------------------+----------------+--------------+-----------+------------------+--------------+--------+---------+------------+---------+-----------------------+---------------+-----------------+-----------+--------+-------------+---------------+--------+--------------------+--------------+---------------+-----------+------------------------+----------------------+-------------+------------------------+---------------+-----------------+---------------------------+--------------------+------------+----------------+--------------------+--------------+-------------+--------------------+-----------------+--------------------+---------------+--------------------+-------------------------------+-------------------------+-------------------+--------------------------+------------------------------+-----------------------+------------------------+---

#### Inspect the colums 'case_id', 'jurisdiction', 'special_condition', 'distance', 'intersection','state_highway_indicator','caltrans_district','state_route', 'postmile','tow_away','killed_victims','injured_victims','party_count','pcf_violation', 'chp_road_type','not_private_property','motorcyclist_injured_count'

In [31]:
collisions_df.select("case_id").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+
|    case_id|
+-----------+
|  1363511.0|
|  1391137.0|
|  5507105.0|
|  6059078.0|
|  5221786.0|
|  5129252.0|
|9.1427419E7|
|  1638103.0|
|  1850676.0|
|  2255158.0|
|  8869107.0|
|  6618123.0|
|  7201000.0|
|  7185424.0|
|  2042443.0|
|  3986562.0|
|  6936449.0|
|  1681106.0|
|9.1339482E7|
|  8556401.0|
+-----------+
only showing top 20 rows

In [32]:
collisions_df.select("jurisdiction").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+
|jurisdiction|
+------------+
|        NULL|
|      9870.0|
|       702.0|
|      9265.0|
|      4115.0|
|      4105.0|
|      2105.0|
|      5406.0|
|      1012.0|
|      3015.0|
|      5702.0|
|      3700.0|
|      4708.0|
|      4119.0|
|      9420.0|
|      9685.0|
|      1909.0|
|       403.0|
|      4307.0|
|      9590.0|
+------------+
only showing top 20 rows

In [33]:
collisions_df.select("special_condition").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------------+
|special_condition|
+-----------------+
|             NULL|
|              1.0|
|              0.0|
+-----------------+

In [34]:
collisions_df.select("distance").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------+
|distance|
+--------+
|   769.0|
|   692.0|
|   934.0|
|   720.0|
|  3690.0|
|  3432.0|
|   576.0|
| 24394.0|
|  1765.0|
|   486.0|
|  3753.0|
| 12830.0|
|  3440.0|
|   702.0|
|  7370.0|
|   758.0|
|  1226.0|
|  4105.0|
|   389.0|
|  1363.0|
+--------+
only showing top 20 rows

In [35]:
collisions_df.select("intersection").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+
|intersection|
+------------+
|        NULL|
|         1.0|
|         0.0|
+------------+

In [36]:
collisions_df.select("state_highway_indicator").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------------------+
|state_highway_indicator|
+-----------------------+
|                   NULL|
|                    1.0|
|                    0.0|
+-----------------------+

In [37]:
collisions_df.select("caltrans_district").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------------+
|caltrans_district|
+-----------------+
|             NULL|
|              1.0|
|              6.0|
|              5.0|
|              2.0|
|              4.0|
|             10.0|
|              8.0|
|              0.0|
|              7.0|
|             11.0|
|              3.0|
|              9.0|
|             12.0|
+-----------------+

In [38]:
collisions_df.select("state_route").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+
|state_route|
+-----------+
|       NULL|
|      201.0|
|      180.0|
|      905.0|
|      155.0|
|       29.0|
|      107.0|
|       47.0|
|      275.0|
|      178.0|
|      780.0|
|      237.0|
|      177.0|
|        1.0|
|       94.0|
|      114.0|
|      149.0|
|      137.0|
|      245.0|
|      165.0|
+-----------+
only showing top 20 rows

In [39]:
collisions_df.select("postmile").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------+
|postmile|
+--------+
|    NULL|
|    26.7|
|  24.478|
|   8.397|
|    13.4|
|   38.61|
|    15.5|
|    14.9|
|  24.524|
|   12.32|
|   30.49|
|  10.802|
|  16.066|
|   76.46|
|  102.62|
|   76.98|
|  52.002|
|   41.89|
|   85.86|
|  20.813|
+--------+
only showing top 20 rows

In [40]:
collisions_df.select("tow_away").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------+
|tow_away|
+--------+
|    NULL|
|     1.0|
|     0.0|
+--------+

In [41]:
collisions_df.select("killed_victims").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------+
|killed_victims|
+--------------+
|          NULL|
|           0.0|
+--------------+

In [42]:
collisions_df.select("injured_victims").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------------+
|injured_victims|
+---------------+
|           NULL|
|            1.0|
|           13.0|
|           26.0|
|           27.0|
|            6.0|
|            5.0|
|            2.0|
|            4.0|
|           14.0|
|           10.0|
|            8.0|
|            0.0|
|            7.0|
|           18.0|
|           11.0|
|            3.0|
|            9.0|
|           16.0|
|           12.0|
+---------------+
only showing top 20 rows

In [43]:
collisions_df.select("party_count").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+
|party_count|
+-----------+
|        1.0|
|       13.0|
|        6.0|
|        5.0|
|       19.0|
|        2.0|
|        4.0|
|       10.0|
|        8.0|
|        7.0|
|       11.0|
|        3.0|
|        9.0|
|       12.0|
|       20.0|
|       16.0|
|       14.0|
+-----------+

In [44]:
collisions_df.select("pcf_violation").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+
|pcf_violation|
+-------------+
|         NULL|
|      21651.0|
|      21701.0|
|      26101.0|
|      22353.0|
|      14601.0|
|      21401.0|
|      21657.0|
|      22153.0|
|      21081.0|
|      23105.0|
|      12500.0|
|      27151.0|
|      21952.0|
|      21206.0|
|      35401.0|
|      21707.0|
|      23152.0|
|      21654.0|
|      21810.0|
+-------------+
only showing top 20 rows

In [45]:
collisions_df.select("chp_road_type").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+
|chp_road_type|
+-------------+
|          1.0|
|          6.0|
|          5.0|
|          2.0|
|          4.0|
|          0.0|
|          7.0|
|          3.0|
+-------------+

In [46]:
collisions_df.select("not_private_property").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+
|not_private_property|
+--------------------+
|                 1.0|
+--------------------+

In [47]:
collisions_df.select("motorcyclist_injured_count").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------------+
|motorcyclist_injured_count|
+--------------------------+
|                       1.0|
|                       5.0|
|                       2.0|
|                       4.0|
|                       0.0|
|                       3.0|
|                       6.0|
+--------------------------+

In [48]:
collisionsSchema = StructType([
    StructField('case_id', StringType(), True),
    StructField('jurisdiction', StringType(), True),
    StructField('officer_id', StringType(), True),
    StructField('reporting_district', StringType(), True),
    StructField('chp_shift', StringType(), True),
    StructField('population', StringType(), True),
    StructField('county_city_location', IntegerType(), True),
    StructField('county_location', StringType(), True),
    StructField('special_condition', StringType(), True),
    StructField('beat_type', StringType(), True),
    StructField('chp_beat_type', StringType(), True),
    StructField('chp_beat_class', StringType(), True),
    StructField('beat_number', StringType(), True),
    StructField('primary_road', StringType(), True),
    StructField('secondary_road', StringType(), True),
    StructField('distance', StringType(), True),
    StructField('direction', StringType(), True),
    StructField('intersection', StringType(), True),
    StructField('weather_1', StringType(), True),
    StructField('state_highway_indicator', StringType(), True),
    StructField('caltrans_county', StringType(), True),
    StructField('caltrans_district', StringType(), True),
    StructField('state_route', StringType(), True),
    StructField('postmile', StringType(), True),
    StructField('location_type', StringType(), True),
    StructField('side_of_highway', StringType(), True),
    StructField('tow_away', StringType(), True),
    StructField('collision_severity', StringType(), True),
    StructField('killed_victims', StringType(), True),
    StructField('injured_victims', StringType(), True),
    StructField('party_count', StringType(), True),
    StructField('primary_collision_factor', StringType(), True),
    StructField('pcf_violation_category', StringType(), True),
    StructField('pcf_violation', StringType(), True),
    StructField('pcf_violation_subsection', StringType(), True),
    StructField('hit_and_run', StringType(), True),
    StructField('type_of_collision', StringType(), True),
    StructField('motor_vehicle_involved_with', StringType(), True),
    StructField('pedestrian_action', StringType(), True),
    StructField('road_surface', StringType(), True),
    StructField('road_condition_1', StringType(), True),
    StructField('lighting', StringType(), True),
    StructField('control_device', StringType(), True),
    StructField('chp_road_type', StringType(), True),
    StructField('pedestrian_collision', StringType(), True),
    StructField('bicycle_collision', StringType(), True),
    StructField('motorcycle_collision', StringType(), True),
    StructField('truck_collision', StringType(), True),
    StructField('not_private_property', StringType(), True),
    StructField('statewide_vehicle_type_at_fault', StringType(), True),
    StructField('chp_vehicle_type_at_fault', StringType(), True),
    StructField('severe_injury_count', StringType(), True),
    StructField('other_visible_injury_count', StringType(), True),
    StructField('complaint_of_pain_injury_count', StringType(), True),
    StructField('pedestrian_killed_count', StringType(), True),
    StructField('pedestrian_injured_count', StringType(), True),
    StructField('bicyclist_killed_count', StringType(), True),
    StructField('bicyclist_injured_count', StringType(), True),
    StructField('motorcyclist_killed_count', StringType(), True),
    StructField('motorcyclist_injured_count', StringType(), True),
    StructField('latitude', DoubleType(), True),
    StructField('longitude', DoubleType(), True),
    StructField('collision_date', DateType(), True),
    StructField('collision_time', TimestampType(), True),
    StructField('process_date', DateType(), True),
])

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### 'case_id' is an integer of max 22 digits. Define a large decimal type with 22 digits and 0 scale and use it for casting

#### Now validate the schema and sample records

In [49]:
collisions_df1 = spark.read.load("/user/livy/sample_collisions.csv", format = "csv", header = "true", schema = collisionsSchema)

collisions_df1 = collisions_df1.withColumn("case_id", col("case_id").cast(large_case_id_decimal_type))

collisions_df1 = collisions_df1 \
    .withColumn("jurisdiction", col("jurisdiction").cast("int"))
print("collisions Schema after modification:\n")
collisions_df1.printSchema()
print("\nSample data from collisions Schema:\n")
collisions_df1.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

collisions Schema after modification:

root
 |-- case_id: decimal(22,0) (nullable = true)
 |-- jurisdiction: integer (nullable = true)
 |-- officer_id: string (nullable = true)
 |-- reporting_district: string (nullable = true)
 |-- chp_shift: string (nullable = true)
 |-- population: string (nullable = true)
 |-- county_city_location: integer (nullable = true)
 |-- county_location: string (nullable = true)
 |-- special_condition: string (nullable = true)
 |-- beat_type: string (nullable = true)
 |-- chp_beat_type: string (nullable = true)
 |-- chp_beat_class: string (nullable = true)
 |-- beat_number: string (nullable = true)
 |-- primary_road: string (nullable = true)
 |-- secondary_road: string (nullable = true)
 |-- distance: string (nullable = true)
 |-- direction: string (nullable = true)
 |-- intersection: string (nullable = true)
 |-- weather_1: string (nullable = true)
 |-- state_highway_indicator: string (nullable = true)
 |-- caltrans_county: string (nullable = true)
 |-- cal

### 1.4 Read Case ID Sample CSV File

In [50]:
case_ids_df = spark.read.load("/user/livy/sample_case_ids.csv", format = "csv", header = "true", inferSchema = "true")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [51]:
print("Case ID Schema as read by PySpark:\n")
case_ids_df.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Case ID Schema as read by PySpark:

root
 |-- case_id: double (nullable = true)
 |-- db_year: integer (nullable = true)

In [52]:
case_ids_df.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+-------+
|    case_id|db_year|
+-----------+-------+
|9.0017156E7|   2021|
|  4078685.0|   2021|
|9.0588783E7|   2021|
|  3351919.0|   2018|
|   632208.0|   2018|
+-----------+-------+
only showing top 5 rows

In [53]:
caseid_Schema = StructType([
    StructField('case_id', StringType(), True),
    StructField('db_year', IntegerType(), True),
    ])

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [54]:
case_id_df2 = spark.read.load("/user/livy/sample_case_ids.csv", format = "csv", header = "true", schema = caseid_Schema)
case_id_df2.select(max(length(col('case_id')))).show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+
|max(length(case_id))|
+--------------------+
|                  22|
+--------------------+

#### 'case_id' is an integer of max 22 digits. Define a large decimal type with 22 digits and 0 scale and use it for casting

In [55]:
case_ids_df1 = spark.read.load("/user/livy/sample_case_ids.csv", format = "csv", header = "true", schema = caseid_Schema)

case_ids_df1 = case_ids_df.withColumn("case_id", col("case_id").cast(large_case_id_decimal_type))

print("Case ID Schema after modification:\n")
case_ids_df1.printSchema()
print("\nSample data from Case ID Schema:\n")
case_ids_df1.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Case ID Schema after modification:

root
 |-- case_id: decimal(22,0) (nullable = true)
 |-- db_year: integer (nullable = true)


Sample data from Case ID Schema:

+--------+-------+
| case_id|db_year|
+--------+-------+
|90017156|   2021|
| 4078685|   2021|
|90588783|   2021|
| 3351919|   2018|
|  632208|   2018|
+--------+-------+
only showing top 5 rows

In [56]:
case_ids_df1.select("case_id").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+
|            case_id|
+-------------------+
|             377546|
|            2788420|
|            2454913|
|            2746663|
|            8294809|
|            4397716|
|            8093229|
|           90905548|
|9250011005145016000|
|            8940650|
|           90828955|
|            5822855|
|            4064285|
|9550011230140017000|
|           90627693|
|            5051908|
|            2730591|
|            1379014|
|            8792143|
|            9000725|
+-------------------+
only showing top 20 rows

# **2. Data Cleaning** <font color = red>[20 marks]</font> <br>


In [57]:
#Display Schema & Sample Data

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## **2.1 Missing Values** <font color = red>[10 marks]</font> <br>


In [58]:
#Check for Missing Values
from pyspark.sql.functions import col, sum, when

# Function to count missing values for each column
def find_missing_vals(df, table):
    print(f"\nMissing values in table {table}:")
    df.select([sum(when(col(c).isNull() | (col(c) == ''), 1).otherwise(0)).alias(c) for c in df.columns]).show()

# find missing values in each table
find_missing_vals(case_ids_df1, "case_ids")
find_missing_vals(collisions_df1, "collisions")
find_missing_vals(parties_df1, "parties")
find_missing_vals(victims_df1, "victims")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Missing values in table case_ids:
+-------+-------+
|case_id|db_year|
+-------+-------+
|      0|      0|
+-------+-------+


Missing values in table collisions:
+-------+------------+----------+------------------+---------+----------+--------------------+---------------+-----------------+---------+-------------+--------------+-----------+------------+--------------+--------+---------+------------+---------+-----------------------+---------------+-----------------+-----------+--------+-------------+---------------+--------+------------------+--------------+---------------+-----------+------------------------+----------------------+-------------+------------------------+-----------+-----------------+---------------------------+-----------------+------------+----------------+--------+--------------+-------------+--------------------+-----------------+--------------------+---------------+--------------------+-------------------------------+-------------------------+-------------------+--

In [59]:
#Drop Sparse Columns
from pyspark.sql.functions import col, sum, when

def drop_sparse_cols(df, table, threshold=0.5):
    table_count = df.count()
    print(f"\nCount of records in {table}: {table_count}")

    # Obtain count of missing values in each column
    miss_cnt = df.select([
        sum(when(col(c).isNull() | (col(c) == ''), 1)).alias(c) for c in df.columns
    ]).collect()[0].asDict()

    # Handle None (treat as 0)
    sparse_cols = [
        c_name for c_name, miss in miss_cnt.items()
        if (miss or 0) / table_count > threshold
    ]

    print(f"Sparse Cols to drop in table {table} (>{threshold*100}% miss): {sparse_cols}")

    # Drop sparse columns
    return df.drop(*sparse_cols)

# Apply to each DataFrame with a 50% threshold
case_ids_df1 = drop_sparse_cols(case_ids_df1, "case_ids")
collisions_df1 = drop_sparse_cols(collisions_df1, "collisions")
parties_df1 = drop_sparse_cols(parties_df1, "parties")
victims_df1 = drop_sparse_cols(victims_df1, "victims")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Count of records in case_ids: 942433
Sparse Cols to drop in table case_ids (>50.0% miss): []

Count of records in collisions: 935791
Sparse Cols to drop in table collisions (>50.0% miss): ['reporting_district', 'caltrans_county', 'caltrans_district', 'state_route', 'postmile', 'location_type', 'side_of_highway', 'pcf_violation_subsection', 'latitude', 'longitude']

Count of records in parties: 1866917
Sparse Cols to drop in table parties (>50.0% miss): []

Count of records in victims: 963933
Sparse Cols to drop in table victims (>50.0% miss): []

In [60]:
#Convert Data Types

from pyspark.sql.functions import col, to_date
from pyspark.sql.types import IntegerType, DoubleType

# Step 1: Convert known date columns in PySpark
def convert_date_col(df, col_name, date_format="yyyy-MM-dd"):
    if col_name in df.columns:
        df = df.withColumn(col_name, to_date(col(col_name), date_format))
        print(f"Converted column '{col_name}' to date.")
    return df

# Apply to relevant columns
collisions_df1 = convert_date_col(collisions_df1, 'collision_date')
collisions_df1 = convert_date_col(collisions_df1, 'process_date')


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Converted column 'collision_date' to date.
Converted column 'process_date' to date.

In [61]:
#Handle Missing Values

from pyspark.sql.functions import col, min, to_date
from pyspark.sql.types import IntegerType, FloatType, StringType, DateType

def impute_missing_vals(df, table):
    print(f"\nHandling missing values for Table {table}:")

    # Replace missing (int and double) values to 0
    num_cols = [c_name for c_name, dtype in df.dtypes if dtype in ['int', 'double', 'long', 'decimal(22,0)']]
    for c_name in num_cols:
        df = df.fillna({c_name: 0})
    print(f"Replaced missing numbers with 0.")

    # Replace NULL string values with 'Unknown'
    obj_cols = [c_name for c_name, dtype in df.dtypes if dtype == 'string']
    for c_name in obj_cols:
        df = df.fillna({c_name: 'Unknown'})
    print(f"Replaced NULL columns with 'Unknown'.")

    # Fill datetime columns with the earliest available date
    date_cols = [c_name for c_name, dtype in df.dtypes if dtype == 'timestamp']
    for c_name in date_cols:
        earliest_date = df.select(min(col(c_name))).collect()[0][0]  # Get the earliest date
        if earliest_date:
            df = df.withColumn(c_name, when(col(c_name).isNull(), earliest_date).otherwise(col(c_name)))
            print(f"Replaced missing values in '{c_name}' with earliest date: {earliest_date}")

    return df

# Apply to all datasets
case_ids_df1 = impute_missing_vals(case_ids_df1, "caseids__df")
collisions_df1 = impute_missing_vals(collisions_df1, "collisions_df")
parties_df1 = impute_missing_vals(parties_df1, "parties_df")
victims_df1 = impute_missing_vals(victims_df1, "victims_df")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Handling missing values for Table caseids__df:
Replaced missing numbers with 0.
Replaced NULL columns with 'Unknown'.

Handling missing values for Table collisions_df:
Replaced missing numbers with 0.
Replaced NULL columns with 'Unknown'.
Replaced missing values in 'collision_time' with earliest date: 2025-06-17 00:00:00

Handling missing values for Table parties_df:
Replaced missing numbers with 0.
Replaced NULL columns with 'Unknown'.

Handling missing values for Table victims_df:
Replaced missing numbers with 0.
Replaced NULL columns with 'Unknown'.

## **2.2 Fixing Columns** <font color = red>[5 marks]</font> <br>


In [62]:
#Remove Duplicates

# Function to remove duplicates and print the count of duplicates removed
def remove_dupes(df, table):

    df_cleaned = df.dropDuplicates()
    print(f"Removed {df.count() - df_cleaned.count()} duplicate rrecords from {table}.")
    return df_cleaned

case_ids_df1 = remove_dupes(case_ids_df1, "case_df")
collisions_df1 = remove_dupes(collisions_df1, "collisions_df")
parties_df1 = remove_dupes(parties_df1, "parties_df")
victims_df1 = remove_dupes(victims_df1, "victims_df")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Removed 35 duplicate rrecords from case_df.
Removed 0 duplicate rrecords from collisions_df.
Removed 0 duplicate rrecords from parties_df.
Removed 0 duplicate rrecords from victims_df.

In [63]:
from pyspark.sql import functions as fn

# List of numerical columns to check for outliers
def num_cols(df, table):
    # Filter numeric columns by checking data type of each column
    num_cols_list = [c_name for c_name, dtype in df.dtypes if dtype in ['int', 'double', 'long', 'decimal(22,0)'] and 'id' not in c_name.lower()]
    print(f"\n{table} - Numeric columns:")
    for c in num_cols_list:
        
        print(f" - {c}")
    return num_cols_list

# List of numerical columns to check for outliers
case_ids_num_cols = num_cols(case_ids_df1, "case_df")
collisions_num_cols = num_cols(collisions_df1, "collisions_df")
parties_num_cols = num_cols(parties_df1, "parties_df")
victims_num_cols = num_cols(victims_df1, "victims_df")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


case_df - Numeric columns:
 - db_year

collisions_df - Numeric columns:
 - jurisdiction
 - county_city_location

parties_df - Numeric columns:
 - party_number
 - at_fault
 - party_age
 - cellphone_in_use
 - party_number_killed
 - party_number_injured
 - vehicle_year

victims_df - Numeric columns:
 - party_number
 - victim_age

In [64]:
from pyspark.sql import functions as fn

#Detect Outliers using IQR
def detect_outliers_iqr(df, table, cols):
    print(f"\nDetecting outliers using IQR in {table} :")

    for c in cols:
         
        qntiles = df.approxQuantile(c, [0.25, 0.75], 0.05)  # Relative error tolerance is 0.05
        Q1 = qntiles[0]      # Calculate Q1 (25th percentile)
        Q3 = qntiles[1]      # Calculate Q3 (75th percentile)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        # Detect outliers: values below the lower bound or above the upper bound
        outliers_df = df.filter((fn.col(c) < lower_bound) | (fn.col(c) > upper_bound))
        
        # Count outliers only if there are any
        outliers_count = outliers_df.count()
        if outliers_count > 0:
            print(f"{c}: {outliers_count} outliers")


# Apply outlier detection to each DataFrame
detect_outliers_iqr(case_ids_df1, "case_df", case_ids_num_cols)
detect_outliers_iqr(collisions_df1, "collisions_df", collisions_num_cols)
detect_outliers_iqr(parties_df1, "parties_df", parties_num_cols)
detect_outliers_iqr(victims_df1, "victims_df", victims_num_cols)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Detecting outliers using IQR in case_df :

Detecting outliers using IQR in collisions_df :

Detecting outliers using IQR in parties_df :
party_number: 35411 outliers
party_age: 19180 outliers
cellphone_in_use: 26406 outliers
party_number_killed: 6889 outliers
party_number_injured: 437989 outliers
vehicle_year: 193038 outliers

Detecting outliers using IQR in victims_df :
party_number: 12976 outliers
victim_age: 17284 outliers

## **2.3 Outlier Analysis** <font color = red>[5 marks]</font> <br>


In [65]:
#Remove Outliers

from pyspark.sql import functions as fn

def remove_outliers_iqr(df, table, num_cols, ex_cols=None):
    print(f"\nRemoving outliers from {table}...")
    clean_df = df
    total_removed = 0

    for c in num_cols:
        # Calculate Q1 and Q3
        q1, q3 = clean_df.approxQuantile(c, [0.25, 0.75], 0.05)
        iqr = q3 - q1
        lower = q1 - 1.5 * iqr
        upper = q3 + 1.5 * iqr

        before_count = clean_df.count()
        clean_df = clean_df.filter((fn.col(c) >= lower) & (fn.col(c) <= upper))
        after_count = clean_df.count()

        removed = before_count - after_count
        total_removed += removed

        if removed > 0:
            print(f" - {removed} rows removed based on column '{c}'")

    if total_removed == 0:
        print(" - No outliers removed.")

    return clean_df

# Define columns to exclude from outlier removal
ex_case_id = ['case_id']
ex_collison = ['case_id', 'id']
ex_party = ['case_id', 'id', 'party_number']
ex_victim = ['case_id', 'id', 'party_number']

# Apply the outlier removal to all datasets
case_ids_df1 = remove_outliers_iqr(case_ids_df1, "case_df", case_ids_num_cols, ex_case_id)
collisions_df1 = remove_outliers_iqr(collisions_df1, "collisions_df",collisions_num_cols,  ex_collison)
parties_df1 = remove_outliers_iqr(parties_df1, "parties_df", parties_num_cols , ex_party)
victims_df1 = remove_outliers_iqr(victims_df1, "victims_df", victims_num_cols,  ex_victim)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Removing outliers from case_df...
 - No outliers removed.

Removing outliers from collisions_df...
 - No outliers removed.

Removing outliers from parties_df...
 - 35411 rows removed based on column 'party_number'
 - 13536 rows removed based on column 'party_age'
 - 25923 rows removed based on column 'cellphone_in_use'
 - 6547 rows removed based on column 'party_number_killed'
 - 420961 rows removed based on column 'party_number_injured'
 - 135958 rows removed based on column 'vehicle_year'

Removing outliers from victims_df...
 - 12976 rows removed based on column 'party_number'
 - 15125 rows removed based on column 'victim_age'

In [66]:
collisions_df1.select("collision_severity").distinct().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+
|  collision_severity|
+--------------------+
|property damage only|
|                pain|
|       severe injury|
|        other injury|
+--------------------+

# **3. Exploratory Data Analysis** <font color = red>[65 marks]</font> <br>


## **3.1.1. Data Preparation** <font color = red>[5 marks]</font> <br>

Q: Classify variables into categorical and numerical.

In [67]:
# Encode Categorical Variables

# String Indexing for Categorical Columns
from pyspark.sql.types import DoubleType
from pyspark.sql import functions as fn


def classify_cat_num_columns(df):
    cat_cols = [c for c, t in df.dtypes if t == 'string']
    num_cols = [c for c, t in df.dtypes if t in ['int', 'double', 'long', 'decimal(22,0)']]
    return cat_cols, num_cols

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [68]:
# table collision
cat_cols, num_cols = classify_cat_num_columns(collisions_df1)
print("Collision Table:\n")
print("Categorical Columns:", cat_cols)
print("\nNumerical Columns:", num_cols)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Collision Table:

Categorical Columns: ['officer_id', 'chp_shift', 'population', 'county_location', 'special_condition', 'beat_type', 'chp_beat_type', 'chp_beat_class', 'beat_number', 'primary_road', 'secondary_road', 'distance', 'direction', 'intersection', 'weather_1', 'state_highway_indicator', 'tow_away', 'collision_severity', 'killed_victims', 'injured_victims', 'party_count', 'primary_collision_factor', 'pcf_violation_category', 'pcf_violation', 'hit_and_run', 'type_of_collision', 'motor_vehicle_involved_with', 'pedestrian_action', 'road_surface', 'road_condition_1', 'lighting', 'control_device', 'chp_road_type', 'pedestrian_collision', 'bicycle_collision', 'motorcycle_collision', 'truck_collision', 'not_private_property', 'statewide_vehicle_type_at_fault', 'chp_vehicle_type_at_fault', 'severe_injury_count', 'other_visible_injury_count', 'complaint_of_pain_injury_count', 'pedestrian_killed_count', 'pedestrian_injured_count', 'bicyclist_killed_count', 'bicyclist_injured_count', 'm

In [69]:
# table case_id
cat_cols, num_cols = classify_cat_num_columns(case_ids_df1)
print("Case ID Table:\n")
print("Categorical Columns:", cat_cols)
print("\nNumerical Columns:", num_cols)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Case ID Table:

Categorical Columns: []

Numerical Columns: ['case_id', 'db_year']

In [70]:
# table victims
cat_cols, num_cols = classify_cat_num_columns(victims_df1)
print("Victims Table:\n")
print("Categorical Columns:", cat_cols)
print("\n Numerical Columns:", num_cols)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Victims Table:

Categorical Columns: ['victim_role', 'victim_sex', 'victim_degree_of_injury', 'victim_seating_position', 'victim_safety_equipment_1', 'victim_safety_equipment_2', 'victim_ejected']

 Numerical Columns: ['id', 'case_id', 'party_number', 'victim_age']

In [71]:
# table parties
cat_cols, num_cols = classify_cat_num_columns(parties_df1)
print("Parties Table:\n")
print("Categorical Columns:", cat_cols)
print("\nNumerical Columns:", num_cols)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Parties Table:

Categorical Columns: ['party_type', 'party_sex', 'party_sobriety', 'direction_of_travel', 'party_safety_equipment_1', 'party_safety_equipment_2', 'financial_responsibility', 'cellphone_use_type', 'other_associate_factor_1', 'movement_preceding_collision', 'vehicle_make', 'statewide_vehicle_type', 'chp_vehicle_type_towing', 'chp_vehicle_type_towed', 'party_race']

Numerical Columns: ['id', 'case_id', 'party_number', 'at_fault', 'party_age', 'cellphone_in_use', 'party_number_killed', 'party_number_injured', 'vehicle_year']

In [72]:
# Encode Categorical Variables
from pyspark.sql import functions as fn
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

#Encode Categorical Columns using StringIndexer
def encode_cat_cols(df, cat_cols, imp_cols=None):
    indexers = []
    for c in cat_cols:
        if imp_cols and c in imp_cols:
            print(f"Encoding important column: {c}")
        indexer = StringIndexer(inputCol=c, outputCol=c + "_indexed")
        indexers.append(indexer)
        
     # Create a pipeline to apply all StringIndexers
    pipeline = Pipeline(stages=indexers)
    
    # Fit and transform the data
    df_fit_transform = pipeline.fit(df).transform(df)
    
    print(f"Encoded {len(cat_cols)} categorical columns.")
    return df_fit_transform

cat_cols, num_cols = classify_cat_num_columns(collisions_df1)

# Ensure 'collision_severity' is included as an important column
if 'collision_severity' not in cat_cols:
    cat_cols.append('collision_severity')

# Apply encoding to collisions_df
collisions_df_fit_transform = encode_cat_cols(collisions_df1, cat_cols, imp_cols=['collision_severity'])

# Show the transformed dataframe
collisions_df_fit_transform.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Encoding important column: collision_severity
Encoded 49 categorical columns.
+-------------------+------------+----------+--------------+----------------+--------------------+---------------+-----------------+-----------------+-------------+--------------+-----------+--------------+--------------+--------+---------+------------+---------+-----------------------+--------+--------------------+--------------+---------------+-----------+------------------------+----------------------+-------------+---------------+-----------------+---------------------------+--------------------+------------+----------------+--------------------+--------------+-------------+--------------------+-----------------+--------------------+---------------+--------------------+-------------------------------+-------------------------+-------------------+--------------------------+------------------------------+-----------------------+------------------------+----------------------+-----------------------+--------

In [73]:
# Reordering & Renaming Columns

# ----------------------------
# 1. Reorder columns in collisions_df
# ----------------------------
collisions_cols = collisions_df1.columns

# Preferred front columns
bring_fwd = ['case_id', 'collision_date', 'process_date', 'collision_severity', 'injured_victims', 'killed_victims']
bring_fwd = [c for c in bring_fwd if c in collisions_cols]  # Keep only existing columns

# Get remaining columns
rest = [c for c in collisions_cols if c not in bring_fwd]

# Reorder
collisions_df1 = collisions_df1.select(bring_fwd + rest)

# ----------------------------
# 2. Rename selected columns in victims_df
# ----------------------------
rename_map = {
    'victim_sex': 'gender',
    'victim_age': 'age',
    'victim_degree_of_injury': 'injury_severity',
    'victim_role': 'role'
}

for old_c, new_c in rename_map.items():
    if old_c in victims_df1.columns:
        victims_df1 = victims_df1.withColumnRenamed(old_c, new_c)

# ----------------------------
# 3. Print schemas
# ----------------------------
print("\nUpdated Schema after Re-Order and Rename of columns: collisions_df1")
collisions_df1.printSchema()

print("\nUpdated Schema after Re-Order and Rename of columns: victims_df1")
victims_df1.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Updated Schema after Re-Order and Rename of columns: collisions_df1
root
 |-- case_id: decimal(22,0) (nullable = false)
 |-- collision_date: date (nullable = true)
 |-- process_date: date (nullable = true)
 |-- collision_severity: string (nullable = false)
 |-- injured_victims: string (nullable = false)
 |-- killed_victims: string (nullable = false)
 |-- jurisdiction: integer (nullable = false)
 |-- officer_id: string (nullable = false)
 |-- chp_shift: string (nullable = false)
 |-- population: string (nullable = false)
 |-- county_city_location: integer (nullable = false)
 |-- county_location: string (nullable = false)
 |-- special_condition: string (nullable = false)
 |-- beat_type: string (nullable = false)
 |-- chp_beat_type: string (nullable = false)
 |-- chp_beat_class: string (nullable = false)
 |-- beat_number: string (nullable = false)
 |-- primary_road: string (nullable = false)
 |-- secondary_road: string (nullable = false)
 |-- distance: string (nullable = false)
 |-- dire

In [74]:
# Final cleaned data
def final_df_shape(df, name):
    num_rows = df.count()
    num_cols = len(df.columns)
    print(f"{name} ==> Rows: {num_rows}, Columns: {num_cols}")

# Print final shapes for each cleaned DataFrame
final_df_shape(case_ids_df1 , "case_df")
final_df_shape(collisions_df1, "collisions_df")
final_df_shape(parties_df1, "parties_df")
final_df_shape(victims_df1, "victims_df")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

case_df ==> Rows: 942398, Columns: 2
collisions_df ==> Rows: 935791, Columns: 55
parties_df ==> Rows: 1228581, Columns: 24
victims_df ==> Rows: 935832, Columns: 11

Loading the Final Cleaned Dataset into S3 Bucket

## **3.1.2. Analyze the distribution of collision severity.** <font color = red>[5 marks]</font> <br>

Q: Analyze the distribution of collision severity.

In [75]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import boto3
# Univariate Analysis

# Collision Severity Distribution
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend

import matplotlib.pyplot as plt
import seaborn as sns

# Convert to Pandas
severity_cnt_df = collisions_df1.groupBy("collision_severity").count().orderBy("count", ascending=False)
severity_cnt_pd = severity_cnt_df.toPandas()

# Plot
plt.figure(figsize=(10, 6))
sns.barplot(
    data=severity_cnt_pd,
    x='collision_severity',
    y='count',
    hue='collision_severity',
    legend=False,  # Hide redundant legend
    palette='viridis'
)
plt.title("Collision Severity Distribution")
plt.xlabel("Collision Severity")
plt.ylabel("Count")
plt.xticks(rotation=45)
plt.tight_layout()
#plt.show()

# Save plot locally on EMR
local_path = "/tmp/collision_severity_dist.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/collision_severity_dist.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/collision_severity_dist.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/collision_severity_dist.png

In [76]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import boto3
# Univariate Analysis

# Collision Severity Distribution
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend

import matplotlib.pyplot as plt
import seaborn as sns

# Convert to Pandas
severity_cnt_df = collisions_df.groupBy("collision_severity").count().orderBy("count", ascending=False)
severity_cnt_pd = severity_cnt_df.toPandas()

# Plot
plt.figure(figsize=(10, 6))
sns.barplot(
    data=severity_cnt_pd,
    x='collision_severity',
    y='count',
    hue='collision_severity',
    legend=False,  # Hide redundant legend
    palette='viridis'
)
plt.title("Collision Severity Distribution")
plt.xlabel("Collision Severity")
plt.ylabel("Count")
plt.xticks(rotation=45)
plt.tight_layout()
#plt.show()

# Save plot locally on EMR
local_path = "/tmp/collision_severity_dist2.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/collision_severity_dist2.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/collision_severity_dist2.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/collision_severity_dist2.png

## **3.1.3. Weather conditions during collisions.** <font color = red>[5 marks]</font> <br>

Q: Examine weather conditions during collisions.

In [77]:
# Weather Conditions During Collisions
weather_counts_df = collisions_df1.groupBy("weather_1") \
                                 .count() \
                                 .orderBy("count", ascending=False)

print("Weather Condition Counts (Spark):")
weather_counts_df.show(truncate=False)

# Convert to Pandas
weather_counts_pd = weather_counts_df.toPandas()

# Plot
plt.figure(figsize=(12, 6))
sns.barplot(data=weather_counts_pd,
            x='weather_1',
            y='count',
            hue='weather_1',
            palette='Set2',
            legend=False)

plt.title("Weather Conditions During Collisions")
plt.xlabel("Weather Condition")
plt.ylabel("Count")
plt.xticks(rotation=45)
plt.tight_layout()
#plt.show()

# Save plot locally on EMR
local_path = "/tmp/weather_conditions.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/weather_conditions.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Weather Condition Counts (Spark):
+---------+------+
|weather_1|count |
+---------+------+
|clear    |769927|
|cloudy   |122155|
|raining  |32341 |
|Unknown  |4626  |
|fog      |3901  |
|snowing  |1370  |
|other    |1081  |
|wind     |390   |
+---------+------+

? Plot saved locally at: /tmp/weather_conditions.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/weather_conditions.png

## **3.1.4. Victime Age Distribution.** <font color = red>[5 marks]</font> <br>

Q: Analyze the distribution of victim ages.

In [78]:
victims_df1.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- id: integer (nullable = false)
 |-- case_id: decimal(22,0) (nullable = false)
 |-- party_number: integer (nullable = false)
 |-- role: string (nullable = false)
 |-- gender: string (nullable = false)
 |-- age: integer (nullable = false)
 |-- injury_severity: string (nullable = false)
 |-- victim_seating_position: string (nullable = false)
 |-- victim_safety_equipment_1: string (nullable = false)
 |-- victim_safety_equipment_2: string (nullable = false)
 |-- victim_ejected: string (nullable = false)

In [79]:
# Distribution of Victim Ages
victim_ages_df = victims_df1.select("age").dropna()
victim_ages_df.show(truncate=False)
# Convert to Pandas
victim_ages_pd = victim_ages_df.toPandas()

# Plot
plt.figure(figsize=(10, 6))
sns.histplot(victim_ages_pd["age"], bins=30, kde=True, color='skyblue')

plt.title("Distribution of Victim Ages")
plt.xlabel("Age")
plt.ylabel("Frequency")
plt.grid(True)
plt.tight_layout()
#plt.show()

# Save plot locally on EMR
local_path = "/tmp/victim_age_distribution.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/victim_age_distribution.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+
|age|
+---+
|0  |
|19 |
|39 |
|64 |
|36 |
|14 |
|19 |
|17 |
|28 |
|57 |
|22 |
|32 |
|34 |
|22 |
|20 |
|18 |
|22 |
|62 |
|46 |
|43 |
+---+
only showing top 20 rows

? Plot saved locally at: /tmp/victim_age_distribution.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/victim_age_distribution.png

## **3.1.5. Collision Severity vs Number of Victims.** <font color = red>[5 marks]</font> <br>

Q: Study the relationship between collision severity and the number of victims.

In [80]:
# Bivariate Analysis

# Collision Severity vs. Number of Victims
victims_by_severity_df = collisions_df1.groupBy("collision_severity").agg(
    fn.sum("injured_victims").alias("total_injured_victims"),
    fn.sum("killed_victims").alias("total_killed_victims")
)
victims_by_severity_df = victims_by_severity_df.withColumn(
    "total_victims", victims_by_severity_df["total_injured_victims"] + victims_by_severity_df["total_killed_victims"]
)

# Convert to Pandas
victims_by_severity_pd = victims_by_severity_df.toPandas()

# Plot
plt.figure(figsize=(10, 6))
sns.barplot(data=victims_by_severity_pd, x='collision_severity', y='total_victims', palette='viridis')

# Step 5: Title and labels
plt.title("Collision Severity vs. Total Number of Victims")
plt.xlabel("Collision Severity")
plt.ylabel("Total Number of Victims")
plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()

# Step 6: Show the plot
# plt.show()

# Save plot locally on EMR
local_path = "/tmp/Collision_Severity_vs_Total_Victims.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Collision_Severity_vs_Total_Victims.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Collision_Severity_vs_Total_Victims.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Collision_Severity_vs_Total_Victims.png

In [81]:
# Bivariate Analysis

# Collision Severity vs. Number of Victims
victims_by_severity_df = collisions_df.groupBy("collision_severity").agg(
    fn.sum("injured_victims").alias("total_injured_victims"),
    fn.sum("killed_victims").alias("total_killed_victims")
)
victims_by_severity_df = victims_by_severity_df.withColumn(
    "total_victims", victims_by_severity_df["total_injured_victims"] + victims_by_severity_df["total_killed_victims"]
)

# Convert to Pandas
victims_by_severity_pd = victims_by_severity_df.toPandas()

# Plot
plt.figure(figsize=(10, 6))
sns.barplot(data=victims_by_severity_pd, x='collision_severity', y='total_victims', palette='viridis')

# Step 5: Title and labels
plt.title("Collision Severity vs. Total Number of Victims")
plt.xlabel("Collision Severity")
plt.ylabel("Total Number of Victims")
plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()

# Step 6: Show the plot
# plt.show()

# Save plot locally on EMR
local_path = "/tmp/Collision_Severity_vs_Total_Victims2.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Collision_Severity_vs_Total_Victims2.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Collision_Severity_vs_Total_Victims2.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Collision_Severity_vs_Total_Victims2.png

## **3.1.6. Weather Conditions vs Collision Severity.** <font color = red>[5 marks]</font> <br>

Q: Analyze the correlation between weather conditions and collision severity.

In [82]:
# Weather vs. Collision Severity
weather_vs_severity_df = collisions_df1.groupBy("weather_1", "collision_severity").count()

# Convert to Pandas
weather_vs_severity_pd = weather_vs_severity_df.toPandas()

# Plot
plt.figure(figsize=(12, 6))
sns.barplot(data=weather_vs_severity_pd, x='weather_1', y='count', hue='collision_severity', palette='viridis')

# Step 4: Title and labels
plt.title("Weather Condition vs. Collision Severity")
plt.xlabel("Weather Condition")
plt.ylabel("Number of Collisions")
plt.xticks(rotation=45)
plt.legend(title="Collision Severity", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

# Step 5: Show the plot
# plt.show()

# Save plot locally on EMR
local_path = "/tmp/Weather_Condition_vs_Collision_Severity.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Weather_Condition_vs_Collision_Severity.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Weather_Condition_vs_Collision_Severity.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Weather_Condition_vs_Collision_Severity.png

In [83]:
# Weather vs. Collision Severity
weather_vs_severity_df = collisions_df.groupBy("weather_1", "collision_severity").count()

# Convert to Pandas
weather_vs_severity_pd = weather_vs_severity_df.toPandas()

# Plot
plt.figure(figsize=(12, 6))
sns.barplot(data=weather_vs_severity_pd, x='weather_1', y='count', hue='collision_severity', palette='viridis')

# Step 4: Title and labels
plt.title("Weather Condition vs. Collision Severity")
plt.xlabel("Weather Condition")
plt.ylabel("Number of Collisions")
plt.xticks(rotation=45)
plt.legend(title="Collision Severity", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

# Step 5: Show the plot
# plt.show()

# Save plot locally on EMR
local_path = "/tmp/Weather_Condition_vs_Collision_Severity2.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Weather_Condition_vs_Collision_Severity2.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Weather_Condition_vs_Collision_Severity2.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Weather_Condition_vs_Collision_Severity2.png

## **3.1.7. Lighting conditions vs Collision Severity.** <font color = red>[5 marks]</font> <br>

Q:Visualize the impact of lighting conditions on collision severity.

In [84]:
# Lighting Conditions vs. Collision Severity
lighting_severity_df = collisions_df1.groupBy("lighting", "collision_severity").count()

# Convert to Pandas
lighting_severity_pd = lighting_severity_df.toPandas()

# Plot
plt.figure(figsize=(12, 6))
sns.barplot(
    data=lighting_severity_pd,
    x="lighting",
    y="count",
    hue="collision_severity",
    palette="viridis"
)

# Aesthetics
plt.title("Lighting Conditions vs. Collision Severity")
plt.xlabel("Lighting Condition")
plt.ylabel("Number of Collisions")
plt.xticks(rotation=45)
plt.legend(title="Collision Severity", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

#plt.show()

# Save plot locally on EMR
local_path = "/tmp/Lighting_Condition_vs_Collision_Severity.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Lighting_Condition_vs_Collision_Severity.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Lighting_Condition_vs_Collision_Severity.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Lighting_Condition_vs_Collision_Severity.png

In [85]:
# Lighting Conditions vs. Collision Severity
lighting_severity_df = collisions_df.groupBy("lighting", "collision_severity").count()

# Convert to Pandas
lighting_severity_pd = lighting_severity_df.toPandas()

# Plot
plt.figure(figsize=(12, 6))
sns.barplot(
    data=lighting_severity_pd,
    x="lighting",
    y="count",
    hue="collision_severity",
    palette="viridis"
)

# Aesthetics
plt.title("Lighting Conditions vs. Collision Severity")
plt.xlabel("Lighting Condition")
plt.ylabel("Number of Collisions")
plt.xticks(rotation=45)
plt.legend(title="Collision Severity", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

#plt.show()

# Save plot locally on EMR
local_path = "/tmp/Lighting_Condition_vs_Collision_Severity2.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Lighting_Condition_vs_Collision_Severity2.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Lighting_Condition_vs_Collision_Severity2.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Lighting_Condition_vs_Collision_Severity2.png

## **3.1.8. Weekday-Wise Collision Trends.** <font color = red>[7 marks]</font> <br>

Q: Extract and analyze weekday-wise collision trends.

In [86]:
# Extract the weekday
from pyspark.sql.functions import col, to_date, date_format

# Step 1: Ensure date is in date format
collisions_df1 = collisions_df1.withColumn("collision_date", to_date("collision_date"))

# Step 2: Extract weekday name
collisions_df1 = collisions_df1.withColumn("weekday", date_format("collision_date", "EEEE"))

# Step 3: Group by weekday and count
weekday_df = collisions_df1.groupBy("weekday").count()

# Step 4: Convert to Pandas
weekday_pd = weekday_df.toPandas()

# Step 5: Order weekdays correctly
weekday_order = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
weekday_pd["weekday"] = pd.Categorical(weekday_pd["weekday"], categories=weekday_order, ordered=True)
weekday_pd = weekday_pd.sort_values("weekday")

# Plot
plt.figure(figsize=(10, 6))
sns.barplot(data=weekday_pd, x="weekday", y="count", palette="crest")

plt.title("Number of Collisions per Weekday")
plt.xlabel("Day of the Week")
plt.ylabel("Number of Collisions")
plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
#plt.show()

# Save plot locally on EMR
local_path = "/tmp/Num_Collisions_per_Weekday.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Num_Collisions_per_Weekday.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Num_Collisions_per_Weekday.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Num_Collisions_per_Weekday.png

## **3.1.9. Spatial Distribution of Collisions.** <font color = red>[7 marks]</font> <br>

Q: Study spatial distribution of collisions by county.

In [87]:
# Spatial Analysis
from pyspark.sql.functions import col

# Collision Density by County
# Step 1: Ensure the columns exist
collisions_df_copy = spark.read.option("header", True).csv("/user/livy/sample_collisions.csv")

collisions_df_copy = (
    collisions_df_copy
    .withColumn("latitude", col("latitude").cast(DoubleType()))
    .withColumn("longitude", col("longitude").cast(DoubleType()))
)

# Step 2: Filter out invalid/missing coordinates in PySpark
map_df = collisions_df_copy.filter(
    (col("latitude").isNotNull()) & (col("longitude").isNotNull()) &
    (col("latitude") != 0) & (col("longitude") != 0)
)
map_pd = map_df.select("latitude", "longitude").toPandas()

# Plot the map
plt.figure(figsize=(10, 6))
plt.scatter(map_pd["longitude"], map_pd["latitude"], alpha=0.2, s=10, c='blue')

plt.title("Spatial Distribution of Collisions")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.grid(True)
plt.tight_layout()
#plt.show()

# Optional: Print number of plotted points
print(f"\nTotal valid coordinates plotted: {len(map_pd)}")


# Save plot locally on EMR
local_path = "/tmp/Spatial_Dist_Collision.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Spatial_Dist_Collision.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Total valid coordinates plotted: 266742
? Plot saved locally at: /tmp/Spatial_Dist_Collision.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Spatial_Dist_Collision.png

## **3.1.10. Collision Analysis by Geography.** <font color = red>[6 marks]</font> <br>

Q: Generate a scatter plot to analyze collision locations geographically.

In [88]:
# Scatter Plot of Collision Locations
collisions_df_copy = (
    collisions_df_copy
    .withColumn("latitude", col("latitude").cast(DoubleType()))
    .withColumn("longitude", col("longitude").cast(DoubleType()))
)


# Convert PySpark DataFrame to Pandas, handling potential missing values
map_df = collisions_df_copy.filter(
    (col("latitude").isNotNull()) & (col("longitude").isNotNull()) &
    (col("latitude") != 0) & (col("longitude") != 0)
)
map_pd = map_df.select("latitude", "longitude").toPandas()

# Convert to numeric (if needed) and handle invalid data

# Plot the scatter plot
plt.figure(figsize=(10, 6))
plt.scatter(map_pd["longitude"], map_pd["latitude"], alpha=0.2, s=10, c='blue')

plt.title("Scatter Plot of Collision Locations")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.grid(True)
plt.tight_layout()
# plt.show()

# Step 5: Optional count of valid records
print(f"\nTotal valid coordinates plotted: {len(map_pd)}")


# Save plot locally on EMR
local_path = "/tmp/Scatter_Plot_Collision_Locations.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Scatter_Plot_Collision_Locations.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…


Total valid coordinates plotted: 266742
? Plot saved locally at: /tmp/Scatter_Plot_Collision_Locations.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Scatter_Plot_Collision_Locations.png

## **3.1.11. Collision Trends Over Time.** <font color = red>[10 marks]</font> <br>

Extract and analyzing collision trends over time.

In [89]:
from pyspark.sql.functions import year, month, hour, to_timestamp, col

# Extract year and month from collision_date
collisions_df1 = collisions_df1.withColumn("collision_date", to_date(col("collision_date")))

collisions_df1 = collisions_df1.withColumn("year", year("collision_date"))
collisions_df1 = collisions_df1.withColumn("month", month("collision_date"))

# Plot
date_stats_pd = collisions_df1.select("year", "month").dropna().toPandas()

# Plotting collisions by month
plt.figure(figsize=(10, 6))
sns.countplot(data=date_stats_pd, x="month", palette="Blues")

plt.title("Collisions by Month")
plt.xlabel("Month")
plt.ylabel("Count")
plt.grid(True)
plt.tight_layout()
#plt.show()



# Save plot locally on EMR
local_path = "/tmp/Collision_BY_Month.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Collision_BY_Month.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Collision_BY_Month.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Collision_BY_Month.png

Q: Analyze yearly, monthly and hourly trends in collisions.

In [90]:
# Yearly Trend of Collisions
collisions_df1 = collisions_df1.withColumn("collision_datetime", to_timestamp("collision_date"))
collisions_df1 = (
    collisions_df1
    .withColumn("year", year("collision_datetime"))
    .withColumn("month", month("collision_datetime"))
    .withColumn("hour", hour(to_timestamp("collision_time", "HHmm")))
)
yearly_df = collisions_df1.groupBy("year").count().orderBy("year")
yearly_pd = yearly_df.toPandas()

# Plot
plt.figure(figsize=(10, 5))
sns.lineplot(data=yearly_pd, x="year", y="count", marker="o", color="steelblue")
plt.title("Yearly Trend of Collisions")
plt.xlabel("Year")
plt.ylabel("Number of Collisions")
plt.grid(True)
plt.tight_layout()
# plt.show()

# Save plot locally on EMR
local_path = "/tmp/Yearly_Trend_Collisions.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Yearly_Trend_Collisions.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Yearly_Trend_Collisions.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Yearly_Trend_Collisions.png

In [91]:
# Monthly Trend of Collisions
monthly_df = collisions_df1.groupBy("month").count().orderBy("month")
monthly_pd = monthly_df.toPandas()

# Plot
plt.figure(figsize=(10, 5))
sns.barplot(data=monthly_pd, x="month", y="count", palette="viridis")
plt.title("Monthly Trend of Collisions")
plt.xlabel("Month")
plt.ylabel("Number of Collisions")
plt.grid(True)
plt.tight_layout()
# plt.show()
# Save plot locally on EMR
local_path = "/tmp/Monthly_Trend_Collisions.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Monthly_Trend_Collisions.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Monthly_Trend_Collisions.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Monthly_Trend_Collisions.png

In [92]:
# Hourly Trend of Collisions
from pyspark.sql.functions import col, hour, to_timestamp

hourly_df = collisions_df1.withColumn("collision_time", to_timestamp("collision_time", "HH:mm:ss"))
hourly_df = hourly_df.withColumn("hour", hour("collision_time"))
hourly_df = hourly_df.groupBy("hour").count().orderBy("hour")
hourly_pd = hourly_df.toPandas()

# Plot
plt.figure(figsize=(10, 5))
sns.lineplot(data=hourly_pd, x="hour", y="count", marker="o", color="orange")
plt.title("Hourly Trend of Collisions")
plt.xlabel("Hour of Day")
plt.ylabel("Number of Collisions")
plt.xticks(range(0, 24))
plt.grid(True)
plt.tight_layout()
# plt.show()
# Save plot locally on EMR
local_path = "/tmp/Hourly_Trend_Collisions.png"
plt.savefig(local_path)
print(f"✅ Plot saved locally at: {local_path}")

# Upload to S3
bucket_name = "trafficds71"
s3_key = "plots/Hourly_Trend_Collisions.png"  # You can customize folder/key structure

# Upload
s3 = boto3.client("s3")
s3.upload_file(local_path, bucket_name, s3_key)

# Generate viewable URL (if public or accessible)
s3_url = f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
print(f"📊 Plot uploaded to: {s3_url}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

? Plot saved locally at: /tmp/Hourly_Trend_Collisions.png
? Plot uploaded to: https://trafficds71.s3.amazonaws.com/plots/Hourly_Trend_Collisions.png

# **4. ETL Querying** <font color = red>[35 marks]</font> <br>

## **4.1 Loading the Dataset** <font color = red>[1 marks]</font> <br>

Q: Load the processed dataset as CSV files in S3 bucket.

In [93]:
# Write your query here
# Loding sample_case_ids
case_ids_df1.coalesce(1) \
  .write \
  .option("header", "true") \
  .mode("overwrite") \
  .csv("s3://trafficds71/temp_output/")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [94]:
import boto3

s3 = boto3.client('s3')
bucket = 'trafficds71'
prefix = 'temp_output/'

# List objects to find the actual CSV file
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)

for obj in response.get('Contents', []):
    key = obj['Key']
    if key.endswith(".csv"):
        print(f"Renaming {key} to output/sample_case_ids.csv")

        # Copy with new name
        s3.copy_object(
            Bucket=bucket,
            CopySource={'Bucket': bucket, 'Key': key},
            Key='output/sample_case_ids.csv'
        )

        # Optionally delete original part file
        s3.delete_object(Bucket=bucket, Key=key)
        break

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Renaming temp_output/part-00000-ca223808-ab45-47a1-95fa-58d507e86653-c000.csv to output/sample_case_ids.csv
{'ResponseMetadata': {'RequestId': 'C6JEFWFENDEA7QYP', 'HostId': 'nJsti4b/xRtHyaD2tlJFV0qvoATuh0myABq0vRKhWHt3Ln/wTZcfbYyCYUw8g5Oft/VBi+VnsL0=', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amz-id-2': 'nJsti4b/xRtHyaD2tlJFV0qvoATuh0myABq0vRKhWHt3Ln/wTZcfbYyCYUw8g5Oft/VBi+VnsL0=', 'x-amz-request-id': 'C6JEFWFENDEA7QYP', 'date': 'Tue, 17 Jun 2025 05:30:03 GMT', 'x-amz-server-side-encryption': 'AES256', 'content-type': 'application/xml', 'content-length': '275', 'server': 'AmazonS3'}, 'RetryAttempts': 0}, 'ServerSideEncryption': 'AES256', 'CopyObjectResult': {'ETag': '"3e0531c0343e1e80270d2ffed8d2b164"', 'LastModified': datetime.datetime(2025, 6, 17, 5, 30, 3, tzinfo=tzlocal()), 'ChecksumCRC64NVME': 'mVUbjxJU1AM='}}
{'ResponseMetadata': {'RequestId': 'N719RW01Z7FBW7T2', 'HostId': 'PoFV3MxhOdEh6+3+CKNkszQw/iO6R15zMwdSR5m/LQsIGHlEX+HuRTKJ9R6OJXZ6xvzpQR7dqo0=', 'HTTPStatusCode': 204, 'HTT

In [95]:
# Write your query here
# Loding sample_collisions
collisions_df1.coalesce(1) \
  .write \
  .option("header", "true") \
  .mode("overwrite") \
  .csv("s3://trafficds71/temp_output/")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [96]:
import boto3

s3 = boto3.client('s3')
bucket = 'trafficds71'
prefix = 'temp_output/'

# List objects to find the actual CSV file
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)

for obj in response.get('Contents', []):
    key = obj['Key']
    if key.endswith(".csv"):
        print(f"Renaming {key} to output/sample_collisions.csv")

        # Copy with new name
        s3.copy_object(
            Bucket=bucket,
            CopySource={'Bucket': bucket, 'Key': key},
            Key='output/sample_collisions.csv'
        )

        # Optionally delete original part file
        s3.delete_object(Bucket=bucket, Key=key)
        break

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Renaming temp_output/part-00000-66be7f5b-5931-41f2-a3c0-aedcc1363a1a-c000.csv to output/sample_collisions.csv
{'ResponseMetadata': {'RequestId': 'T0X5XKMSESJ6GKWA', 'HostId': 'fH0Jr1tzQit7gfI6T6MN+fkjxcoGy8DdrxI9pZPZ4E1JGIgpic57IS/ZDDbGbMfAmTM7xXein1I=', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amz-id-2': 'fH0Jr1tzQit7gfI6T6MN+fkjxcoGy8DdrxI9pZPZ4E1JGIgpic57IS/ZDDbGbMfAmTM7xXein1I=', 'x-amz-request-id': 'T0X5XKMSESJ6GKWA', 'date': 'Tue, 17 Jun 2025 05:30:38 GMT', 'x-amz-server-side-encryption': 'AES256', 'content-type': 'application/xml', 'content-length': '275', 'server': 'AmazonS3'}, 'RetryAttempts': 0}, 'ServerSideEncryption': 'AES256', 'CopyObjectResult': {'ETag': '"81e8d72ce1be69152a4298c8a0fdb15f"', 'LastModified': datetime.datetime(2025, 6, 17, 5, 30, 38, tzinfo=tzlocal()), 'ChecksumCRC64NVME': '+QyFWSgbg2A='}}
{'ResponseMetadata': {'RequestId': '977DS0B5K3JBYNZ8', 'HostId': 'HefPZPGnv+ADVh50t/qa1+LS1E6iMiZtaXZJmjUhdd8WX0UZVYPiHCUVzKQ2PlOgwJsnGEQVlnA=', 'HTTPStatusCode': 204, '

In [97]:
# Write your query here
# Loding sample_parties
parties_df1.coalesce(1) \
  .write \
  .option("header", "true") \
  .mode("overwrite") \
  .csv("s3://trafficds71/temp_output/")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [98]:
import boto3

s3 = boto3.client('s3')
bucket = 'trafficds71'
prefix = 'temp_output/'

# List objects to find the actual CSV file
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)

for obj in response.get('Contents', []):
    key = obj['Key']
    if key.endswith(".csv"):
        print(f"Renaming {key} to output/sample_parties.csv")

        # Copy with new name
        s3.copy_object(
            Bucket=bucket,
            CopySource={'Bucket': bucket, 'Key': key},
            Key='output/sample_parties.csv'
        )

        # Optionally delete original part file
        s3.delete_object(Bucket=bucket, Key=key)
        break

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Renaming temp_output/part-00000-ea443920-ff5a-4e33-b343-7794a0e16aec-c000.csv to output/sample_parties.csv
{'ResponseMetadata': {'RequestId': '4M2YKS28YDP4X7MQ', 'HostId': 't4phz2U5qNLudPwH9kS1M417ppQ7yEOZmYU4+UEsdYxrJ6tz36RbqH6yni/Nly3NZxow060o4CU=', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amz-id-2': 't4phz2U5qNLudPwH9kS1M417ppQ7yEOZmYU4+UEsdYxrJ6tz36RbqH6yni/Nly3NZxow060o4CU=', 'x-amz-request-id': '4M2YKS28YDP4X7MQ', 'date': 'Tue, 17 Jun 2025 05:31:02 GMT', 'x-amz-server-side-encryption': 'AES256', 'content-type': 'application/xml', 'content-length': '275', 'server': 'AmazonS3'}, 'RetryAttempts': 0}, 'ServerSideEncryption': 'AES256', 'CopyObjectResult': {'ETag': '"ecbadc9c4f8fab169dda8d0cfbf1b3f8"', 'LastModified': datetime.datetime(2025, 6, 17, 5, 31, 2, tzinfo=tzlocal()), 'ChecksumCRC64NVME': 'rir81dE2gA0='}}
{'ResponseMetadata': {'RequestId': 'F6BNJ6C9Z0Q913NA', 'HostId': 'sSOfduhXVWhybyrQDXag9a6E4azBi7/HZqsfmDR+tMMrHpwIe8RrU5vuixmJguZaytWCur7CxMw=', 'HTTPStatusCode': 204, 'HTTP

In [99]:
# Write your query here
# Loding sample_victims
victims_df1.coalesce(1) \
  .write \
  .option("header", "true") \
  .mode("overwrite") \
  .csv("s3://trafficds71/temp_output/")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [100]:
import boto3

s3 = boto3.client('s3')
bucket = 'trafficds71'
prefix = 'temp_output/'

# List objects to find the actual CSV file
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)

for obj in response.get('Contents', []):
    key = obj['Key']
    if key.endswith(".csv"):
        print(f"Renaming {key} to output/sample_victims.csv")

        # Copy with new name
        s3.copy_object(
            Bucket=bucket,
            CopySource={'Bucket': bucket, 'Key': key},
            Key='output/sample_victims.csv'
        )

        # Optionally delete original part file
        s3.delete_object(Bucket=bucket, Key=key)
        break

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Renaming temp_output/part-00000-c34c80d2-b3fe-43fa-a3a7-b6a16b8e1259-c000.csv to output/sample_victims.csv
{'ResponseMetadata': {'RequestId': '96E3304G73S1GVCJ', 'HostId': 'u9hwysQLSNmp1MPZczba0yeEK+hmxYQIPfva0tGeZYRYxDtfjUIsPK87+zjuJ1W+gMomWX8bDTM3k3vyUj7gmv9jGNa0FUyW', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amz-id-2': 'u9hwysQLSNmp1MPZczba0yeEK+hmxYQIPfva0tGeZYRYxDtfjUIsPK87+zjuJ1W+gMomWX8bDTM3k3vyUj7gmv9jGNa0FUyW', 'x-amz-request-id': '96E3304G73S1GVCJ', 'date': 'Tue, 17 Jun 2025 05:31:17 GMT', 'x-amz-server-side-encryption': 'AES256', 'content-type': 'application/xml', 'content-length': '275', 'server': 'AmazonS3'}, 'RetryAttempts': 0}, 'ServerSideEncryption': 'AES256', 'CopyObjectResult': {'ETag': '"994a462afc0d6c8e860ce1a625a53707"', 'LastModified': datetime.datetime(2025, 6, 17, 5, 31, 17, tzinfo=tzlocal()), 'ChecksumCRC64NVME': 'dVADEQkFA9I='}}
{'ResponseMetadata': {'RequestId': 'PF2YEYR56FT6VQDE', 'HostId': 'fSc+ou/ngoUesWR6xLBlKmod+frybFSdJxAcXptuS1wR9sQMwNFhJmlbQlHS8KWbff

## **4.2. Top 5 Counties** <font color = red>[4 marks]</font> <br>

Q: Identify the top 5 counties with the highest number of collisions.

In [101]:
# Query: Identify the top 5 counties with the most collisions
from pyspark.sql.functions import col, desc, month, count

top5_counties = collisions_df1.groupBy("county_location") \
    .count() \
    .orderBy(desc("count")) \
    .limit(5)

top5_counties.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------------+------+
|county_location| count|
+---------------+------+
|    los angeles|284100|
|         orange| 72042|
| san bernardino| 56737|
|      san diego| 53105|
|      riverside| 48686|
+---------------+------+

## **4.3. Month with Highest Collisions** <font color = red>[5 marks]</font> <br>

Q. Identify the month with the highest number of collisions.

In [102]:
# Query: Find the month with the highest number of collisions
# Extract month with highest collisions
month_with_high_collisions = collisions_df1 \
    .withColumn("collision_month", month("collision_date")) \
    .groupBy("collision_month") \
    .agg(count("*").alias("collision_count")) \
    .orderBy(desc("collision_count")) \
    .limit(1)

month_with_high_collisions.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------------+---------------+
|collision_month|collision_count|
+---------------+---------------+
|             10|          83274|
+---------------+---------------+

## **4.4. Weather Conditions with Highest Collisions.** <font color = red>[5 marks]</font> <br>

Q. Determine the most common weather condition during collisions.

In [103]:
# Query: Find the most common weather condition during collisions

weather_with_most_col = collisions_df1 \
    .groupBy("weather_1") \
    .agg(count("*").alias("collision_count")) \
    .orderBy(desc("collision_count")) \
    .limit(1)

weather_with_most_col .show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+---------------+
|weather_1|collision_count|
+---------+---------------+
|    clear|         769927|
+---------+---------------+

## **4.5. Fatal Collisions.** <font color = red>[5 marks]</font> <br>

Q. Calculate the percentage of collisions that resulted in fatalities.

In [104]:
# Query: Determine the percentage of collisions that resulted in fatalities
from pyspark.sql.functions import col, count, when
# Total number of collisions
total_col = collisions_df1.count()

# Count of Collisions with fatalities
fatal_col = collisions_df1.filter(col("killed_victims") > 0).count()

# percentage Calculation
fatal_percent = (fatal_col / total_col) * 100

print(f"Percentage of collisions with fatalities: {fatal_percent:.2f}%")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Percentage of collisions with fatalities: 0.00%

In [105]:
fatal_col

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

0

## **4.6. Dangerous Time for Collisions.** <font color = red>[5 marks]</font> <br>

Q. Find the most dangerous time of day for collisions.

In [106]:
# Query: Find the most dangerous time of day for collisions
from pyspark.sql.functions import hour, count, desc

# Extract hour and rsepective count collisions
dangerous_hour = collisions_df1 \
    .withColumn("collision_hour", hour("collision_time")) \
    .groupBy("collision_hour") \
    .agg(count("*").alias("collision_count")) \
    .orderBy(desc("collision_count")) \
    .limit(1)

dangerous_hour.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------+---------------+
|collision_hour|collision_count|
+--------------+---------------+
|            17|          73255|
+--------------+---------------+

## **4.7. Road Surface Conditions.** <font color = red>[5 marks]</font> <br>

Q. Identify the top 5 road surface conditions with the highest collision frequency.

In [107]:
collisions_df1.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+--------------+------------+--------------------+---------------+--------------+------------+----------+--------------+----------------+--------------------+---------------+-----------------+-----------------+-------------+--------------+-----------+--------------+--------------+--------+---------+------------+---------+-----------------------+--------+-----------+------------------------+----------------------+-------------+---------------+-----------------+---------------------------+--------------------+------------+----------------+--------------------+--------------+-------------+--------------------+-----------------+--------------------+---------------+--------------------+-------------------------------+-------------------------+-------------------+--------------------------+------------------------------+-----------------------+------------------------+----------------------+-----------------------+-------------------------+--------------------------+-----

In [108]:
# Query: List the top 5 road types with the highest collision frequency
top5_road_types = collisions_df1 \
    .groupBy("road_surface") \
    .agg(count("*").alias("collision_count")) \
    .orderBy(desc("collision_count")) \
    .limit(5)

top5_road_types.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+---------------+
|road_surface|collision_count|
+------------+---------------+
|         dry|         845637|
|         wet|          76833|
|     Unknown|           8141|
|       snowy|           4122|
|    slippery|           1048|
+------------+---------------+

## **4.8. Lighting Conditions.** <font color = red>[5 marks]</font> <br>

Q. Analyze lighting conditions that contribute to the highest number of collisions.

In [109]:
# Query: Find the top 3 lighting conditions that lead to the most collisions

top3_lighting_conditions = collisions_df1 \
    .groupBy("lighting") \
    .agg(count("*").alias("collision_count")) \
    .orderBy(desc("collision_count")) \
    .limit(3)

top3_lighting_conditions.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+---------------+
|            lighting|collision_count|
+--------------------+---------------+
|            daylight|         625574|
|dark with street ...|         195867|
|dark with no stre...|          74634|
+--------------------+---------------+

# 5. Conclusion <font color = red>[10 marks]</font> <br>

Write your conclusion.

# Conclusion
This project effectively leveraged Apache Spark and AWS S3 to process and analyze large-scale California traffic collision data. Through detailed ETL operations and rich visual analysis, we uncovered crucial patterns in accident severity, timing, locations, and contributing factors.

These insights can help traffic authorities, city planners, and policymakers:

## 📊Key Insights from Collision Data

### 1. Collision Severity
- **Property damage only** is the most common, with over **500,000** cases.
- **Pain-related injuries** are the second highest, over **200,000** cases.

### 2. Weather Conditions
- Most collisions occurred in **clear weather**, likely due to high traffic.
- **Cloudy** and **raining** conditions also show significant collision counts.

### 3. Weather vs Severity
- Majority of severe collisions happen during **clear weather**.
- **Cloudy and rainy** conditions also contribute to higher severity.

### 4. Victim Age Distribution
- Highest collision involvement is in the **16–24 age group**.
- Notably high involvement also seen in **0–4 age group**.

### 5. Lighting Conditions
- Most collisions occur during **daylight**.
- **Dark with street lights** is the second most common lighting condition.

### 6. Hourly Collision Trend
- Peak hours: **15:00–17:00**, likely due to office traffic.
- Collisions drop post-midnight and rise sharply from **08:00** onward.

### 7. Daily Trend
- **Friday** sees the most collisions; **Sunday** the least.

### 8. Monthly Trend
- Highest collisions in **March**, **October**, and **December**, likely due to holidays and festive traffic.

### 9. Yearly Trend
- Decrease from **2000–2013**, spike till **2016**, then a drop.
- **Post-2019 decline** aligns with COVID-19 pandemic and reduced mobility.

### 10. Geographical Trends
- High collision density in **Los Angeles**, **San Francisco Bay Area**, and **Central Valley**.
- Sparse data in **eastern deserts** and **northern rural regions** due to low population density.


## ✅ Key Recommendations for Policy Makers & Traffic Authorities

1. **Focus on High-Risk Areas & Time Slots**  
   - Improve infrastructure and traffic control in regions with high collision density (e.g., Los Angeles, Bay Area).  
   - Increase monitoring and enforcement during peak hours (15:00–17:00) and high-risk days (Fridays).

2. **Implement Targeted Safety Campaigns**  
   - Launch awareness programs for vulnerable age groups (16–24 and 0–4 years).  
   - Enforce child safety seat usage and promote defensive driving education.

3. **Leverage Technology & Seasonal Planning**  
   - Use predictive analytics and smart traffic systems for weather-based and real-time alerts.  
   - Plan special safety measures during festive months (March, October, December) when collisions spike.


# 6. Visualization Integration using Tableau/ PowerBI <font color = red>[Optional]</font> <br>