In [1]:
import pandas as pd
import pyarrow

from sqlalchemy import create_engine, text
from time import time

In [2]:
# get the data
# !wget https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-06.parquet

# get data understanding
ytdata = pd.read_parquet('yellow_tripdata_2024-06.parquet', engine='pyarrow')
ytdata.head()

Unnamed: 0,VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,Airport_fee
0,1,2024-06-01 00:03:46,2024-06-01 00:31:23,1.0,12.5,1.0,N,138,195,1,48.5,7.75,0.5,11.55,0.0,1.0,69.3,0.0,1.75
1,2,2024-06-01 00:55:22,2024-06-01 01:08:24,1.0,4.34,1.0,N,138,7,1,20.5,6.0,0.5,8.4,0.0,1.0,38.15,0.0,1.75
2,1,2024-06-01 00:23:53,2024-06-01 00:32:35,1.0,1.3,1.0,N,166,41,1,10.0,1.0,0.5,3.1,0.0,1.0,15.6,0.0,0.0
3,1,2024-06-01 00:32:24,2024-06-01 00:40:06,1.0,1.2,1.0,N,148,114,1,8.6,3.5,0.5,0.2,0.0,1.0,13.8,2.5,0.0
4,1,2024-06-01 00:51:38,2024-06-01 00:58:17,1.0,1.0,1.0,N,148,249,1,7.2,3.5,0.5,2.0,0.0,1.0,14.2,2.5,0.0


In [3]:
ytdata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3539193 entries, 0 to 3539192
Data columns (total 19 columns):
 #   Column                 Dtype         
---  ------                 -----         
 0   VendorID               int32         
 1   tpep_pickup_datetime   datetime64[us]
 2   tpep_dropoff_datetime  datetime64[us]
 3   passenger_count        float64       
 4   trip_distance          float64       
 5   RatecodeID             float64       
 6   store_and_fwd_flag     object        
 7   PULocationID           int32         
 8   DOLocationID           int32         
 9   payment_type           int64         
 10  fare_amount            float64       
 11  extra                  float64       
 12  mta_tax                float64       
 13  tip_amount             float64       
 14  tolls_amount           float64       
 15  improvement_surcharge  float64       
 16  total_amount           float64       
 17  congestion_surcharge   float64       
 18  Airport_fee           

In [4]:
# connect to database and upload this data 
db_user = 'root'
db_password = 'root'
db_host = 'localhost'
db_port = 5431
db_name = 'ny_taxi'

connection_string = f'postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}'
engine = create_engine(connection_string)
db_connection = engine.connect()
db_connection

<sqlalchemy.engine.base.Connection at 0x300ad89e0>

In [13]:
try:
    # execute query
    print(db_connection.execute(text("SELECT count(*) FROM yellow_taxi_trips")).scalar())
except Exception as e:
    # check connection status
    if db_connection.closed or db_connection.invalidated:
        # roll back transaction
        db_connection.rollback()
    raise e

0


In [14]:
# get list of tables
tables = db_connection.execute(text("SELECT table_name FROM information_schema.tables WHERE table_schema='public'"))
table_names = [table[0] for table in tables.fetchall()]

#get row count for the tables
row_counts = {table: db_connection.execute(text(f"SELECT COUNT(*) FROM {table}")).scalar() for table in table_names}
row_counts

{'yellow_taxi_trips': 0, 'orders': 5}

In [15]:
#get schema of df (not necessary)
print(pd.io.sql.get_schema(ytdata, name='yellow_taxi_trips'))

CREATE TABLE "yellow_taxi_trips" (
"VendorID" INTEGER,
  "tpep_pickup_datetime" TIMESTAMP,
  "tpep_dropoff_datetime" TIMESTAMP,
  "passenger_count" REAL,
  "trip_distance" REAL,
  "RatecodeID" REAL,
  "store_and_fwd_flag" TEXT,
  "PULocationID" INTEGER,
  "DOLocationID" INTEGER,
  "payment_type" INTEGER,
  "fare_amount" REAL,
  "extra" REAL,
  "mta_tax" REAL,
  "tip_amount" REAL,
  "tolls_amount" REAL,
  "improvement_surcharge" REAL,
  "total_amount" REAL,
  "congestion_surcharge" REAL,
  "Airport_fee" REAL
)


In [8]:
#create an empty table - yellow_taxi_trips
ytdata.head(0).to_sql(con=engine, name='yellow_taxi_trips', if_exists='replace')

In [16]:
# get list of tables
tables = db_connection.execute(text("SELECT table_name FROM information_schema.tables WHERE table_schema='public'"))
table_names = [table[0] for table in tables.fetchall()]

#get row count for the tables
row_counts = {table: db_connection.execute(text(f"SELECT COUNT(*) FROM {table}")).scalar() for table in table_names}
row_counts

{'yellow_taxi_trips': 0, 'orders': 5}

In [17]:
# ingest data into table
import pyarrow.parquet as pq

start_time = time()

parquet_file = pq.ParquetFile('yellow_tripdata_2024-06.parquet')

connection_string = f'postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}'
engine = create_engine(connection_string)
chunksize = 100000

num_row_groups = parquet_file.num_row_groups
for i in range(num_row_groups):
    chunk = parquet_file.read_row_group(i, columns=parquet_file.schema.names)
    df_chunk = chunk.to_pandas()
    df_chunk.to_sql(name='yellow_taxi_trips', con=engine, if_exists='append', index=False)
    print(f"Inserted row group {i + 1} of {num_row_groups}")
    
end_time = time()
print(f"Ingestion completed. Total time: {end_time - start_time:.2f} seconds")

# Verify the data was inserted
row_count = db_connection.execute(text("SELECT COUNT(*) FROM yellow_taxi_trips")).scalar()
print(f"Total rows in yellow_taxi_trips: {row_count}")

Inserted row group 1 of 4
Inserted row group 2 of 4
Inserted row group 3 of 4
Inserted row group 4 of 4
Ingestion completed. Total time: 203.10 seconds
Total rows in yellow_taxi_trips: 3539193
