In [None]:
import requests
import pyarrow.parquet as pq
import pyarrow as pa

_TAXI_SCHEMA_RAW = pa.schema(
    [
        pa.field("VendorID", pa.int64()),
        pa.field("tpep_pickup_datetime", pa.timestamp("us")),
        pa.field("tpep_dropoff_datetime", pa.timestamp("us")),
        pa.field("passenger_count", pa.float64()),
        pa.field("trip_distance", pa.float64()),
        pa.field("RatecodeID", pa.float64()),
        pa.field("store_and_fwd_flag", pa.string()),
        pa.field("PULocationID", pa.int64()),
        pa.field("DOLocationID", pa.int64()),
        pa.field("payment_type", pa.int64()),
        pa.field("fare_amount", pa.float64()),
        pa.field("extra", pa.float64()),
        pa.field("mta_tax", pa.float64()),
        pa.field("tip_amount", pa.float64()),
        pa.field("tolls_amount", pa.float64()),
        pa.field("improvement_surcharge", pa.float64()),
        pa.field("total_amount", pa.float64()),
        pa.field("congestion_surcharge", pa.float64()),
        pa.field("airport_fee", pa.float64()),
    ]
)

url = "https://s3.amazonaws.com/nyc-tlc/tripdata/yellow_tripdata_2015-01.parquet"

# response = requests.get(url)
# table = pq.read_table(pa.py_buffer(response.content), schema=_TAXI_SCHEMA_RAW)
# pq.write_table(table, "yellow_tripdata_2022-03.parquet")
# open("yellow_tripdata_2015-01.parquet", "wb").write(response.content)

In [None]:
from urllib.parse import urlparse, urljoin
from pathlib import Path

base = "https://s3.amazonaws.com/nyc-tlc/trip+data/yellow_tripdata_{}.parquet"
base = "https://azureopendatastorage.blob.core.windows.net/"
base.format("20115-01")

#### generate test data

In [None]:
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
from numpy.random import default_rng

path = "yellow_tripdata_2015-01.parquet"
table = pq.read_table(path, schema=_TAXI_SCHEMA_RAW)
table.schema

In [None]:
table = table.add_column(0, pa.field("year", pa.int64()), [[int(2022)] * table.shape[0]])
table.schema

In [None]:
rng = default_rng()
rows = rng.choice(table.shape[0], size=100, replace=False)
table = table.take(rows)
pq.write_table(table, "../examples/model-training/tests/data/taxi/2015-01.parquet")

In [None]:
table = pq.read_table("../examples/model-training/tests/data/taxi/2015-01.parquet", schema=_TAXI_SCHEMA_RAW)
partition_key = "2015-01-01"
_RENAME_MAP = {"VendorID": "vendor_id", "PULocationID": "pu_location_id", "DOLocationID": "do_location_id"}
columns = [_RENAME_MAP.get(col, col) for col in table.column_names]
table = table.rename_columns(columns)
table = table.add_column(0, pa.field("year", pa.int64()), [[int(partition_key[:4])] * table.shape[0]])
table = table.add_column(1, pa.field("month", pa.int64()), [[int(partition_key[5:-3])] * table.shape[0]])

In [None]:
from flight_fusion import FusionServiceClient, ClientOptions, AssetKey
ffc = FusionServiceClient(ClientOptions(host="localhost", port=50051))
fds = ffc.get_dataset_client(AssetKey(["taxi", "partitioned2"]))
fds.write_into(table, partition_by=["year", "month"])

In [None]:
df = fds.load()

In [None]:
df.shape