In [1]:
from pyspark.sql import SparkSession

path_to_jdbc_driver_jar = "/Users/pintoza/Desktop/dev/data-science/taxi-demand-forecast/postgresql-42.7.1.jar"

# Increase the memory allocation
spark = SparkSession.builder \
    .appName("Taxi Data Analysis") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.jars", path_to_jdbc_driver_jar) \
    .getOrCreate()

24/01/30 19:52:24 WARN Utils: Your hostname, Zachs-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.1.134 instead (on interface en0)
24/01/30 19:52:24 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/01/30 19:52:25 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/01/30 19:52:28 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()

# Retrieve variables
database_url = os.getenv("DB_URL")
db_user = os.getenv("DB_USERNAME")
db_password = os.getenv("DB_PASSWORD")

# Database connection properties
properties = {"user": db_user, "password": db_password, "driver": "org.postgresql.Driver"}

# Test query to load a small subset of data
query = "(SELECT * FROM taxi_trips LIMIT 100) AS subquery"
df = spark.read.jdbc(url=database_url, table=query, properties=properties)
df.show()

[Stage 0:>                                                          (0 + 1) / 1]

+--------------------+---------------------+---------------+-------------+------------+------------+------------+
|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|total_amount|DOLocationID|PULocationID|
+--------------------+---------------------+---------------+-------------+------------+------------+------------+
|11/14/2018 07:35:...| 11/14/2018 07:41:...|              1|         0.83|         8.8|         230|         142|
|11/14/2018 07:35:...| 11/14/2018 08:00:...|              1|          4.1|       24.13|          65|         137|
|11/14/2018 07:35:...| 11/14/2018 07:38:...|              1|          0.4|         6.8|         234|         164|
|11/14/2018 07:35:...| 11/14/2018 07:44:...|              1|         1.56|        10.8|         236|         237|
|11/14/2018 07:35:...| 11/14/2018 07:49:...|              1|         2.01|        13.3|         262|         237|
|11/14/2018 07:35:...| 11/14/2018 08:27:...|              1|         8.05|        38.8| 

                                                                                

In [3]:
from pyspark.sql.functions import col, max as max_

# Calculate the mode of passenger_count
mode_value = df.groupBy("passenger_count").count().orderBy(col("count").desc()).first()["passenger_count"]
mode_value

                                                                                

1

In [5]:
from pyspark.sql.functions import coalesce, lit

# Impute NA values with the mode
df = df.withColumn("passenger_count", coalesce(df["passenger_count"], lit(mode_value)))

In [6]:
from pyspark.sql.functions import col, count, when, isnull

# Counting the number of nulls in each column
null_counts = df.select([count(when(isnull(c), c)).alias(c) for c in df.columns])

# Show the count of nulls for each column
null_counts.show()

+--------------------+---------------------+---------------+-------------+------------+------------+------------+
|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|total_amount|DOLocationID|PULocationID|
+--------------------+---------------------+---------------+-------------+------------+------------+------------+
|                   0|                    0|              0|            0|           0|           0|           0|
+--------------------+---------------------+---------------+-------------+------------+------------+------------+


In [7]:
# Print the schema to check the data types
df.printSchema()

root
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: long (nullable = false)
 |-- trip_distance: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- PULocationID: long (nullable = true)


In [8]:
from pyspark.sql.functions import to_timestamp

# Convert datetime fields to TimestampType
df = df.withColumn("tpep_pickup_datetime", to_timestamp(df["tpep_pickup_datetime"], 'MM/dd/yyyy hh:mm:ss a')) \
       .withColumn("tpep_dropoff_datetime", to_timestamp(df["tpep_dropoff_datetime"], 'MM/dd/yyyy hh:mm:ss a'))

# Show the schema again to verify the changes
df.printSchema()

root
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: long (nullable = false)
 |-- trip_distance: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- PULocationID: long (nullable = true)


In [10]:
# Save the DataFrame to a Parquet file in the /data/processed/ directory
# Modify the path as necessary based on the above suggestions
df.write.parquet("file:///Users/pintoza/Desktop/dev/data-science/taxi-demand-forecast/data/processed/taxi_data_processed")

                                                                                