In [1]:
from dataclasses import dataclass, field
from typing import Optional, List, Set, Callable, Dict, Deque
from collections import deque
from enum import Enum


# ============================================================================
# Passenger Classes
# ============================================================================

@dataclass
class Passenger:
    """Individual passenger with journey tracking."""
    id: int
    origin_id: str
    dest_id: str
    speed_mps: float = 1.4  # meters per second, reflects mobility
    created_at: float = 0.0  # when generated (minutes)
    queued_at: Optional[float] = None  # when joined station queue
    boarded_at: Optional[float] = None  # when entered train
    alighted_at: Optional[float] = None  # when exited train
    tags: Set[str] = field(default_factory=set)  # e.g., {"wheelchair"}
    notes: Optional[str] = None
    current_line: Optional[str] = None  # which line queue they're in
    route: List[str] = field(default_factory=list)  # planned station path

    @property
    def wait_time(self) -> Optional[float]:
        """Time spent waiting at platform (minutes)."""
        if self.boarded_at is None or self.queued_at is None:
            return None
        return self.boarded_at - self.queued_at

    @property
    def in_vehicle_time(self) -> Optional[float]:
        """Time spent on train (minutes)."""
        if self.alighted_at is None or self.boarded_at is None:
            return None
        return self.alighted_at - self.boarded_at

    @property
    def total_journey_time(self) -> Optional[float]:
        """Total time from creation to arrival (minutes)."""
        if self.alighted_at is None:
            return None
        return self.alighted_at - self.created_at


@dataclass
class PassengerProfile:
    """Defines characteristics for a passenger type."""
    name: str = "default"
    speed_mps: float = 1.4  # walking speed
    boarding_time: float = 2.0  # seconds to board
    proportion: float = 1.0  # proportion of total passengers


# ============================================================================
# Infrastructure Classes
# ============================================================================

class Station:
    """Station with passenger queues."""
    def __init__(self, station_id: str, transfer_time: float = 2.0):
        self.id = station_id
        self.transfer_time = transfer_time  # minutes to transfer between lines
        # Queue per line at this station
        self.queues: Dict[str, Deque[int]] = {}  # line_id -> deque of passenger_ids
        self.lines: Set[str] = set()  # which lines stop here

    def add_line(self, line_id: str):
        """Register a line that stops at this station."""
        self.lines.add(line_id)
        if line_id not in self.queues:
            self.queues[line_id] = deque()

    def add_passenger_to_queue(self, passenger_id: int, line_id: str):
        """Add passenger to queue for specific line."""
        if line_id not in self.queues:
            self.queues[line_id] = deque()
        self.queues[line_id].append(passenger_id)

    def pop_for_boarding(self, line_id: str, max_n: int) -> List[int]:
        """Pop up to max_n passengers FIFO from the queue for a line."""
        q = self.queues.get(line_id)
        if not q:
            return []
        out = []
        for _ in range(min(max_n, len(q))):
            out.append(q.popleft())
        return out

    def get_queue_length(self, line_id: str) -> int:
        """Get number of passengers waiting for a line."""
        q = self.queues.get(line_id)
        return len(q) if q is not None else 0

    def is_transfer_station(self) -> bool:
        """Check if this is a transfer point between lines."""
        return len(self.lines) > 1


class Line:
    """Train line with stops and travel times."""
    def __init__(self, name: str, stops: List[str], travel_times: List[float], fleet_size: int = 0):
        self.name = name
        self.stops = stops  # ordered list of station IDs
        self.fleet_size = fleet_size  # optional: number of trains assigned to this line
        if len(travel_times) != len(stops) - 1:
            raise ValueError(
                f"Line {name}: need {len(stops)-1} travel times for {len(stops)} stops"
            )
        self.travel_times = travel_times  # minutes between consecutive stops

    def get_travel_time(self, from_station: str, to_station: str) -> Optional[float]:
        """Get travel time between two stations on this line (forward direction)."""
        try:
            from_idx = self.stops.index(from_station)
            to_idx = self.stops.index(to_station)
            if from_idx >= to_idx:
                return None  # wrong direction or same station
            return sum(self.travel_times[from_idx:to_idx])
        except (ValueError, IndexError):
            return None

    def get_next_stop(self, current_station: str) -> Optional[str]:
        """Get the next station after current one."""
        try:
            idx = self.stops.index(current_station)
            if idx < len(self.stops) - 1:
                return self.stops[idx + 1]
        except ValueError:
            pass
        return None

    def get_stop_index(self, station: str) -> Optional[int]:
        """Get index of station in stops list."""
        try:
            return self.stops.index(station)
        except ValueError:
            return None


# ============================================================================
# Rolling Stock Classes
# ============================================================================

class TrainState(Enum):
    """Train operational state."""
    AT_STATION = "at_station"
    IN_TRANSIT = "in_transit"


class Train:
    """Train operating on a line."""
    def __init__(self, train_id: str, line: str, capacity: int, schedule: Optional[List[float]] = None):
        self.id = train_id
        self.line = line  # line ID this train operates on
        self.capacity = capacity
        self.schedule = schedule or []  # departure times from first station (minutes)
        self.onboard: List[int] = []  # passenger IDs currently on train
        self.state = TrainState.AT_STATION
        self.current_station: Optional[str] = None
        self.next_station: Optional[str] = None
        self.arrival_time: Optional[float] = None  # when arriving at next station

    @property
    def occupancy(self) -> int:
        """Current number of passengers."""
        return len(self.onboard)

    @property
    def occupancy_rate(self) -> float:
        """Occupancy as percentage of capacity."""
        return (self.occupancy / self.capacity * 100) if self.capacity > 0 else 0.0

    @property
    def available_capacity(self) -> int:
        """Remaining space on train."""
        return max(0, self.capacity - self.occupancy)

    def board_passengers(self, passenger_ids: List[int]) -> List[int]:
        """Board passengers up to capacity. Returns list of boarded passenger IDs."""
        space = self.available_capacity
        to_board = passenger_ids[:space]
        if to_board:
            self.onboard.extend(to_board)
        return to_board

    def alight_passengers(self, station_id: str, passengers: Dict[int, Passenger]) -> List[int]:
        """Remove passengers whose destination is this station; return list of alighting IDs."""
        alighting = []
        remaining = []
        for pid in self.onboard:
            if passengers[pid].dest_id == station_id:
                alighting.append(pid)
            else:
                remaining.append(pid)
        self.onboard = remaining
        return alighting


# ============================================================================
# Demand & Schedule Classes
# ============================================================================

@dataclass
class PassengerDemand:
    """Defines passenger demand between two stations."""
    origin: str
    destination: str
    rate: Callable[[float], float]  # function: time -> passengers per hour
    pattern: Optional[str] = None  # optional preset pattern name

    def get_demand(self, time: float) -> float:
        """Get passenger arrival rate at given time (passengers/hour)."""
        if self.pattern:
            return self._get_pattern_demand(time)
        return self.rate(time)

    def _get_pattern_demand(self, time: float) -> float:
        """Get demand from preset pattern."""
        patterns = {
            "rush_hour": lambda t: 100 if (7 <= t < 9 or 17 <= t < 19) else 20,
            "constant": lambda t: 50,
            "evening_peak": lambda t: 80 if 17 <= t < 20 else 30,
        }
        if self.pattern in patterns:
            return patterns[self.pattern](time)
        return self.rate(time)


@dataclass
class TrainSchedule:
    """Defines when trains run on a line."""
    line: str
    capacity: int
    frequency: Optional[float] = None  # minutes between trains
    departures: Optional[List[float]] = None  # explicit departure times
    start_time: float = 0.0
    end_time: float = 120.0

    def get_departure_times(self) -> List[float]:
        """Generate list of departure times."""
        if self.departures:
            return sorted(self.departures)
        if self.frequency:
            times = []
            current = self.start_time
            while current <= self.end_time:
                times.append(current)
                current += self.frequency
            return times
        raise ValueError("Must specify either frequency or departures")

    def __post_init__(self):
        """Validate schedule parameters."""
        if self.frequency is None and self.departures is None:
            raise ValueError("Must specify either frequency or departures")
        if self.frequency is not None and self.departures is not None:
            raise ValueError("Cannot specify both frequency and departures")


# ============================================================================
# Network Classes
# ============================================================================

class Network:
    """Rail network graph with stations and lines."""
    def __init__(self, transfer_time: float = 2.0):
        self.stations: Dict[str, Station] = {}
        self.lines: Dict[str, Line] = {}
        self.default_transfer_time = transfer_time

    def add_station(self, station_id: str, transfer_time: Optional[float] = None):
        """Add a station to the network."""
        if station_id not in self.stations:
            tt = transfer_time if transfer_time is not None else self.default_transfer_time
            self.stations[station_id] = Station(station_id, transfer_time=tt)

    def add_line(self, line: Line):
        """Add a line and auto-create stations."""
        self.lines[line.name] = line
        # Auto-create stations from line definition
        for station_id in line.stops:
            self.add_station(station_id)
            self.stations[station_id].add_line(line.name)

    def get_station(self, station_id: str) -> Optional[Station]:
        """Get station by ID."""
        return self.stations.get(station_id)

    def get_line(self, line_id: str) -> Optional[Line]:
        """Get line by ID."""
        return self.lines.get(line_id)

    def find_route(self, origin: str, destination: str) -> Optional[List[tuple]]:
        """
        Find route between stations.
        Returns list of (line_id, from_station, to_station) tuples.
        Simple implementation: direct line or one transfer.
        """
        # Direct line
        for line in self.lines.values():
            if origin in line.stops and destination in line.stops:
                orig_idx = line.stops.index(origin)
                dest_idx = line.stops.index(destination)
                if orig_idx < dest_idx:
                    return [(line.name, origin, destination)]
        # One-transfer
        for line1 in self.lines.values():
            if origin not in line1.stops:
                continue
            for transfer_station in line1.stops:
                if transfer_station == origin:
                    continue
                for line2 in self.lines.values():
                    if line2.name == line1.name:
                        continue
                    if transfer_station in line2.stops and destination in line2.stops:
                        orig_idx = line1.stops.index(origin)
                        transfer_idx = line1.stops.index(transfer_station)
                        transfer_idx2 = line2.stops.index(transfer_station)
                        dest_idx = line2.stops.index(destination)
                        if orig_idx < transfer_idx and transfer_idx2 < dest_idx:
                            return [
                                (line1.name, origin, transfer_station),
                                (line2.name, transfer_station, destination),
                            ]
        return None  # No route found

In [3]:
import random
passenger1 = Passenger(id=1, origin_id="Central", dest_id="Wynyard")
passenger2 = Passenger(id=2, origin_id="Town Hall", dest_id="Redfern", speed_mps=1.2)

T1 = Line(
    name="T1",
    stops=["Central", "Town Hall", "Wynyard", "Circular Quay"],
    travel_times=[2.0, 3.0, 4.0],
    fleet_size=8 # Should generate 8 train objects for this line
)
passenger_list = [Passenger(id=i, origin_id=random.choice(T1.stops), dest_id=random.choice(T1.stops)) for i in range(3, 10)]

In [None]:
railnet = Network(transfer_time=3.0)
railnet.add_line(T1)
print(railnet.stations)
print

{'Central': <__main__.Station at 0x1e0df156710>,
 'Town Hall': <__main__.Station at 0x1e0df1320d0>,
 'Wynyard': <__main__.Station at 0x1e0df17e010>,
 'Circular Quay': <__main__.Station at 0x1e0df17c490>}

In [18]:
import random

# Define a simple line and network
line = Line(
    name="T1",
    stops=["Central", "Town Hall", "Wynyard", "Circular Quay"],
    travel_times=[2.0, 3.0, 4.0],
    fleet_size=1
)
net = Network(transfer_time=2.0)
net.add_line(line)

# Create a few passengers with forward trips only (origin index < dest index)
rng = random.Random(42)
passengers = {}
pid = 1
for _ in range(8):
    oi = rng.randrange(0, len(line.stops) - 1)
    di = rng.randrange(oi + 1, len(line.stops))
    p = Passenger(id=pid, origin_id=line.stops[oi], dest_id=line.stops[di], created_at=0.0, queued_at=0.0)
    passengers[pid] = p
    net.stations[p.origin_id].add_passenger_to_queue(p.id, line.name)
    pid += 1

# One train starting at first station
train = Train(train_id="T1-001", line=line.name, capacity=4)
current_time = 0.0
station_idx = 0
train.current_station = line.stops[station_idx]

print("Start simulation")
while True:
    station_name = line.stops[station_idx]
    station = net.get_station(station_name)
    print(f"\nTime {current_time:.1f} min | Train at {station_name}")

    # 1) Alight
    alighting = train.alight_passengers(station_name, passengers)
    for pid in alighting:
        passengers[pid].alighted_at = current_time
    if alighting:
        print(f"  Alighted: {alighting}")

    # 2) Board
    to_board_ids = station.pop_for_boarding(line.name, train.available_capacity)
    boarded_ids = train.board_passengers(to_board_ids)
    for pid in boarded_ids:
        passengers[pid].boarded_at = current_time
    if boarded_ids:
        print(f"  Boarded:  {boarded_ids} (occ={train.occupancy}/{train.capacity})")
    else:
        print(f"  No boarding (occ={train.occupancy}/{train.capacity})")

    # 3) If at last station, end; else move to next station
    if station_idx == len(line.stops) - 1:
        print("Reached terminal. Ending.")
        break

    # Move to next stop
    next_idx = station_idx + 1
    travel_time = line.travel_times[station_idx]
    current_time += travel_time
    station_idx = next_idx

# Report simple metrics
served = [p for p in passengers.values() if p.alighted_at is not None]
left_onboard = train.onboard[:]  # any not alighted by terminal (should be none if all dests <= terminal)
waiting_remaining = sum(st.get_queue_length(line.name) for st in net.stations.values())

print(f"\nSummary:")
print(f"  Passengers created: {len(passengers)}")
print(f"  Passengers served:  {len(served)}")
print(f"  Left onboard:       {left_onboard}")
print(f"  Still waiting:      {waiting_remaining}")

for p in sorted(served, key=lambda x: x.id):
    print(f"    P{p.id}: wait={p.wait_time:.1f} min, in-vehicle={p.in_vehicle_time:.1f} min, total={p.total_journey_time:.1f} min")

Start simulation

Time 0.0 min | Train at Central
  Boarded:  [2, 4, 8] (occ=3/4)

Time 2.0 min | Train at Town Hall
  Alighted: [4, 8]
  Boarded:  [3] (occ=2/4)

Time 5.0 min | Train at Wynyard
  Alighted: [3]
  Boarded:  [1, 5, 6] (occ=4/4)

Time 9.0 min | Train at Circular Quay
  Alighted: [2, 1, 5, 6]
  No boarding (occ=0/4)
Reached terminal. Ending.

Summary:
  Passengers created: 8
  Passengers served:  7
  Left onboard:       []
  Still waiting:      1
    P1: wait=5.0 min, in-vehicle=4.0 min, total=9.0 min
    P2: wait=0.0 min, in-vehicle=9.0 min, total=9.0 min
    P3: wait=2.0 min, in-vehicle=3.0 min, total=5.0 min
    P4: wait=0.0 min, in-vehicle=2.0 min, total=2.0 min
    P5: wait=5.0 min, in-vehicle=4.0 min, total=9.0 min
    P6: wait=5.0 min, in-vehicle=4.0 min, total=9.0 min
    P8: wait=0.0 min, in-vehicle=2.0 min, total=2.0 min


In [3]:
import json
with open("./get_tfnsw_station/stations.json", mode="r+", encoding='utf-8') as f:
    station_dict = json.load(f)
    station_dict.pop('Station', None)

{'Allawah': ['T4'],
 'Arncliffe': ['T4'],
 'Artarmon': ['T1', 'T9'],
 'Ashfield': ['T2', 'T3'],
 'Asquith': ['T1'],
 'Auburn': ['T1', 'T2'],
 'Banksia': ['T4'],
 'Bankstown': ['T6'],
 'Bardwell Park': ['T8'],
 'Beecroft': ['T9'],
 'Berala': ['T3', 'T6'],
 'Berowra': ['T1'],
 'Beverly Hills': ['T8'],
 'Bexley North': ['T8'],
 'Birrong': ['T6'],
 'Blacktown': ['T1', 'T5'],
 'Bondi Junction': ['T4'],
 'Burwood': ['T2', 'T3', 'T9'],
 'Cabramatta': ['T2', 'T3', 'T5'],
 'Campbelltown': ['T8'],
 'Canley Vale': ['T2', 'T5'],
 'Caringbah': ['T4'],
 'Carlton': ['T4'],
 'Carramar': ['T3'],
 'Casula': ['T2', 'T5'],
 'Central': ['T1', 'T2', 'T3', 'T4', 'T7', 'T8', 'T9'],
 'Chatswood': ['T1'],
 'Cheltenham': ['T9'],
 'Chester Hill': ['T3'],
 'Circular Quay': ['T2', 'T3', 'T8'],
 'Clarendon': ['T1', 'T5'],
 'Clyde': ['T1', 'T2'],
 'Como': ['T4'],
 'Concord West': ['T9'],
 'Cronulla': ['T4'],
 'Croydon': ['T2', 'T3'],
 'Denistone': ['T9'],
 'Domestic Airport': ['T8'],
 'Doonside': ['T1'],
 'East Hills