In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import rlway.osrd.infra as infr

In [3]:
infra, sim, res = infr.read_jsons_in_dir('../cases/2cvg')

In [4]:
infr.draw_infra(infra)

In [5]:
import networkx as nx

G = nx.DiGraph()
G.add_edges_from(
[
    (
        route['id'].replace('rt.','').split('->')[0],
        route['id'].replace('rt.','').split('->')[1],
    )
    for route in infra['routes']
]
)

convergence_entry_detectors = []
for node in G:
    if len(list(G.predecessors(node))) > 1:
        convergence_entry_detectors += list(G.predecessors(node))

cvg_signals = [
    s['id']
    for d in convergence_entry_detectors
    for s in infra['signals'] if s['linked_detector']==d 
]

cvg_signals

['SA1', 'SA2', 'S4', 'SC1']

In [6]:
track_sections_elements = {
    track['id']: {}
    for track in infra['track_sections']
}

for element in infra['detectors']:
    track = element['track']
    id = element['id']
    track_sections_elements[track][(id, 'detector')] = element['position']

for element in infra['signals']:
    track = element['track']
    id = element['id']
    position = element['position']
    tag = 'cvg_signal' if id in cvg_signals else 'signal'
    track_sections_elements[track][(id, tag)] = element['position']

for station in infra['operational_points']:
    id = station['id']
    for part in station['parts']:
        track = part['track']
        track_sections_elements[track][(id, 'station')] = part['position']

for track, elements in track_sections_elements.items():
    track_sections_elements[track] = \
        {
            k: v for k,v in sorted(
                track_sections_elements[track].items(),
                key=lambda item: item[1]
            )
        }

In [7]:
trajs = [
    list(dict.fromkeys([
        time["track_section"]
        for time in res[train]['base_simulations'][0]['head_positions']
    ]))
    for train, _ in enumerate(res)
]

In [8]:
train = 0
[
    key[0]
    for track in trajs[train]
    for key in track_sections_elements[track].keys()
    if key[1] in ['station', 'cvg_signal']
]

['Station_WEST', 'SA1', 'S4', 'Station_EAST']

In [9]:
def draw_digraph(G):
    for layer, nodes in enumerate(nx.topological_generations(G)):
        # `multipartite_layout` expects the layer as a node attribute,
        # so add the numeric layer value as a node attribute
        for node in nodes:
            G.nodes[node]["layer"] = layer

    # Compute the multipartite_layout using the "layer" node attribute
    pos = nx.multipartite_layout(G, subset_key="layer")

    # _, ax = plt.subplots()
    nx.draw_networkx(G, pos, node_shape='s', )  # ax=ax)

In [10]:
from typing import Dict, Any


def num_trains(sim: Dict) -> int:
    return len(sim['train_schedules'])