# Import Yellow Taxi Data - Part 1

Import yellow taxi data into Vast S3 and Vast DB.

The schema changes over time, so we need to evolve the schema:
- currently only new fields are added during loading
- datatype changes are handled before loading the parquet into VastDB
- column renames are handled before loading the parquet into VastDB

In [None]:
! pip3 install --quiet vastdb

In [None]:
import os
from io import StringIO
from urllib.parse import urlparse

import boto3
from botocore.exceptions import NoCredentialsError
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc
from pyarrow import csv as pa_csv
import requests

# Custom imports for VASTDB
import vastdb

In [None]:
VASTDB_ENDPOINT = os.getenv("VASTDB_ENDPOINT")
VASTDB_ACCESS_KEY = os.getenv("VASTDB_ACCESS_KEY")
VASTDB_SECRET_KEY = os.getenv("VASTDB_SECRET_KEY")

VASTDB_TWITTER_INGEST_BUCKET = os.getenv("VASTDB_TWITTER_INGEST_BUCKET")
VASTDB_TWITTER_INGEST_SCHEMA = os.getenv("VASTDB_TWITTER_INGEST_SCHEMA")

S3_ENDPOINT = os.getenv("S3A_ENDPOINT")
S3_ACCESS_KEY = os.getenv("S3A_ACCESS_KEY")
S3_SECRET_KEY = os.getenv("S3A_SECRET_KEY")
S3_BUCKET = os.getenv("S3A_BUCKET")

###### SET THIS ######
VASTDB_TWITTER_INGEST_TABLE = 'YELLOW_TRIP_DATA'
###### SET THIS ######

In [None]:
print(f"""
---
VASTDB_ENDPOINT={VASTDB_ENDPOINT}
VASTDB_ACCESS_KEY={VASTDB_ACCESS_KEY[-4:]}
VASTDB_SECRET_KEY=****{VASTDB_SECRET_KEY[-4:]}
VASTDB_TWITTER_INGEST_BUCKET={VASTDB_TWITTER_INGEST_BUCKET}
VASTDB_TWITTER_INGEST_SCHEMA={VASTDB_TWITTER_INGEST_SCHEMA}
VASTDB_TWITTER_INGEST_TABLE={VASTDB_TWITTER_INGEST_TABLE}
---
S3_ENDPOINT={S3_ENDPOINT}
S3_ACCESS_KEY={S3_ACCESS_KEY[-4:]}
S3_SECRET_KEY=****{VASTDB_SECRET_KEY[-4:]}
S3_BUCKET={S3_BUCKET}
""")

In [None]:
def read_parquet(file_path):
    """Reads Parquet data from a file."""
    try:
        return pq.read_table(file_path)
    except Exception as e:
        raise RuntimeError(f"Error reading Parquet file: {e}") from e

def connect_to_vastdb(endpoint, access_key, secret_key):
    """Connects to VastDB."""
    try:
        session = vastdb.connect(endpoint=endpoint, access=access_key, secret=secret_key)
        print("Connected to VastDB")
        return session
    except Exception as e:
        raise RuntimeError(f"Failed to connect to VastDB: {e}") from e

def write_to_vastdb(session, bucket_name, schema_name, table_name, pa_table):
    """Writes data to VastDB."""
    with session.transaction() as tx:
        bucket = tx.bucket(bucket_name)
        schema = bucket.schema(schema_name, fail_if_missing=False) or bucket.create_schema(schema_name)

        table = schema.table(table_name, fail_if_missing=False) or schema.create_table(table_name, pa_table.schema)

        columns_to_add = get_columns_to_add(table.arrow_schema, pa_table.schema)
        for column in columns_to_add:
            table.add_column(column)
            
        try:
            # Attempt to insert data
            table.insert(pa_table)
            print(f"Inserted parquet into {table}")
        except Exception as e:
            print(f"Error during table.insert: {e}")
            
            # Perform schema diff if insert fails
            perform_schema_diff(table.arrow_schema, pa_table.schema)
            raise  # Re-raise the exception for further handling

def perform_schema_diff(existing_schema, new_schema):
    """Compares two schemas and logs the differences, clarifying which is existing and which is new."""
    existing_fields = {field.name.lower(): field for field in existing_schema}
    new_fields = {field.name.lower(): field for field in new_schema}

    print("\nSchema Differences:")
    
    # Check for missing fields in the existing schema
    missing_in_existing = [field for name, field in new_fields.items() if name not in existing_fields]
    if missing_in_existing:
        print("Fields missing in the existing schema:")
        for field in missing_in_existing:
            print(f"  - {field.name} (new): {field.type}")
    else:
        print("No fields are missing in the existing schema.")
    
    # Check for extra fields in the existing schema
    extra_in_existing = [field for name, field in existing_fields.items() if name not in new_fields]
    if extra_in_existing:
        print("Fields present in the existing schema but not in the new schema:")
        for field in extra_in_existing:
            print(f"  - {field.name} (existing): {field.type}")
    else:
        print("No extra fields in the existing schema.")
    
    # Check for type mismatches
    type_mismatches = [
        (existing_fields[name].name, existing_fields[name].type, new_fields[name].type)
        for name in new_fields
        if name in existing_fields and existing_fields[name].type != new_fields[name].type
    ]
    if type_mismatches:
        print("Type mismatches:")
        for name, existing_type, new_type in type_mismatches:
            print(f"  - {name}: (existing) {existing_type} -> (new) {new_type}")
    else:
        print("No type mismatches found.")

def import_to_vastdb(session, bucket_name, schema_name, table_name, files_to_import):
    with session.transaction() as tx:
        bucket = tx.bucket(bucket_name)
        schema = bucket.schema(schema_name, fail_if_missing=False) or bucket.create_schema(schema_name)
        table = schema.table(table_name, fail_if_missing=False)

        if table:
            table.import_files(files_to_import=files_to_import)
        else:
            table = vastdb.util.create_table_from_files(
                schema=schema, 
                table_name=table_name,
                parquet_files=files_to_import
            )

def get_columns_to_add(existing_schema, desired_schema):
    """Identifies columns to add to an existing schema."""
    existing_fields = set(existing_schema.names)
    desired_fields = set(desired_schema.names)
    return [pa.schema([pa.field(name, desired_schema.field(name).type)]) for name in desired_fields - existing_fields]


def query_vastdb(session, bucket_name, schema_name, table_name):
    """Writes data to VastDB."""
    with session.transaction() as tx:
        bucket = tx.bucket(bucket_name)
        schema = bucket.schema(schema_name, fail_if_missing=False) or bucket.create_schema(schema_name)
        table = schema.table(table_name, fail_if_missing=False) or schema.create_table(table_name, pa_table.schema)

        return table.select().read_all()

In [None]:
# Define a function to download files
def download_file(url):
    file_name = os.path.basename(urlparse(url).path)
    
    # Check if file exists, skip if so
    if not os.path.exists(file_name):
        # print(f"Downloading {file_name}...")
        response = requests.get(url, stream=True)

        if response.status_code == 200:
            with open(file_name, "wb") as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            print(f"Downloaded {file_name}")
        else:
            print(f"Failed to download {file_name}. Status code: {response.status_code}")
    else:
        print(f"{file_name} already exists. Skipping download.")


# Define a function to upload the file to S3
def upload_to_s3(file_path, bucket_name, s3_key):
    try:
        # print(f"Uploading {file_path} to S3 bucket {bucket_name}...")
        s3_client.upload_file(file_path, bucket_name, s3_key)
        print(f"File uploaded to s3://{bucket_name}/{s3_key}")
    except FileNotFoundError:
        print(f"The file {file_path} was not found.")
    except NoCredentialsError:
        print("Credentials not available.")
    except Exception as e:
        print(f"Error uploading file to S3: {e}")

# Define a function to delete the file after processing
def delete_file(file_path):
    try:
        os.remove(file_path)
        print(f"Deleted {file_path}")
    except Exception as e:
        print(f"Error deleting {file_path}: {e}")


In [None]:
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc

def process_parquet(file_path):
    """Process and transform the Parquet file data."""
    # Load the Parquet file
    pa_table = pq.read_table(file_path)
    
    # Convert TIMESTAMP[US] to a standard timestamp (nanoseconds)
    columns_to_convert = ['tpep_pickup_datetime', 'tpep_dropoff_datetime']
    for column in columns_to_convert:
        if column in pa_table.column_names:
            # Convert the column to a timestamp with nanosecond precision
            pa_table = pa_table.set_column(
                pa_table.column_names.index(column),
                column,
                pc.cast(pa_table[column], pa.timestamp('ns'))
            )
    
    # Handle NULL columns (e.g., 'airport_fee') and rename to 'Airport_fee'
    null_column = 'airport_fee'
    renamed_column = 'Airport_fee'
    if null_column in pa_table.column_names:
        # Replace NULL values with 0.0 and rename the column
        pa_table = pa_table.set_column(
            pa_table.column_names.index(null_column),
            renamed_column,
            pc.if_else(pc.is_null(pa_table[null_column]), pa.scalar(0.0, pa.float64()), pa_table[null_column])
        )
    
    # Handle casting of passenger_count to int64
    passenger_count_column = 'passenger_count'
    if passenger_count_column in pa_table.column_names:
        # Cast 'passenger_count' to int64
        pa_table = pa_table.set_column(
            pa_table.column_names.index(passenger_count_column),
            passenger_count_column,
            pc.cast(pa_table[passenger_count_column], pa.int64())
        )
    
    # Handle casting of RatecodeID to int32
    ratecode_column = 'RatecodeID'
    if ratecode_column in pa_table.column_names:
        # Cast 'RatecodeID' to int32
        pa_table = pa_table.set_column(
            pa_table.column_names.index(ratecode_column),
            ratecode_column,
            pc.cast(pa_table[ratecode_column], pa.int32())
        )
    
    # Handle LARGE_STRING columns (e.g., 'store_and_fwd_flag')
    string_columns_to_cast = ['store_and_fwd_flag']
    for column in string_columns_to_cast:
        if column in pa_table.column_names:
            # Convert the column to a STRING type
            pa_table = pa_table.set_column(
                pa_table.column_names.index(column),
                column,
                pc.cast(pa_table[column], pa.string())
            )
    
    # Return the processed table
    return pa_table

In [None]:
session = connect_to_vastdb(VASTDB_ENDPOINT, VASTDB_ACCESS_KEY, VASTDB_SECRET_KEY)

# Initialize S3 client
s3_client = boto3.client(
    's3', 
    region_name='vast',
    endpoint_url=S3_ENDPOINT,
    aws_access_key_id=S3_ACCESS_KEY,
    aws_secret_access_key=S3_SECRET_KEY
)

In [None]:
# Download and process all files from 2019-01 to 2024-08
for year in range(2019, 2025):
    for month in range(1, 13):
        if year == 2024 and month > 8:
            break  # Only process until August 2024
        
        # Format month to always have two digits
        month_str = f"{month:02d}"
        file_url = f"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_{year}-{month_str}.parquet"
        
        # Download the file
        download_file(file_url)
        
        # Process the downloaded file
        file_name = f"yellow_tripdata_{year}-{month_str}.parquet"
        pa_table = process_parquet(file_name)

        # Upload the file to S3
        s3_key = f"yellow_tripdata/yellow_tripdata_{year}-{month_str}.parquet"
        upload_to_s3(file_name, S3_BUCKET, s3_key)
        
        # Write the processed data to VASTDB (custom logic)
        write_to_vastdb(session=session,
                        bucket_name=VASTDB_TWITTER_INGEST_BUCKET, 
                        schema_name=VASTDB_TWITTER_INGEST_SCHEMA, 
                        table_name=VASTDB_TWITTER_INGEST_TABLE, 
                        pa_table=pa_table)

        # import_to_vastdb(
        #     session=session,
        #     bucket_name=VASTDB_TWITTER_INGEST_BUCKET, 
        #     schema_name=VASTDB_TWITTER_INGEST_SCHEMA, 
        #     table_name=VASTDB_TWITTER_INGEST_TABLE, 
        #     files_to_import=[f"/{s3_key}"]
        # )
        
        # Delete the file after processing
        delete_file(file_name)


## Check Parquet file for non-compliance

In [None]:
# ! pip3 install --upgrade --quiet git+https://github.com/snowch/vastdb_parq_schema_file.git --use-pep517

In [None]:
## This can run a long time because the second check verifies the datasize in each row
# ! parquet_checker yellow_tripdata_2019-01.parquet