In [2]:
## Evan Sultanik Graph Gen for parent stop data
from dataclasses import dataclass
from io import BytesIO
import itertools
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, Iterator, Optional
import urllib.request
from zipfile import ZipFile

from tqdm import tqdm
from collections import defaultdict
from dataclasses import dataclass
from functools import wraps
from typing import Iterator

import networkx as nx


def parse_csv(stream: Iterable[str]) -> Iterator[dict[str, str]]:
    columns: list[str] | None = None
    for line in stream:
        if columns is None:
            columns = line.strip().split(",")
        else:
            yield dict(zip(columns, line.strip().split(",")))


@dataclass(frozen=True, slots=True)
class Stop:
    stop_id: str
    name: str
    lat: float
    lon: float
    location_type: Optional[int]
    parent_station: Optional[str]

    @classmethod
    def parse(cls, stream: Iterable[str]) -> Iterator["Stop"]:
        for stop in parse_csv(stream):
            try:
                location_type: int | None = int(stop["location_type"])
            except ValueError:
                location_type = None
            parent_station: str | None = stop["parent_station"].strip()
            if not parent_station:
                parent_station = None
            yield cls(
                stop_id=stop["stop_id"].strip(),
                name=stop["stop_name"].strip(),
                lat=float(stop["stop_lat"]),
                lon=float(stop["stop_lon"]),
                location_type=location_type,
                parent_station=parent_station
            )

    def __hash__(self):
        return hash(self.stop_id)

    def __eq__(self, other):
        return isinstance(other, Stop) and self.stop_id == other.stop_id


@dataclass(frozen=True, slots=True, unsafe_hash=True)
class Transfer:
    from_stop: str
    to_stop: str
    min_transfer_time: int

    @classmethod
    def parse(cls, stream: Iterable[str]) -> Iterator["Transfer"]:
        for transfer in parse_csv(stream):
            yield cls(
                from_stop=transfer["from_stop_id"].strip(),
                to_stop=transfer["to_stop_id"].strip(),
                min_transfer_time=int(transfer["min_transfer_time"]),
            )


@dataclass(frozen=True, slots=True)
class Route:
    route_id: str
    route_short_name: str
    route_long_name: str

    def __hash__(self):
        return hash(self.route_id)

    def __str__(self):
        return f"{self.route_long_name} ({self.route_short_name})"

    @classmethod
    def parse(cls, stream: Iterable[str]) -> Iterator["Route"]:
        for route in parse_csv(stream):
            yield cls(
                route_id=route["route_id"].strip(),
                route_short_name=route["route_short_name"].strip(),
                route_long_name=route["route_long_name"].strip(),
            )


@dataclass(frozen=True, slots=True, unsafe_hash=True)
class Trip:
    trip_id: str
    route: Route
    direction_id: str

    @classmethod
    def parse(cls, stream: Iterable[str], routes_by_id: dict[str, Route]) -> Iterator["Trip"]:
        for trip in parse_csv(stream):
            route_id = trip["route_id"]
            if route_id not in routes_by_id:
                raise ValueError(f"Unknown route_id {route_id!r}")
            yield cls(
                trip_id=trip["trip_id"].strip(),
                route=routes_by_id[route_id],
                direction_id=trip["direction_id"].strip(),
            )

    def __str__(self):
        return f"Trip {self.trip_id} on route {self.route_id} in direction {self.direction_id}"


class Edge:
    def __init__(self, from_id: str, to_id: str, duration: float, intermediate_stops: Iterable[str] = ()):
        self.from_id: str = from_id
        self.to_id: str = to_id
        self.duration: float = duration
        self.intermediate_stops: [str] = list(intermediate_stops)

    def __hash__(self):
        return hash((self.from_id, self.to_id))

    def __eq__(self, other):
        return isinstance(other, Edge) and other.from_id == self.from_id and other.to_id == self.to_id


class DoTransfer(Edge):
    pass


class Feed:
    def __init__(self, stops: Iterable[Stop], transfers: Iterable[Transfer], edges: Iterable[Edge],
                 routes_by_stop: dict[Stop, set[Route]]):
        all_stops: dict[str: Stop] = {
            stop.stop_id: stop
            for stop in stops
        }
        self.stop_equivalents: dict[str: str] = {}
        for stop_id, stop in all_stops.items():
            while stop.parent_station is not None:
                stop = all_stops[stop.parent_station]
            self.stop_equivalents[stop_id] = stop.stop_id
        self.stops: dict[str, Stop] = {
            stop.stop_id: stop
            for stop in all_stops.values()
            if stop.parent_station is None
        }
        self.routes_by_stop: dict[Stop, set[Route]] = dict(routes_by_stop)

        self.transfers: dict[str: dict[str: Transfer]] = {}
        for transfer in transfers:
            from_stop = self.stop_equivalents[transfer.from_stop]
            to_stop = self.stop_equivalents[transfer.to_stop]
            if from_stop not in self.transfers:
                self.transfers[from_stop] = {to_stop: transfer}
            else:
                self.transfers[from_stop][to_stop] = transfer

        self.edges: {str: {str: Edge}} = {}
        for edge in edges:
            from_stop = self.stop_equivalents[edge.from_id]
            to_stop = self.stop_equivalents[edge.to_id]
            edge.from_id = from_stop
            edge.to_id = to_stop
            edge.intermediate_stops = [self.stop_equivalents[s] for s in edge.intermediate_stops]
            if from_stop not in self.edges:
                self.edges[from_stop] = {to_stop: edge}
            else:
                self.edges[from_stop][to_stop] = edge

        # Remove any stops that have no neighbors
        islanded_stops = {
            stop
            for stop in self.stops.keys()
            if sum(1 for _ in self.neighbors(stop)) == 0
        }
        for stop in islanded_stops:
            del self.stops[stop]

        self._leaves: Optional[set[str]] = None
        self._leaf_branch_length: dict[str: int] = {}
        self._shortest_path_lengths: Optional[dict[str: dict[str: float]]] = None

    @property
    def leaves(self) -> set[str]:
        if self._leaves is None:
            self._leaves = {
                stop
                for stop in self.stops
                if sum(1 for _ in self.neighbors(stop)) == 1
            }
        return self._leaves

    @property
    def leaf_branch_length(self) -> dict[str: int]:
        if not self.leaves or self._leaf_branch_length:
            return self._leaf_branch_length
        for leaf in self.leaves:
            node: str = leaf
            history = {leaf}
            while True:
                successors = {n.to_id for n in self.neighbors(node)} - history
                if len(successors) != 1:
                    break
                node = next(iter(successors))
                history.add(node)
            self._leaf_branch_length[leaf] = self.shortest_path_lengths[node][leaf]
        return self._leaf_branch_length

    @property
    def shortest_path_lengths(self) -> dict[str: dict[str: float]]:
        if self._shortest_path_lengths is not None:
            return self._shortest_path_lengths
        self._shortest_path_lengths = {}
        apsp_path = Path("shortest_path_lengths.txt")
        if apsp_path.exists():
            with open(apsp_path, "r") as f:
                for line in f:
                    from_stop, to_stop, distance = (s.strip() for s in line.strip().split(","))
                    distance = float(distance)
                    if from_stop not in self._shortest_path_lengths:
                        self._shortest_path_lengths[from_stop] = {to_stop: distance}
                    else:
                        self._shortest_path_lengths[from_stop][to_stop] = distance
        else:
            ordered_stops = list(self.stops.keys())
            self._shortest_path_lengths: dict[str: dict[str: float]] = {
                from_stop: {
                    to_stop: self.distance(from_stop, to_stop)
                    for to_stop in ordered_stops
                }
                for from_stop in ordered_stops
            }
            for k in tqdm(ordered_stops, leave=False, unit="stops", desc="calculating shortest paths"):
                for i in ordered_stops:
                    for j in ordered_stops:
                        dist = self._shortest_path_lengths[i][k] + self._shortest_path_lengths[k][j]
                        if self._shortest_path_lengths[i][j] > dist:
                            self._shortest_path_lengths[i][j] = dist
            with open(apsp_path, "w") as f:
                for from_stop, lengths in self._shortest_path_lengths.items():
                    for to_stop, distance in lengths.items():
                        f.write(f"{from_stop},{to_stop},{distance}\n")
        for from_node, distances in self._shortest_path_lengths.items():
            for to_node, distance in distances.items():
                if distance >= float('inf'):
                    raise ValueError(f"There is no path from {from_node} to {to_node}!")
        return self._shortest_path_lengths

    def distance(self, from_stop: str, to_stop: str) -> float:
        if from_stop == to_stop:
            return 0
        d = float('inf')
        if from_stop in self.edges and to_stop in self.edges[from_stop]:
            d = self.edges[from_stop][to_stop].duration
        if from_stop in self.transfers and to_stop in self.transfers[from_stop]:
            d = min(d, self.transfers[from_stop][to_stop].min_transfer_time)
        return d

    def neighbors(self, from_stop: str) -> Iterator[Edge]:
        if from_stop in self.edges:
            yield from self.edges[from_stop].values()
        if from_stop in self.transfers:
            for to_stop, transfer in self.transfers[from_stop].items():
                if from_stop != to_stop:
                    yield DoTransfer(from_stop, to_stop, transfer.min_transfer_time)

    @classmethod
    def load_or_download(cls, download_url: str, path: Optional[Path] = None) -> "Feed":
        if path is None:
            path = Path.cwd()
        if path.exists() and (path.is_file() or path.is_dir() and (path / "stops.txt").exists()):
            return cls.load(path, cache_dir=path)
        else:
            return cls.load(download_url, cache_dir=path)

    @classmethod
    def load(cls, path_or_url: Path | str | ZipFile, cache_dir: Optional[Path] = None) -> "Feed":
        if isinstance(path_or_url, str):
            if Path(path_or_url).exists():
                return cls.load(Path(path_or_url))
            else:
                # this is a URL
                response = urllib.request.urlopen(path_or_url)
                return cls.load(ZipFile(BytesIO(response.read())), cache_dir=cache_dir)
        elif isinstance(path_or_url, ZipFile):
            if cache_dir is None:
                with TemporaryDirectory() as d:
                    path_or_url.extractall(d)
                    return cls.load(d)
            else:
                path_or_url.extractall(cache_dir)
                return cls.load(cache_dir)
        if not isinstance(path_or_url, Path):
            raise ValueError(f"argument must be a Path, str, or ZipFile, not {path_or_url!r}")
        if path_or_url.is_file():
            return cls.load(ZipFile(path_or_url), cache_dir=cache_dir)

        with open(path_or_url / "routes.txt") as f:
            routes_by_id = {
                route.route_id: route
                for route in Route.parse(f)
            }

        with open(path_or_url / "trips.txt") as f:
            trips_by_id = {
                trip.trip_id: trip
                for trip in Trip.parse(f, routes_by_id=routes_by_id)
            }

        # this is a directory
        edges: dict[tuple[str, str], list[int]] = {}
        trips: dict[str, list[str]] = {}
        trips_by_stop: dict[str, set[Trip]] = {}
        with open(path_or_url / "stop_times.txt") as f:
            last_trip_id = ""
            last_seq = -1
            last_stop_id = -1
            last_arrival_time = 0
            for line in parse_csv(f):
                trip_id = line["trip_id"].strip()
                if trip_id not in trips_by_id:
                    raise ValueError(f"Unknown trip_id: {trip_id!r}, it is not in trips.txt!")
                trip = trips_by_id[trip_id]
                stop_id = line["stop_id"].strip()
                if stop_id not in trips_by_stop:
                    trips_by_stop[stop_id] = {trip}
                else:
                    trips_by_stop[stop_id].add(trip)
                arrival_time = line["arrival_time"].strip()
                departure_time = line["departure_time"].strip()
                stop_sequence = int(line["stop_sequence"].strip())
                arrival_hour, arrival_min, arrival_sec = map(int, arrival_time.split(":"))
                arrival_time = arrival_hour * 60 * 60 + arrival_min * 60 + arrival_sec
                if trip_id == last_trip_id and stop_sequence == last_seq + 1:
                    while arrival_time < last_arrival_time:
                        arrival_time += 24 * 60 * 60
                    edge = (last_stop_id, stop_id)
                    if edge not in edges:
                        edges[edge] = [arrival_time - last_arrival_time]
                    else:
                        edges[edge].append(arrival_time - last_arrival_time)
                last_trip_id = trip_id
                last_seq = stop_sequence
                last_stop_id = stop_id
                last_arrival_time = arrival_time
                if trip_id not in trips:
                    trips[trip_id] = [stop_id]
                else:
                    trips[trip_id].append(stop_id)
        edges: list[Edge] = [
            Edge(from_id, to_id, sum(times) / len(times))
            for (from_id, to_id), times in edges.items()
        ]
        # is there any trip where we pass through a station (e.g., an express train?)
        for edge in tqdm(edges, desc="Finding intermediate stops...", unit="edges", leave=False):
            intermediates: set[tuple[str, ...]] = set()
            for trip in trips_by_stop[edge.from_id] | trips_by_stop[edge.to_id]:
                trip_stops = trips[trip.trip_id]
                try:
                    from_id_index = trip_stops.index(edge.from_id)
                    to_id_index = trip_stops.index(edge.to_id)
                except ValueError:
                    continue
                intermediate_stops = tuple(trip_stops[from_id_index+1:to_id_index])
                if intermediate_stops:
                    intermediates.add(intermediate_stops)
            if len(intermediates) > 1:
                # TODO: Find a better way to do this!
                intermediates = {tuple(itertools.chain(*intermediates))}
            if intermediates:
                edge.intermediate_stops = list(next(iter(intermediates)))
        with open(path_or_url / "stops.txt") as f, open(path_or_url / "transfers.txt") as t:
            stops = list(Stop.parse(f))
            return cls(stops=stops, transfers=Transfer.parse(t), edges=edges,
                       routes_by_stop={
                           stop: {trip.route for trip in trips_by_stop[stop.stop_id]}
                           for stop in stops
                           if stop.stop_id in trips_by_stop
                       })


In [3]:
feed = Feed.load("gtfs_subway.zip")


                                                                                      

In [15]:
import pandas as pd
def make_graph(feed: Feed, directed: bool = False) -> nx.Graph:
    if directed:
        graph = nx.DiGraph()
    else:
        graph = nx.Graph()
    for node in feed.stops.keys():
        graph.add_node(node)
    for node in feed.stops.keys():
        graph.add_weighted_edges_from((
            (neighbor.from_id, neighbor.to_id, neighbor.duration)
            for neighbor in feed.neighbors(node)
            if neighbor.from_id in feed.stops and neighbor.to_id in feed.stops
        ))
    return graph

G = make_graph(feed)
el = list(G.edges(data=True))
for i in range(len(el)):
	q = (el[i][0], el[i][1])
	d = el[i][2]['weight']
	el[i] = [q,d]
el

[[('101', '103'), 208.3422459893048],
 [('103', '104'), 91.17647058823529],
 [('104', '106'), 90.0],
 [('106', '107'), 90.0],
 [('107', '108'), 89.68253968253968],
 [('108', '109'), 90.0],
 [('109', '110'), 72.01058201058201],
 [('110', '111'), 90.0],
 [('111', '112'), 120.21164021164022],
 [('112', '113'), 96.24338624338624],
 [('112', 'A09'), 180],
 [('113', '114'), 120.0],
 [('114', '115'), 90.0],
 [('115', '116'), 112.3292469352014],
 [('116', '117'), 141.59369527145358],
 [('117', '118'), 60.0],
 [('118', '119'), 82.3292469352014],
 [('119', '120'), 90.05253940455341],
 [('120', '121'), 109.7456279809221],
 [('120', '227'), 262.80898876404495],
 [('120', '123'), 174.91935483870967],
 [('121', '122'), 73.68839427662957],
 [('122', '123'), 106.31160572337043],
 [('123', '124'), 63.43402225755167],
 [('123', '127'), 238.1451612903226],
 [('124', '125'), 90.0],
 [('125', '126'), 120.0],
 [('125', 'A24'), 180],
 [('126', '127'), 112.41653418124007],
 [('127', '128'), 79.3718042366691],

In [None]:
import json

def serialize(obj):
    if not isinstance(obj, (str, int, float, bool)):
        return list(obj)
    else: return obj

with open('sultanik_edges.json', 'w') as f:
     json.dump(el,f, default=serialize)
    