# TransitGraphAI
This notebook scrapes the MobilityData GTFS catalog on GitHub, filters feeds, validates the presence of required GTFS files, computes simple size indicators, and downloads a curated batch of feeds for downstream graph construction.  
Note: set a personal GitHub token via `GITHUB_TOKEN` to avoid rate limits.
# Data collection

In [55]:
import pandas as pd
import requests
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed
import io
import os
import re
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np

## Scraping of the catalog
Using a personal GitHub API token, we query the `MobilityData/mobility-database-catalogs` repository to retrieve the list of GTFS schedule feeds. To avoid rate limiting, export `GITHUB_TOKEN` and pass it in the `Authorization` header.

In [None]:
headers = {
    "Authorization": "token XXXXXXXXXXXXXXXXX"
}  # personal token removed to keep it private, replace with your own

base_api_url = "https://api.github.com/repos/MobilityData/mobility-database-catalogs/contents/catalogs/sources/gtfs/schedule"
response = requests.get(base_api_url, headers=headers)
file_list = response.json()

records = []

for file in file_list:
    file_url = file.get("download_url")
    if not file_url:
        continue
    r = requests.get(file_url)
    if r.status_code != 200:
        continue
    try:
        data = r.json()
        if data.get("status") == "deprecated":
            continue
        records.append(
            {
                "id": data.get("mdb_source_id"),
                "provider": data.get("provider"),
                "country": data.get("location", {}).get("country_code"),
                "region": data.get("location", {}).get("subdivision_name"),
                "bbox": data.get("location", {}).get("bounding_box"),
                "url": data.get("urls", {}).get("latest"),
            }
        )
    except Exception as e:
        print(f"Error with {file_url}: {e}")

df = pd.DataFrame(records)
df.to_csv("data_collection/gtfs_sources_mobilitydata.csv", index=False)

In [None]:
df = pd.read_csv("data_collection/gtfs_sources_mobilitydata.csv")
df.head()

Unnamed: 0,id,provider,country,region,bbox,url
0,1329,Department of Municipalities and Transport,AE,Abu Dhabi Emirate,"{'minimum_latitude': 22.9433339258299, 'maximu...",https://storage.googleapis.com/storage/v1/b/md...
1,904,Road and Transport Authority (RTA),AE,Dubayy,"{'minimum_latitude': 24.720541, 'maximum_latit...",https://storage.googleapis.com/storage/v1/b/md...
2,2345,Municipality of Tirana,AL,Tirana,"{'minimum_latitude': 41.2830270048811, 'maximu...",https://storage.googleapis.com/storage/v1/b/md...
3,1220,Colectivos Buenos Aires,AR,Buenos Aires,"{'minimum_latitude': -35.182685, 'maximum_lati...",https://storage.googleapis.com/storage/v1/b/md...
4,6,Subterráneos de Buenos Aires (SUBTE),AR,Buenos Aires,"{'minimum_latitude': None, 'maximum_latitude':...",https://storage.googleapis.com/storage/v1/b/md...


We aim to build time-dependent, multimodal graphs. We therefore require feeds to include at least: `stops.txt`, `stop_times.txt`, `trips.txt`, `routes.txt`, `calendar.txt` or `calendar_dates.txt`, and `transfers.txt`. Feeds missing any of these files are excluded from the batch download.

In [None]:
# Reload the list of GTFS sources
df = pd.read_csv("gtfs_sources_mobilitydata.csv")

results = []

# List of files to check for availability in each folder
required_files = {
    "stops.txt",
    "stop_times.txt",
    "trips.txt",
    "routes.txt",
    "calendar.txt",
    "calendar_dates.txt",
    "transfers.txt",
}


def process_feed(row: dict) -> dict:
    """
    Download a GTFS zip from the feed URL and check for required GTFS files.

    Parameters
    ----------
    row : dict
        A record with at least keys "provider" and "url".

    Returns
    -------
    dict
        Summary with:
        - 'provider' (str): provider name
        - 'url' (str): feed URL
        - 'files_found' (list[str]): present required files
        - 'missing_files' (list[str]): required files not found
        - 'complete' (bool): True if all required files are present
        - 'error' (str | None): error message if the request or parsing failed
    Notes
    -----
    Performs an HTTP GET with timeout and inspects the zip file contents in-memory.
    """
    url = row.get("url")
    provider = row.get("provider")
    try:
        r = requests.get(url, timeout=20)
        z = zipfile.ZipFile(io.BytesIO(r.content))
        filenames = set(z.namelist())
        found_files = required_files & filenames

        return {
            "provider": provider,
            "url": url,
            "files_found": list(found_files),
            "missing_files": list(required_files - found_files),
            "complete": required_files.issubset(filenames),
        }
    except Exception as e:
        return {
            "provider": provider,
            "url": url,
            "files_found": [],
            "missing_files": list(required_files),
            "complete": False,
            "error": str(e),
        }


# Parallel processing of all the folders
results = []
with ThreadPoolExecutor(max_workers=10) as executor:
    futures = [executor.submit(process_feed, row) for _, row in df.iterrows()]
    for future in tqdm(
        as_completed(futures), total=len(futures), desc="Analysing GTFS sources"
    ):
        results.append(future.result())

# Export CSV
results_df = pd.DataFrame(results)
results_df.to_csv("data_collection/gtfs_file_check_results.csv", index=False)
results_df.head()

Analyse des GTFS:   0%|          | 0/852 [00:00<?, ?it/s]

Unnamed: 0,provider,url,files_found,missing_files,complete,error
0,Mar Chiquita SRL,https://storage.googleapis.com/storage/v1/b/md...,"[trips.txt, stop_times.txt, calendar_dates.txt...",[transfers.txt],False,
1,Subterráneos de Buenos Aires (SUBTE),https://storage.googleapis.com/storage/v1/b/md...,[],"[calendar_dates.txt, trips.txt, stops.txt, tra...",False,
2,Canberra Metro Operations,https://storage.googleapis.com/storage/v1/b/md...,"[trips.txt, stop_times.txt, calendar_dates.txt...",[transfers.txt],False,
3,Byron Easybus,https://storage.googleapis.com/storage/v1/b/md...,"[trips.txt, transfers.txt, stop_times.txt, cal...",[],True,
4,"Department of Transport, Public Transport",https://storage.googleapis.com/storage/v1/b/md...,"[trips.txt, stop_times.txt, calendar_dates.txt...",[transfers.txt],False,


For a first pass we restrict to Western Europe. This yields a more homogeneous sample in terms of regulatory context and service patterns, while remaining diverse across metropolitan areas. We calculate the number of lines, stops, and trips in each of these networks, in order to homogenize the sample according to the size and complexity of the graphs.

In [None]:
df = pd.read_csv("data_collection/gtfs_file_check_results.csv")
ref = pd.read_csv("data_collection/gtfs_sources_mobilitydata.csv")
df = pd.merge(df, ref[["provider", "country", "region"]], on="provider")
df = df[
    df["complete"]
    & (df["country"].isin(["FR", "IT", "DE", "ES", "BE", "LU", "PT", "NL", "CH"]))
]


def count_lines_in_zip(zip_bytes: bytes, filename: str) -> int | None:
    """
    Count the number of data rows (excluding header) for a CSV-like file inside a GTFS zip.

    Parameters
    ----------
    zip_bytes : bytes
        Raw bytes of a GTFS zip archive.
    filename : str
        Target file name inside the archive (e.g., "stops.txt").

    Returns
    -------
    int | None
        Number of non-header lines if file exists and is readable; otherwise None.

    Notes
    -----
    Opens the file from the in-memory zip and subtracts one line for the header.
    Returns None on missing file or any read/parsing error.
    """
    try:
        with zipfile.ZipFile(io.BytesIO(zip_bytes)) as z:
            with z.open(filename) as f:
                # Compte les lignes (en excluant l'en-tête)
                return sum(1 for line in f) - 1
    except KeyError:
        return None
    except Exception:
        return None


def analyze_feed(row: dict) -> dict:
    """
    Download a GTFS zip and compute simple size indicators for key tables.

    Parameters
    ----------
    row : dict
        A record with keys "provider" and "url".

    Returns
    -------
    dict
        Indicators:
        - 'provider' (str), 'url' (str)
        - 'nb_stops' (int | None): number of stops
        - 'nb_routes' (int | None): number of routes
        - 'nb_trips' (int | None): number of trips
        - 'error' (str | None): error message if something failed

    Notes
    -----
    Uses `count_lines_in_zip` on "stops.txt", "routes.txt", and "trips.txt".
    """
    url = row["url"]
    provider = row["provider"]
    try:
        r = requests.get(url, timeout=20)
        content = r.content
        return {
            "provider": provider,
            "url": url,
            "nb_stops": count_lines_in_zip(content, "stops.txt"),
            "nb_routes": count_lines_in_zip(content, "routes.txt"),
            "nb_trips": count_lines_in_zip(content, "trips.txt"),
        }
    except Exception as e:
        return {
            "provider": provider,
            "url": url,
            "nb_stops": None,
            "nb_routes": None,
            "nb_trips": None,
            "error": str(e),
        }


results = []
with ThreadPoolExecutor(max_workers=10) as executor:
    futures = [executor.submit(analyze_feed, row) for _, row in df.iterrows()]
    for future in tqdm(
        as_completed(futures), total=len(futures), desc="Network analysis"
    ):
        results.append(future.result())

stats_df = pd.DataFrame(results)
stats_df.to_csv("data_collection/gtfs_network_sizes.csv", index=False)
stats_df.head()

Analyse réseau:   0%|          | 0/81 [00:00<?, ?it/s]

Unnamed: 0,provider,url,nb_stops,nb_routes,nb_trips,error
0,De Waterbus,https://storage.googleapis.com/storage/v1/b/md...,8.0,2.0,223.0,
1,DeWaterbus,https://storage.googleapis.com/storage/v1/b/md...,13.0,2.0,223.0,
2,Bürgerbus Leupoldsgrün (Landkreis Hof),https://storage.googleapis.com/storage/v1/b/md...,27.0,2.0,6.0,
3,HofBus,https://storage.googleapis.com/storage/v1/b/md...,345.0,39.0,1168.0,
4,naldo Verkehrsverbund,https://storage.googleapis.com/storage/v1/b/md...,6819.0,423.0,31245.0,


In [None]:
stats_df["intensity"] = stats_df["nb_trips"] / stats_df["nb_routes"]
ref = pd.read_csv("data_collection/gtfs_sources_mobilitydata.csv")
stats_df = pd.merge(stats_df, ref[["provider", "country", "region"]], on="provider")

In [36]:
stats_df = stats_df.drop_duplicates(subset="provider")

In [37]:
stats_df.describe()

Unnamed: 0,nb_stops,nb_routes,nb_trips,intensity
count,73.0,73.0,73.0,73.0
mean,11969.794521,646.60274,66093.9,168.259723
std,63798.191345,3404.094327,268152.0,277.475183
min,2.0,1.0,1.0,0.333333
25%,132.0,14.0,457.0,25.9
50%,1218.0,52.0,3755.0,54.125874
75%,3415.0,191.0,26108.0,224.916667
max,541592.0,29084.0,2169492.0,1627.0


The public transport networks in Western Europe range from 2 to >500,000 stops, with a majority of networks with more than 52 routes.  
To balance richness and tractability, we retain feeds whose size indicators fall within a mid-range, and we thus focus on the subset of networks ranging from 100 to 500 routes This avoids toy networks and overly large metropolitan systems in early experiments.

In [None]:
stats_df.loc[
    (stats_df["nb_routes"] >= 100) & (stats_df["nb_routes"] <= 500), "provider"
].unique()

array(['naldo Verkehrsverbund', 'Schweizer Reisen',
       'Autos Castellbisbal, Mohn, Oliveras, Rosanbus, Soler i Sauret, Tusgsal, Barcelona City Tour, Monbus, Avanza',
       'OVA-Aalen, OVA-Bopfingen, Beck+Schubert',
       'Rurtalbahn GmbH, ABELLIO Rail, VIAS GmbH, Aachener Straßenbahn und Energieversorgungs-AG, Rurtalbus GmbH, WestVerkehr GmbH, Staatsbahnen, National Express, ASEAG Netliner',
       'Transports Metropolitans de Barcelona (TMB)', 'Tisséo',
       'Aleop Renfort LR 85, Aléop en Loire-Atlantique, Aléop en Maine-et-Loire, Aléop en Mayenne, Aléop en Sarthe, Aléop en Vendée, Aléop en Vendée et Loire-Atlantique, Aléop express Régionale, Aléop TER, projet Aléop en Loire-Atlantique, Yeu Continent',
       'Réseau Mistral', 'Régie des Transports Métropolitains (RTM)',
       'Transports en Commun Lyonnais (TCL)',
       'Cars Région Auvergne-Rhône-Alpes (Transisère)',
       'Trenitalia Piemonte', 'bodo Verkehrsverbund',
       'Trentino Trasporti Esercizio (TTE)',
       '

This subset contains 17 public transport networks, which can constitute a first solid sample to attempt to make generalizations among a diversity of networks. We record the catalog commit hash and the retrieval timestamp to keep the sample reproducible over time.

In [None]:
provider_zones = {
    "naldo Verkehrsverbund": "naldo Verkehrsverbund",
    "Schweizer Reisen": "Schweizer Reisen",
    "OVA-Aalen, OVA-Bopfingen, Beck+Schubert": "Aalen-Bopfingen",
    "Rurtalbahn GmbH, ABELLIO Rail, VIAS GmbH, Aachener Straßenbahn und Energieversorgungs-AG, Rurtalbus GmbH, WestVerkehr GmbH, Staatsbahnen, National Express, ASEAG Netliner": "Aachen",
    "Transports Metropolitans de Barcelona (TMB)": "Barcelona",
    "Tisséo": "Toulouse",
    "Aleop Renfort LR 85, Aléop en Loire-Atlantique, Aléop en Maine-et-Loire, Aléop en Mayenne, Aléop en Sarthe, Aléop en Vendée, Aléop en Vendée et Loire-Atlantique, Aléop express Régionale, Aléop TER, projet Aléop en Loire-Atlantique, Yeu Continent": "Pays de la Loire",
    "Réseau Mistral": "Toulon",
    "Régie des Transports Métropolitains (RTM)": "Marseille",
    "Transports en Commun Lyonnais (TCL)": "Lyon",
    "Cars Région Auvergne-Rhône-Alpes (Transisère)": "Isère",
    "Trenitalia Piemonte": "Piemonte",
    "bodo Verkehrsverbund": "Bodensee-Oberschwaben",
    "Trentino Trasporti Esercizio (TTE)": "Trentino",
    "Hofmann Omnibusverkehr GmbH": "Hofmann Omnibusverkehr GmbH",
    "Agenzia Mobilità Ambiente Territorio": "Milano",
}
stats_df = stats_df[stats_df["provider"].isin(provider_zones.keys())]
stats_df["provider"] = stats_df["provider"].map(provider_zones)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  stats_df["provider"] = stats_df["provider"].map(provider_zones)


In [54]:
stats_df

Unnamed: 0,provider,url,nb_stops,nb_routes,nb_trips,error,intensity,country,region
4,naldo Verkehrsverbund,https://storage.googleapis.com/storage/v1/b/md...,6819.0,423.0,31245.0,,73.865248,DE,Baden-Württemberg
5,Schweizer Reisen,https://storage.googleapis.com/storage/v1/b/md...,2122.0,104.0,4935.0,,47.451923,DE,
23,Aalen-Bopfingen,https://storage.googleapis.com/storage/v1/b/md...,3322.0,191.0,9248.0,,48.418848,DE,Baden-Württemberg
27,Aachen,https://storage.googleapis.com/storage/v1/b/md...,3193.0,370.0,21148.0,,57.156757,DE,Nordrhein-Westfalen
29,Barcelona,https://storage.googleapis.com/storage/v1/b/md...,3415.0,113.0,56436.0,,499.433628,ES,Barcelona
52,Toulouse,https://storage.googleapis.com/storage/v1/b/md...,5540.0,122.0,31406.0,,257.42623,FR,Occitanie
56,Pays de la Loire,https://storage.googleapis.com/storage/v1/b/md...,9829.0,143.0,7740.0,,54.125874,FR,Pays de la Loire
57,Toulon,https://storage.googleapis.com/storage/v1/b/md...,1958.0,100.0,25661.0,,256.61,FR,Provence-Alpes-Côte-d'Azur
58,Marseille,https://storage.googleapis.com/storage/v1/b/md...,2681.0,130.0,36758.0,,282.753846,FR,Provence-Alpes-Côte-d’Azur
60,Lyon,https://storage.googleapis.com/storage/v1/b/md...,6752.0,467.0,118262.0,,253.237687,FR,Rhône


In [None]:
base_dir = "gtfs_data"
os.makedirs(base_dir, exist_ok=True)


def sanitize(name: str) -> str:
    """
    Convert a provider name into a filesystem-friendly identifier.

    Parameters
    ----------
    name : str
        Original provider name.

    Returns
    -------
    str
        Sanitized name with non-word characters collapsed to underscores and trimmed.
    """
    return re.sub(r"\W+", "_", name.strip())


# Download + extraction
for _, row in tqdm(
    stats_df.iterrows(), total=len(stats_df), desc="Download + extraction"
):
    provider = sanitize(row["provider"])
    url = row["url"]

    provider_dir = os.path.join(base_dir, provider)
    if os.path.exists(provider_dir):
        continue

    try:
        r = requests.get(url, timeout=50)
        z = zipfile.ZipFile(io.BytesIO(r.content))
        z.extractall(provider_dir)
    except Exception as e:
        print(f"Error with {provider} ({url}) : {e}")

Téléchargement + extraction:   0%|          | 0/16 [00:00<?, ?it/s]