In [9]:
# MBTA Live Train Tracking (Notebook-ready)
# ------------------------------------------------------------
# - Works without any non-standard deps (uses requests; pandas optional)
# - Optional API key via env var MBTA_API_KEY (recommended but not required)
# - Modular functions you can reuse across cells
# - Examples at the bottom
#
# Endpoints (MBTA v3):
#   Vehicles (live positions): https://api-v3.mbta.com/vehicles
#   Predictions (arrivals/departures): https://api-v3.mbta.com/predictions
#
# Reference route_type (GTFS):
#   0 = Light rail (e.g., Green Line)
#   1 = Subway/Heavy rail (e.g., Red/Orange/Blue)
#   2 = Commuter rail
#   3 = Bus
#   4 = Ferry
# ------------------------------------------------------------

from __future__ import annotations
import os
import time
import typing as t
import requests

try:
    import pandas as pd  # optional
    _HAS_PANDAS = True
except Exception:
    _HAS_PANDAS = False

MBTA_BASE_URL = "https://api-v3.mbta.com"

class MBTAClient:
    """
    Minimal MBTA v3 API client focused on live train tracking and predictions.
    Provide an API key via env var MBTA_API_KEY or via the constructor.
    """
    def __init__(self, api_key: str | None = None, base_url: str = MBTA_BASE_URL, timeout: int = 15):
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.api_key = api_key or os.getenv("MBTA_API_KEY")

    # ---------- internal helpers ----------

    def _headers(self) -> dict:
        h = {"accept": "application/vnd.api+json"}
        if self.api_key:
            # MBTA supports either `x-api-key` or `api_key` query param. Header is cleaner.
            h["x-api-key"] = self.api_key
        return h

    def _get(self, path: str, params: dict | None = None) -> dict:
        url = f"{self.base_url}/{path.lstrip('/')}"
        r = requests.get(url, headers=self._headers(), params=params or {}, timeout=self.timeout)
        r.raise_for_status()
        return r.json()

    # ---------- public API ----------

    def list_routes(
        self,
        route_types: t.Iterable[int] = (0, 1, 2),
        active_only: bool = False,
    ) -> list[dict]:
        params = {
            "filter[type]": ",".join(str(x) for x in route_types),  # <-- was filter[route_type]
            "page[limit]": 100,
        }
        data = self._get("/routes", params)
        routes = data.get("data", [])


        if not active_only:
            return routes

        # Filter down to routes that currently have at least one vehicle
        active_ids = set()
        for rt in route_types:
            v = self.get_vehicles(route_type=rt, limit=1)
            if v:
                active_ids.update({x["attributes"].get("route", None) for x in v})
        return [r for r in routes if r["id"] in active_ids]

    def get_vehicles(
        self,
        route: str | None = None,
        route_type: int | None = None,
        direction_id: int | None = None,
        limit: int = 50,
        include: str = "trip,stop,route",
    ) -> list[dict]:
        """
        Live vehicle positions. For trains, set route_type to 0 (Light), 1 (Subway), or 2 (Commuter rail).
        You can also pass a specific route id (e.g., 'Red', 'Orange', 'CR-Fitchburg').
        """
        params: dict[str, t.Any] = {"page[limit]": limit, "include": include}
        if route:
            params["filter[route]"] = route
        if route_type is not None:
            params["filter[route_type]"] = route_type
        if direction_id is not None:
            params["filter[direction_id]"] = direction_id

        data = self._get("/vehicles", params)
        return data.get("data", [])

    def get_predictions(
        self,
        route: str | None = None,
        stop: str | None = None,
        direction_id: int | None = None,
        route_type: int | None = None,
        limit: int = 50,
        include: str = "trip,stop,vehicle,route",
        sort: str = "arrival_time",
    ) -> list[dict]:
        """
        Predictions (arrivals/departures). Filter by route and/or stop.
        """
        params: dict[str, t.Any] = {"page[limit]": limit, "include": include, "sort": sort}
        if route:
            params["filter[route]"] = route
        if stop:
            params["filter[stop]"] = stop
        if direction_id is not None:
            params["filter[direction_id]"] = direction_id
        if route_type is not None:
            params["filter[route_type]"] = route_type

        data = self._get("/predictions", params)
        return data.get("data", [])

    # ---------- convenience transforms ----------

    @staticmethod
    def vehicles_to_records(vehicles: list[dict]) -> list[dict]:
        """
        Flatten vehicles JSON: each record includes vehicle id, label, bearing, speed, lat/lon, current_status, route, and timestamps.
        """
        out = []
        for v in vehicles:
            attrs = v.get("attributes", {})
            out.append({
                "vehicle_id": v.get("id"),
                "label": attrs.get("label"),
                "bearing_deg": attrs.get("bearing"),
                "speed_mps": attrs.get("speed"),
                "latitude": attrs.get("latitude"),
                "longitude": attrs.get("longitude"),
                "current_status": attrs.get("current_status"),
                "current_stop_sequence": attrs.get("current_stop_sequence"),
                "updated_at": attrs.get("updated_at"),
                "route": attrs.get("route"),
                "direction_id": attrs.get("direction_id"),
                "trip_id": attrs.get("trip_id"),
            })
        return out

    @staticmethod
    def predictions_to_records(predictions: list[dict]) -> list[dict]:
        """
        Flatten predictions JSON to useful columns.
        """
        out = []
        for p in predictions:
            attrs = p.get("attributes", {})
            out.append({
                "prediction_id": p.get("id"),
                "stop_sequence": attrs.get("stop_sequence"),
                "arrival_time": attrs.get("arrival_time"),
                "departure_time": attrs.get("departure_time"),
                "status": attrs.get("status"),
                "direction_id": attrs.get("direction_id"),
                "stop_id": attrs.get("stop"),
                "route_id": attrs.get("route_id"),
                "trip_id": attrs.get("trip_id"),
                "schedule_relationship": attrs.get("schedule_relationship"),
                "updated_at": attrs.get("updated_at"),
            })
        return out

    # ---------- simple polling helper (for “live-ish” tracking) ----------

    def poll_vehicles(
        self,
        route: str,
        interval_sec: float = 10.0,
        iterations: int = 6,
        to_dataframe: bool = True,
        verbose: bool = True,
    ):
        """
        Poll live vehicle positions for a given route every `interval_sec` seconds, for `iterations` times.
        Returns either a pandas DataFrame (if available and to_dataframe=True) or a list of dict snapshots.
        """
        snapshots: list[dict] = []
        for i in range(iterations):
            vehicles = self.get_vehicles(route=route)
            recs = self.vehicles_to_records(vehicles)
            snapshots.append({"t_index": i, "ts": time.time(), "records": recs})
            if verbose:
                print(f"[{i+1}/{iterations}] fetched {len(recs)} vehicles on route {route}")
            if i < iterations - 1:
                time.sleep(interval_sec)

        if _HAS_PANDAS and to_dataframe:
            # Explode into a table with one row per vehicle per poll tick
            rows = []
            for snap in snapshots:
                for r in snap["records"]:
                    rr = dict(r)
                    rr["_t_index"] = snap["t_index"]
                    rr["_ts"] = snap["ts"]
                    rows.append(rr)
            return pd.DataFrame(rows)
        return snapshots
    
def _get(self, path: str, params: dict | None = None) -> dict:
    url = f"{self.base_url}/{path.lstrip('/')}"
    r = requests.get(url, headers=self._headers(), params=params or {}, timeout=self.timeout)
    try:
        r.raise_for_status()
    except requests.HTTPError as e:
        # Surface MBTA’s JSON:API error detail when present
        try:
            err = r.json()
            details = "; ".join(x.get("detail", str(x)) for x in err.get("errors", []))
        except Exception:
            details = r.text
        raise requests.HTTPError(f"{e} | MBTA says: {details}") from None
    return r.json()



In [10]:
client = MBTAClient()
routes = client.list_routes(route_types=(0,1,2))
print([r["id"] for r in routes])

['Red', 'Mattapan', 'Orange', 'Green-B', 'Green-C', 'Green-D', 'Green-E', 'Blue', 'CR-Fairmount', 'CR-NewBedford', 'CR-Fitchburg', 'CR-Worcester', 'CR-Franklin', 'CR-Greenbush', 'CR-Haverhill', 'CR-Kingston', 'CR-Lowell', 'CR-Needham', 'CR-Newburyport', 'CR-Providence', 'CR-Foxboro']


In [18]:
client = MBTAClient()
routes = client.list_routes(route_types=(5,))
print([r["id"] for r in routes])

[]


In [25]:
vehicles = client.get_vehicles(route="Orange")
vehicle_df = pd.DataFrame(MBTAClient.vehicles_to_records(vehicles)) if _HAS_PANDAS else MBTAClient.vehicles_to_records(vehicles)
vehicle_df

Unnamed: 0,vehicle_id,label,bearing_deg,speed_mps,latitude,longitude,current_status,current_stop_sequence,updated_at,route,direction_id,trip_id
0,O-5484ECC8,1458,185,,42.43772,-71.0708,STOPPED_AT,1,2025-08-19T13:44:26-04:00,,0,
1,O-5484EAFF,1442,40,,42.33776,-71.08795,INCOMING_AT,60,2025-08-19T13:44:39-04:00,,1,
2,O-5484EAFE,1518,30,,42.31001,-71.10782,STOPPED_AT,10,2025-08-19T13:44:52-04:00,,1,
3,O-5484EAFC,1516,30,,42.30122,-71.11384,STOPPED_AT,1,2025-08-19T13:40:51-04:00,,1,
4,O-5484EAFB,1428,210,,42.31005,-71.10796,STOPPED_AT,180,2025-08-19T13:44:57-04:00,,0,
5,O-5484EAFA,1454,220,,42.3389,-71.08656,INCOMING_AT,140,2025-08-19T13:44:18-04:00,,0,
6,O-5484EAF9,1438,175,,42.35881,-71.05782,STOPPED_AT,80,2025-08-19T13:44:52-04:00,,0,
7,O-5484EAF7,1414,200,,42.42462,-71.07519,IN_TRANSIT_TO,20,2025-08-19T13:44:55-04:00,,0,
8,O-5484EAF6,1408,20,,42.42273,-71.07593,INCOMING_AT,180,2025-08-19T13:44:44-04:00,,1,
9,O-5484E5C3,1440,150,,42.37736,-71.07505,INCOMING_AT,50,2025-08-19T13:44:44-04:00,,0,


In [20]:
vehicles = client.get_vehicles(route="Red")
vehicle_df = pd.DataFrame(MBTAClient.vehicles_to_records(vehicles)) if _HAS_PANDAS else MBTAClient.vehicles_to_records(vehicles)
print(vehicle_df if _HAS_PANDAS else vehicle_df[:3])

    vehicle_id label  bearing_deg speed_mps  latitude  longitude  \
0   R-5484ECA0  1928          315      None  42.25371  -71.00673   
1   R-5484EC6E  1616          275      None  42.39805  -71.13358   
2   R-5484EB5B  1700            0      None  42.33790  -71.05696   
3   R-5484EB26  1733          125      None  42.39673  -71.12245   
4   R-5484EB20  1718           95      None  42.36243  -71.08581   
5   R-5484EB16  1916          160      None  42.32705  -71.05787   
6   R-5484EB15  1920          260      None  42.30002  -71.06183   
7   R-5484EB13  1832          330      None  42.28458  -71.06377   
8   R-5484EB10  1837          135      None  42.35524  -71.06018   
9   R-5484EB05  1804          200      None  42.31119  -71.05323   
10  R-5484E966  1853          350      None  42.37293  -71.11652   
11  R-5484E7C1  1811          310      None  42.35792  -71.06548   
12  R-5484E58F  1826           85      None  42.30028  -71.05982   
13  R-5484E459  1705          330      None  42.

In [None]:

# -------------------------
# Example usage (uncomment and run these in your notebook):
# -------------------------
# client = MBTAClient()  # or MBTAClient(api_key="YOUR_KEY")
#
# # 1) List rail routes (light, subway, commuter rail)
# routes = client.list_routes(route_types=(0,1,2))
# print([r["id"] for r in routes])
#
# # 2) Live vehicles on the Red Line (subway)
# vehicles = client.get_vehicles(route="Red")
# vehicle_df = pd.DataFrame(MBTAClient.vehicles_to_records(vehicles)) if _HAS_PANDAS else MBTAClient.vehicles_to_records(vehicles)
# print(vehicle_df if _HAS_PANDAS else vehicle_df[:3])
#
# # 3) Next arrivals for a specific route (and optionally a stop)
# preds = client.get_predictions(route="Red", limit=100)
# preds_df = pd.DataFrame(MBTAClient.predictions_to_records(preds)) if _HAS_PANDAS else MBTAClient.predictions_to_records(preds)
# print(preds_df if _HAS_PANDAS else preds_df[:3])
#
# # 4) “Live-ish” polling demo: watch Red Line vehicles for ~1 minute
# polled = client.poll_vehicles(route="Red", interval_sec=10, iterations=6, to_dataframe=True)
# polled.head() if _HAS_PANDAS else polled[:1]
