In [81]:
from tkinter.messagebox import showerror

import numpy as np
from src.utils import read_data

# get frequently used data: pathvia array
PV = read_data("pathvia.pkl", show_timer=False).sort_values(by=["path_id"])
# PV = read_data("pathvia.pkl", show_timer=False).sort_values(by=["path_id", "pv_id"])
PV: np.ndarray = PV[PV["link_type"] == "in_vehicle"][["path_id", "node_id1", "node_id2", "line", "updown"]].values

# get frequently used data: timetable array
TT = read_data("TT", show_timer=False).reset_index().sort_values(by=["LINE_NID", "UPDOWN", "TRAIN_ID", "DEPARTURE_TS"])
TT['ts1'] = TT["DEPARTURE_TS"] - TT["STOP_TIME"]  # make sure the carriage gate is actually open at arrive_ts
TT: np.ndarray = TT[["TRAIN_ID", "STATION_NID", "LINE_NID", "UPDOWN", "ts1", "DEPARTURE_TS"]].values

In [63]:
STA = read_data("STA", show_timer=False)
line = np.random.choice([1, 2, 3, 4, 7, 10], size=1)[0]
o_nid, d_nid = np.random.choice(STA[STA["LINE_NID"] == line].index, 2)
AFC = read_data("AFC", show_timer=False)
ts1, ts2 = AFC[["TS1", "TS2"]].sample(n=1).values.tolist()[0]
if line == 7:
    upd = 1 if np.random.rand() < 0.5 else -1
else:
    upd = 1 if o_nid < d_nid else -1
print(o_nid, d_nid, ts1, ts2, upd)

10337 10326 57090 58454 -1


In [128]:
def find_tt1(_nid1: int, _nid2: int, _ts1: int, _ts2: int, _upd: int) -> list[tuple[int, int, int]]:
    """
    Find timetable array for nid pairs within ts range.
    Board_ts is obtained from departure_ts, alight_ts from arrive_ts.

    :return: [(train_id, board_ts, alight_ts), ...]
    """
    assert _ts1 < _ts2, f"ts1 should be smaller than ts2: {_ts1}, {_ts2}"

    # filter line
    start_idx, end_idx = np.searchsorted(TT[:, 2], [line, line + 1])
    tt = TT[start_idx:end_idx]

    # filter updown
    start_idx, end_idx = np.searchsorted(tt[:, 3], [upd, upd + 1])
    tt = tt[start_idx:end_idx]

    # filter od and ts range
    tt = tt[
        ((tt[:, 1] == _nid1) | (tt[:, 1] == _nid2))  # filter nid
        & (tt[:, 5] > _ts1)  # filter ts1
        & (tt[:, 4] < _ts2)  # filter ts2
        ]

    # find pairs
    res: list[tuple[int]] = []
    for r1, r2 in zip(tt[:-1], tt[1:]):
        if r1[0] == r2[0] and r1[1] == _nid1 and r2[1] == _nid2:
            res.append((r1[0], r1[5], r2[4]))

    return res

def find_tt2(_nid1: int, _nid2: int, _ts1: int, _ts2: int, _upd: int) -> list[tuple[int, int, int]]:
    """
    Find timetable array for nid pairs within ts range.
    Board_ts is obtained from departure_ts, alight_ts from arrive_ts.

    :return: [(train_id, board_ts, alight_ts), ...]
    """
    assert _ts1 < _ts2, f"ts1 should be smaller than ts2: {_ts1}, {_ts2}"

    # filter line
    start_idx, end_idx = np.searchsorted(TT[:, 2], [line, line + 1])
    tt = TT[start_idx:end_idx]

    # filter updown
    start_idx, end_idx = np.searchsorted(tt[:, 3], [upd, upd + 1])
    tt = tt[start_idx:end_idx]

    # filter od and ts range
    tt = tt[
        ((tt[:, 1] == _nid1) | (tt[:, 1] == _nid2))  # filter nid
        & (tt[:, 5] > _ts1)  # filter ts1
        & (tt[:, 4] < _ts2)  # filter ts2
        ]

    # # find pairs
    # res: list[tuple[int]] = []
    # for r1, r2 in zip(tt[:-1], tt[1:]):
    #     if r1[0] == r2[0] and r1[1] == _nid1 and r2[1] == _nid2:
    #         res.append((r1[0], r1[5], r2[4]))

    # Pair finding using vectorized comparison
    train_ids = tt[:-1, 0] == tt[1:, 0]  # check if train_id is the same for consecutive rows
    nid1_condition = tt[:-1, 1] == _nid1  # check if the first station is _nid1
    nid2_condition = tt[1:, 1] == _nid2  # check if the second station is _nid2

    # Combine conditions
    condition = train_ids & nid1_condition & nid2_condition

    # Prepare result using filtered indices
    res = [(tt[i, 0], tt[i, 5], tt[i + 1, 4]) for i in range(len(tt) - 1) if condition[i]]
    return res

In [130]:
%timeit find_tt1(o_nid, d_nid, ts1, ts2+2000, upd)
%timeit find_tt2(o_nid, d_nid, ts1, ts2+2000, upd)

41 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
37.1 µs ± 882 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [134]:
_uid1, _uid2 = 1001, 1130
%timeit base_path_id = int(f"{_uid1}{_uid2}01")
%timeit base_path_id = _uid1*1000000 +_uid2*100 +1

408 ns ± 11.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
133 ns ± 8.71 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [136]:

_uid1*1000000 +_uid2*100 +1, int(f"{_uid1}{_uid2}01")

(1001113001, 1001113001)