In [None]:
# ============================================================
# Build OD-level annual statistics for Complete Trips
# (population-level, NOT sampled)
# ============================================================

import pandas as pd
import numpy as np
import glob
import json
from datetime import datetime
from collections import defaultdict

# =========================
# CONFIG (与 sample 脚本保持一致)
# =========================

BASE_DIR = "C:/Users/rli04/Villanova University/Complete-trip-coordinate - Documents/General"
PARQUET_DIR = f"{BASE_DIR}/Salt_Lake/delivery"

ORIG_TRACT = "49035114000"
DEST_TRACT = "49035980000"
# 49035114000 (center)
# 49035980000 (airport)
# 49035110106 (ski)
# 49035101402 (U of U)
MONTHS = [
    'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
    'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'
]

OUTPUT_STATS_JSON = f"{ORIG_TRACT}_to_{DEST_TRACT}.stats.json"

# =========================
# REQUIRED COLUMNS
# =========================

USE_COLS = [
    "linked_trip_id",
    "travel_mode",
    "local_datetime_start",
    "local_datetime_end",
    "geohash7_orig",
    "geohash7_dest"
]

# =========================
# 1️⃣ LOAD YEARLY DATA
# =========================

MONTHLY_DFS = []

for m in MONTHS:
    print(f"Loading month: {m}")
    files = glob.glob(f"{PARQUET_DIR}/Salt_Lake-{m}-2020/*.snappy.parquet")
    if not files:
        continue

    dfs = [pd.read_parquet(f, columns=USE_COLS) for f in files]
    df_m = pd.concat(dfs, ignore_index=True)

    df_m["local_datetime_start"] = pd.to_datetime(
        df_m["local_datetime_start"], errors="coerce"
    )
    df_m["local_datetime_end"] = pd.to_datetime(
        df_m["local_datetime_end"], errors="coerce"
    )

    df_m = df_m[
        (df_m["local_datetime_end"] > df_m["local_datetime_start"])
    ]

    MONTHLY_DFS.append(df_m)

df = pd.concat(MONTHLY_DFS, ignore_index=True)
print("Total rows loaded:", len(df))

# =========================
# 2️⃣ OD FILTER
# =========================

df = df[
    (df["geohash7_orig"].notna()) &
    (df["geohash7_dest"].notna())
]

df = df[
    (df["geohash7_orig"].str.startswith(ORIG_TRACT[:5])) &
    (df["geohash7_dest"].str.startswith(DEST_TRACT[:5]))
]

print("After OD filter, rows:", len(df))

# =========================
# 3️⃣ COMPLETE TRIP FILTER
# =========================

mode_sets = (
    df.groupby("linked_trip_id")["travel_mode"]
      .agg(lambda x: set(str(m).lower().strip() for m in x))
)

valid_linked_ids = mode_sets[
    (mode_sets.apply(len) >= 2) &
    (mode_sets != {"car"})
].index

df = df[df["linked_trip_id"].isin(valid_linked_ids)]

print("Valid linked trips:", df["linked_trip_id"].nunique())

# =========================
# 4️⃣ COMPUTE LEG DURATION
# =========================

df["duration_min"] = (
    df["local_datetime_end"] - df["local_datetime_start"]
).dt.total_seconds() / 60

# =========================
# 5️⃣ BUILD YEARLY linked_trip OBJECTS
# =========================

linked_groups = defaultdict(list)

for r in df.itertuples():
    linked_groups[r.linked_trip_id].append(r)

linked_trips_year = []

for linked_id, rows in linked_groups.items():

    rows_sorted = sorted(rows, key=lambda r: r.local_datetime_start)

    if len(rows_sorted) < 2:
        continue

    total_duration = sum(
        r.duration_min or 0
        for r in rows_sorted
    )

    transfer_count = len(rows_sorted) - 1

    mode_set = set(
        str(r.travel_mode).lower().strip()
        for r in rows_sorted
    )

    linked_trips_year.append({
        "duration": total_duration,
        "transfers": transfer_count,
        "modes": mode_set
    })

print("Final linked trips (year):", len(linked_trips_year))

# =========================
# 6️⃣ AGGREGATE STATISTICS
# =========================

durations = np.array([t["duration"] for t in linked_trips_year])
transfers = np.array([t["transfers"] for t in linked_trips_year])
mode_sets = [t["modes"] for t in linked_trips_year]

def pct(arr, q):
    return float(np.percentile(arr, q))

stats = {
    "schema": "nova.complete_trip.od_stats.v1",
    "generated_at": datetime.utcnow().isoformat() + "Z",
    "od": {
        "origin": ORIG_TRACT,
        "destination": DEST_TRACT
    },
    "coverage": {
        "temporal": "year-2020",
        "spatial": "Salt Lake 6-county"
    },
    "counts": {
        "linked_trips": int(len(durations))
    },
    "trip_duration_min": {
        "min": float(durations.min()),
        "p25": pct(durations, 25),
        "median": pct(durations, 50),
        "p75": pct(durations, 75),
        "max": float(durations.max())
    },
    "transfers": {
        "avg": float(transfers.mean()),
        "p75": int(pct(transfers, 75)),
        "max": int(transfers.max())
    },
    "mode_involvement": {
        "car": float(sum("car" in m for m in mode_sets) / len(mode_sets)),
        "bus": float(sum("bus" in m for m in mode_sets) / len(mode_sets)),
        "rail": float(sum("rail" in m for m in mode_sets) / len(mode_sets)),
        "walk": float(sum("walk" in m for m in mode_sets) / len(mode_sets))
    }
}

# =========================
# 7️⃣ OUTPUT JSON
# =========================

with open(OUTPUT_STATS_JSON, "w", encoding="utf-8") as f:
    json.dump(stats, f, indent=2, allow_nan=False)

print("Stats JSON written to:", OUTPUT_STATS_JSON)


Loading month: Jan
Loading month: Feb
Loading month: Mar
Loading month: Apr
Loading month: May
Loading month: Jun
Loading month: Jul
Loading month: Aug
Loading month: Sep
Loading month: Oct
Loading month: Nov
Loading month: Dec
Total rows loaded: 28762190
