In [None]:
#!/usr/bin/env python3
"""
route_schedule_and_animate.py

Builds graph from Book-(Sheet1).csv (uses 'adjacent_connected_node(s)' column),
reads OD requests from source_destination.csv, finds obtuse-turn-only routes,
schedules trains with safety rules, writes feasible routes + occupancy CSV,
and saves an animation (MP4 or GIF).

Author: ChatGPT (GPT-5 Thinking mini)
"""

import os, csv, math, heapq, random
from datetime import datetime, timedelta
from collections import defaultdict, deque

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.animation as animation

# ----------------------------
# USER PARAMETERS (tweak if needed)
# ----------------------------
NODES_CSV = "Book-(Sheet1).csv"           # node file (must include nodes,x,y,adjacent_connected_node(s))
OD_CSV = "source_destination.csv"         # must include source,destination (column names flexible)
OUT_ROUTES = "feasible_routes.txt"
OUT_OCC = "occupancy_log.csv"
OUT_ANIM_MP4 = "train_schedule.mp4"
OUT_ANIM_GIF = "train_schedule.gif"

# physical & scheduling
SPEED_MPS = 8.0               # 8 m/s
HEADWAY_S = 600               # 10 minutes
SAFE_DISTANCE_M = 500.0       # 500 meters separation alternative
TRAIN_LENGTH_M = 400.0        # train length (for tail clear)
NODE_FREE_NORMAL_S = 120      # 2 minutes after tail clears
NODE_FREE_SWITCH_S = 300      # 5 minutes after tail clears (for degree >= 3)
MAX_PATHS_PER_OD = 30         # limit to avoid combinatorial explosion
MAX_HOPS = 20                 # max hops for route search
MAX_TRAINS_CAP = 500          # hard maximum scheduled trains
START_TIME = datetime.now().replace(hour=8, minute=0, second=0, microsecond=0)
START_WINDOW_SECONDS = 3600   # allow train starts within first 1 hour
ANIM_DURATION_SECONDS = 60.0  # playback length (seconds)
FPS = 20

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

DEBUG = True

# ----------------------------
# Utilities
# ----------------------------
def dist_m(p, q):
    return math.hypot(p[0] - q[0], p[1] - q[1])

def angle_between(prev, cur, nxt):
    """Angle in degrees between vector (cur - prev) and (nxt - cur)"""
    v1 = (cur[0] - prev[0], cur[1] - prev[1])
    v2 = (nxt[0] - cur[0], nxt[1] - cur[1])
    n1 = math.hypot(v1[0], v1[1])
    n2 = math.hypot(v2[0], v2[1])
    if n1 == 0 or n2 == 0:
        return 180.0
    dot = (v1[0]*v2[0] + v1[1]*v2[1]) / (n1*n2)
    dot = max(-1.0, min(1.0, dot))
    return math.degrees(math.acos(dot))

# ----------------------------
# Load nodes CSV & build graph
# ----------------------------
if not os.path.exists(NODES_CSV):
    raise FileNotFoundError(f"{NODES_CSV} not found. Put it in the working directory.")

nodes = {}
with open(NODES_CSV, newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for r in reader:
        node = str(r.get('nodes') or r.get('id') or r.get('node')).strip()
        if node == "":
            continue
        x = float(r.get('x') or r.get('lon') or r.get('X') or 0.0)
        y = float(r.get('y') or r.get('lat') or r.get('Y') or 0.0)
        adj = r.get('adjacent_connected_node(s)') or r.get('adjacent') or r.get('adj') or ""
        nodes[node] = {'x': x, 'y': y, 'adj': adj}

G = nx.Graph()
for n, v in nodes.items():
    G.add_node(n, x=v['x'], y=v['y'])

# Build edges from adjacency column (supports comma/semicolon/space separated)
for n, v in nodes.items():
    raw = (v.get('adj') or "").strip()
    if not raw:
        continue
    san = raw.replace('"', '').replace("'", "")
    sep = ',' if ',' in san else ';' if ';' in san else ' '
    for tok in [t.strip() for t in san.split(sep) if t.strip()]:
        if tok not in nodes:
            # placeholder node at 0,0 if somehow referenced
            nodes.setdefault(tok, {'x':0.0, 'y':0.0, 'adj':""})
            if tok not in G:
                G.add_node(tok, x=0.0, y=0.0)
        if not G.has_edge(n, tok):
            lm = dist_m((nodes[n]['x'], nodes[n]['y']), (nodes[tok]['x'], nodes[tok]['y']))
            G.add_edge(n, tok, length_m=lm)

if DEBUG:
    print(f"Graph constructed: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")

# ----------------------------
# Read OD requests
# ----------------------------
if not os.path.exists(OD_CSV):
    raise FileNotFoundError(f"{OD_CSV} not found. Provide source_destination.csv in working directory.")

od_list = []
with open(OD_CSV, newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    # try common headers
    for r in reader:
        s = str(r.get('source') or r.get('from') or r.get('src') or r.get('origin') or '').strip()
        d = str(r.get('destination') or r.get('to') or r.get('dst') or r.get('dest') or '').strip()
        if s == "" or d == "":
            # try first two columns if headers unknown
            keys = list(r.keys())
            if len(keys) >= 2:
                s = str(r[keys[0]]).strip(); d = str(r[keys[1]]).strip()
        if s and d:
            if s not in G or d not in G:
                if DEBUG:
                    print(f"OD pair {s}->{d} references unknown node; skipping.")
                continue
            od_list.append((s,d))
if not od_list:
    raise RuntimeError("No valid OD pairs found in source_destination.csv")

if DEBUG:
    print(f"Loaded {len(od_list)} OD pairs (requests)")

# ----------------------------
# Feasible routes finder (all simple paths with obtuse-turn constraint)
# We'll do a DFS but cap total paths per OD and max hops to avoid explosion.
# ----------------------------
def find_feasible_routes(G, source, target, max_paths=MAX_PATHS_PER_OD, max_hops=MAX_HOPS, angle_min_deg=90.0):
    routes = []
    stack = [(source, [source])]
    while stack and len(routes) < max_paths:
        cur, path = stack.pop()
        if len(path) > max_hops:
            continue
        if cur == target:
            routes.append(list(path))
            continue
        # neighbors sorted by increasing edge length (heuristic for short routes first)
        neighbors = sorted(G.neighbors(cur), key=lambda nb: G.edges[cur, nb]['length_m'])
        for nb in neighbors:
            if nb in path:
                continue
            if len(path) >= 2:
                prev = path[-2]
                prev_coord = (G.nodes[prev]['x'], G.nodes[prev]['y'])
                cur_coord = (G.nodes[cur]['x'], G.nodes[cur]['y'])
                nb_coord = (G.nodes[nb]['x'], G.nodes[nb]['y'])
                ang = angle_between(prev_coord, cur_coord, nb_coord)
                # require non-acute: angle >= angle_min_deg
                if not (ang + 1e-9 >= angle_min_deg):
                    continue
            stack.append((nb, path + [nb]))
    return routes

# compute feasible routes for each OD
feasible_routes = {}
for s,d in od_list:
    routes = find_feasible_routes(G, s, d, max_paths=MAX_PATHS_PER_OD, max_hops=MAX_HOPS, angle_min_deg=90.0)
    feasible_routes[(s,d)] = routes
    if DEBUG:
        print(f"OD {s}->{d}: found {len(routes)} feasible routes")

# write feasible routes to text for inspection
with open(OUT_ROUTES, 'w', encoding='utf-8') as f:
    for (s,d), rlist in feasible_routes.items():
        f.write(f"{s} -> {d} : {len(rlist)} routes\n")
        for r in rlist:
            f.write("  " + " -> ".join(r) + "\n")
        f.write("\n")
if DEBUG:
    print(f"Wrote feasible routes to {OUT_ROUTES}")

# ----------------------------
# Scheduling helpers & occupancy structures
# ----------------------------
def canonical_seg(u, v):
    return tuple(sorted((u, v)))

# occupancy per-segment: seg_id -> dict {+1: [ (entry, tail_clear) , ... ], -1: [...]}
segment_occupancy = defaultdict(lambda: {+1: [], -1: []})
# node occupancy: node -> list of (start_busy, end_busy)
node_occupancy = defaultdict(list)

# helper: compute node free buffer depending on degree
def node_free_buffer(node):
    deg = G.degree(node)
    return NODE_FREE_SWITCH_S if deg >= 3 else NODE_FREE_NORMAL_S

# check segment placement
def can_place_segment(seg, direction, entry_front, tail_clear):
    # same direction: allow if either entry_front >= existing_entry + HEADWAY_S OR distance sep >= SAFE_DISTANCE_M
    same_list = segment_occupancy[seg][direction]
    for (e1, e2) in same_list:
        delta_sec = (entry_front - e1).total_seconds()
        if delta_sec < 0:
            # new start before an existing start -- this algorithm schedules increasing time
            # treat as potential conflict if intervals overlap
            if not (tail_clear <= e1):
                return False
            else:
                continue
        # check headway time
        if delta_sec >= HEADWAY_S:
            continue
        # else check spatial separation at new start time
        dist_sep = SPEED_MPS * delta_sec
        if dist_sep >= SAFE_DISTANCE_M:
            continue
        # otherwise conflict
        return False
    # opposite direction: no overlap of [entry_front, tail_clear] with any existing opposite intervals
    opp_list = segment_occupancy[seg][-direction]
    for (o1, o2) in opp_list:
        if not (tail_clear <= o1 or entry_front >= o2):
            return False
    return True

def can_place_node(node, arrival_center, tail_clear):
    # node busy interval for this train = [arrival_center, tail_clear + node_buffer]
    buf = node_free_buffer(node)
    busy_start = arrival_center
    busy_end = tail_clear + timedelta(seconds=buf)
    for (s,e) in node_occupancy[node]:
        if not (busy_end <= s or busy_start >= e):
            return False
    return True

def add_segment_occupancy(seg, direction, entry_front, tail_clear):
    segment_occupancy[seg][direction].append((entry_front, tail_clear))

def add_node_occupancy(node, arrival_center, tail_clear):
    buf = node_free_buffer(node)
    node_occupancy[node].append((arrival_center, tail_clear + timedelta(seconds=buf)))

# ----------------------------
# Scheduling algorithm (greedy, round-robin over OD list until no new starts fit into 1 hour)
# ----------------------------
scheduled = []   # list of dicts: { 'od':(s,d), 'route': route_nodes, 'start_time': dt, 'edge_times': [ (u,v,entry_front, tail_clear, dir_sign) ] }

# Precompute route edges lengths & durations for convenience
route_infos = {}  # (s,d,idx) -> dict
for key, rlist in feasible_routes.items():
    s,d = key
    for idx, path in enumerate(rlist):
        edges = []
        tot = 0.0
        for i in range(len(path)-1):
            u, v = path[i], path[i+1]
            lm = G.edges[u, v]['length_m']
            edges.append((u, v, lm))
            tot += lm
        route_infos[(s,d,idx)] = {'path': path, 'edges': edges, 'length_m': tot, 'durations': [lm / SPEED_MPS for (_,_,lm) in edges]}

# helper to check & place a train on a given route starting at t0 (if feasible and start within window)
def try_place_train(route_key, start_t0, allow_start_window_end):
    info = route_infos[route_key]
    edges = info['edges']
    durations = info['durations']
    # compute entry times & tail clears per segment, and arrival centers for nodes
    times = []  # (u,v, entry_front, exit_front, tail_clear, dir_sign)
    cur_t = start_t0
    # also compute arrival_center times at nodes (front arrival)
    node_arrivals = []
    for idx, (u, v, lm) in enumerate(edges):
        # direction canonical
        seg = canonical_seg(u, v)
        # compute direction sign relative to canonical
        dir_sign = +1 if (u == seg[0] and v == seg[1]) else -1
        entry_front = cur_t
        travel_sec = durations[idx]
        exit_front = entry_front + timedelta(seconds=travel_sec)
        tail_clear = exit_front + timedelta(seconds=(TRAIN_LENGTH_M / SPEED_MPS))
        times.append((u, v, entry_front, exit_front, tail_clear, dir_sign))
        # arrival center to next node is exit_front (front reaches next node center)
        node_arrivals.append((v, exit_front, tail_clear))  # node v arrival on exiting this segment
        cur_t = exit_front
    # node arrival for first node (source): front is at source at start_t0
    source_node = info['path'][0]
    source_tail_clear = start_t0 + timedelta(seconds=(TRAIN_LENGTH_M / SPEED_MPS))
    # Will need to check source node occupancy as well (source arrival = start_t0)
    # Now check constraints on segments and nodes
    # first check source node
    if not can_place_node(source_node, start_t0, source_tail_clear):
        return False, None
    # check segments sequentially
    for (u, v, entry_front, exit_front, tail_clear, dir_sign) in times:
        seg = canonical_seg(u, v)
        # enforce starting window: entry_front must be <= allow_start_window_end if this is first segment
        if entry_front > allow_start_window_end and (u == info['path'][0]):
            return False, None
        if not can_place_segment(seg, dir_sign, entry_front, tail_clear):
            return False, None
    # check nodes: source handled already; now intermediate nodes from node_arrivals
    for (node, arrival_center, tail_clear) in node_arrivals:
        if not can_place_node(node, arrival_center, tail_clear):
            return False, None
    # If all checks passed, reserve all segments and nodes
    # Add source node occupancy
    add_node_occupancy(source_node, start_t0, source_tail_clear)
    # add segment and node occupancy
    for (u, v, entry_front, exit_front, tail_clear, dir_sign) in times:
        seg = canonical_seg(u, v)
        add_segment_occupancy(seg, dir_sign, entry_front, tail_clear)
    for (node, arrival_center, tail_clear) in node_arrivals:
        add_node_occupancy(node, arrival_center, tail_clear)
    # build compact edge_times list to store in schedule
    edge_times = [(u, v, entry_front, tail_clear, dir_sign) for (u, v, entry_front, exit_front, tail_clear, dir_sign) in times]
    return True, edge_times

# round-robin scheduling until no new train placed in a full cycle
start_window_end = START_TIME + timedelta(seconds=START_WINDOW_SECONDS)
od_index = 0
no_progress_rounds = 0
max_no_progress = len(od_list) if len(od_list) > 0 else 1
# To distribute start times, we'll attempt earliest-start (from START_TIME) for each try,
# scanning start_t0 in 60s steps up to start_window_end.
time_step = timedelta(seconds=60)

# Create a list of all (od, route_idx) combinations sorted by route length (prefer short)
od_route_keys = []
for od in od_list:
    s,d = od
    for idx in range(len(feasible_routes[(s,d)])):
        od_route_keys.append((s,d,idx))
# sort by length
od_route_keys.sort(key=lambda k: route_infos[k]['length_m'])

# We'll greedy attempt to place as many trains as possible by repeatedly scanning od_route_keys
placed_any = True
while placed_any and len(scheduled) < MAX_TRAINS_CAP:
    placed_any = False
    for route_key in od_route_keys:
        if len(scheduled) >= MAX_TRAINS_CAP:
            break
        s,d,idx = route_key
        # try to place one train on this route at earliest possible start within window
        t_try = START_TIME
        placed = False
        while t_try <= start_window_end:
            ok, edge_times = try_place_train(route_key, t_try, start_window_end)
            if ok:
                scheduled.append({'od':(s,d), 'route': route_infos[route_key]['path'], 'start_time': t_try, 'edge_times': edge_times})
                placed = True
                placed_any = True
                if DEBUG:
                    print(f"Placed train #{len(scheduled)} {s}->{d} start={t_try} route_hops={len(route_infos[route_key]['path'])} len={route_infos[route_key]['length_m']:.1f}m")
                break
            t_try += time_step
        # if not placed, move to next route_key
    if DEBUG:
        print(f"Iteration complete: scheduled_count={len(scheduled)}")

if DEBUG:
    print(f"Scheduling finished; total trains scheduled: {len(scheduled)}")

# ----------------------------
# Build per-second occupancy log for animation & CSV
# ----------------------------
if not scheduled:
    print("No trains scheduled. Exiting.")
    exit(0)

# find sim end time (allow trains to finish even if they start in first hour)
sim_end = START_TIME
for tr in scheduled:
    last_tail = max(et[3] for et in tr['edge_times'])
    sim_end = max(sim_end, last_tail)
total_seconds = int(max(1, (sim_end - START_TIME).total_seconds()))
timestamps = [START_TIME + timedelta(seconds=i) for i in range(total_seconds+1)]

# helper to get front position for a train at time t
def train_front_position(train, t):
    st = train['start_time']
    if t < st:
        return None
    # if passed last tail clearing -> finished at last node
    last_tail = max(et[3] for et in train['edge_times'])
    if t >= last_tail:
        last_node = train['route'][-1]
        return (G.nodes[last_node]['x'], G.nodes[last_node]['y'], 'finished', None, None, None)
    # find current edge where entry <= t <= tail_clear
    for i, (u, v, entry_front, tail_clear, dir_sign) in enumerate(train['edge_times']):
        # compute exit_front approx = entry_front + (length / speed)
        length = next((lm for (uu,vv,lm) in route_infos[(train['od'][0],train['od'][1], 0)]['edges'] if (uu == u and vv == v) or (uu == v and vv == u)), None)
        # the above tries to find the length in route_infos - but because route_infos keys used earlier may not match index 0 always,
        # we instead fetch from train's route edges (we can map edges lengths from G)
        # safer approach:
        # find length from graph
        length = G.edges[u, v]['length_m']
        travel_sec = length / SPEED_MPS
        exit_front = entry_front + timedelta(seconds=travel_sec)
        if entry_front <= t <= tail_clear:
            if t <= exit_front:
                denom = (exit_front - entry_front).total_seconds()
                frac = (t - entry_front).total_seconds() / denom if denom > 0 else 1.0
            else:
                frac = 1.0
            x_u, y_u = G.nodes[u]['x'], G.nodes[u]['y']
            x_v, y_v = G.nodes[v]['x'], G.nodes[v]['y']
            x = x_u + (x_v - x_u) * frac
            y = y_u + (y_v - y_u) * frac
            return (x, y, 'moving', (u,v), entry_front, tail_clear)
    return None

# build records
records = []
for t in timestamps:
    for idx, train in enumerate(scheduled, start=1):
        meta = train_front_position(train, t)
        if meta is None:
            continue
        x, y, state, edge_id, entry, exit = meta
        records.append({
            'timestamp': t.isoformat(),
            'train_id': idx,
            'x_m': x,
            'y_m': y,
            'node': train['route'][-1] if state == 'finished' else '',
            'state': state,
            'entry_time': entry.isoformat() if entry else '',
            'exit_time': exit.isoformat() if exit else ''
        })

# write occupancy CSV
fieldnames = ['timestamp','train_id','x_m','y_m','node','state','entry_time','exit_time']
with open(OUT_OCC, 'w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    for r in records:
        writer.writerow(r)
if DEBUG:
    print(f"Wrote occupancy log to {OUT_OCC} (records: {len(records)})")

print(f"Maximum number of trains scheduled safely (within 1 hour starts): {len(scheduled)}")

# ----------------------------
# Prepare animation (compress sim timeline to playback duration)
# ----------------------------
# plotting coordinates are original x,y
x_vals = [G.nodes[n]['x'] for n in G.nodes]
y_vals = [G.nodes[n]['y'] for n in G.nodes]
xmin, xmax = min(x_vals), max(x_vals)
ymin, ymax = min(y_vals), max(y_vals)
xpad = max(1.0, 0.05 * (xmax - xmin + 1.0))
ypad = max(1.0, 0.05 * (ymax - ymin + 1.0))

fig, ax = plt.subplots(figsize=(12,6))
ax.set_xlim(xmin - xpad, xmax + xpad)
ax.set_ylim(ymin - ypad, ymax + ypad)
ax.set_aspect('equal')
ax.set_title("Scheduled Trains Animation")

# draw network
pos = {n:(G.nodes[n]['x'], G.nodes[n]['y']) for n in G.nodes}
nx.draw_networkx_nodes(G, pos=pos, node_size=80, node_color='lightgray', ax=ax)
nx.draw_networkx_edges(G, pos=pos, ax=ax, edge_color='gray')
nx.draw_networkx_labels(G, pos=pos, font_size=8, ax=ax)

# build frames
FRAMES = int(ANIM_DURATION_SECONDS * FPS)
INTERVAL_MS = int(1000 / FPS)

sim_start = timestamps[0]
sim_end_dt = timestamps[-1]
sim_total_seconds = (sim_end_dt - sim_start).total_seconds() if sim_end_dt > sim_start else 1.0
frame_sim_times = [sim_start + timedelta(seconds=(i/(FRAMES-1)) * sim_total_seconds) for i in range(FRAMES)]

# index records by second for quick lookup
records_by_dt = defaultdict(list)
for r in records:
    dt = datetime.fromisoformat(r['timestamp']).replace(microsecond=0)
    records_by_dt[dt].append(r)

def find_frame_records(sim_t):
    k = sim_t.replace(microsecond=0)
    if k in records_by_dt:
        return records_by_dt[k]
    # fallback +/-1s
    for d in (k - timedelta(seconds=1), k + timedelta(seconds=1)):
        if d in records_by_dt:
            return records_by_dt[d]
    return []

# assign colors to trains
train_ids = sorted({r['train_id'] for r in records})
cmap = plt.get_cmap('tab20', max(1, len(train_ids)))
train_colors = {tid: cmap(i % max(1, len(train_ids))) for i, tid in enumerate(train_ids)}

scatter = ax.scatter([], [], s=100, zorder=5)
labels = {tid: ax.text(0,0,"", fontsize=8, weight='bold') for tid in train_ids}

def update(frame_idx):
    sim_t = frame_sim_times[frame_idx]
    recs = find_frame_records(sim_t)
    pts = []
    cols = []
    # clear labels
    for lab in labels.values():
        lab.set_text("")
    for r in recs:
        x = float(r['x_m']); y = float(r['y_m']); tid = r['train_id']
        pts.append((x,y)); cols.append(train_colors[tid])
        labels[tid].set_position((x + 0.5, y + 0.5))
        labels[tid].set_text(str(tid))
        labels[tid].set_color(train_colors[tid])
    if pts:
        scatter.set_offsets(np.array(pts))
        scatter.set_color(cols)
    else:
        scatter.set_offsets(np.empty((0,2)))
        scatter.set_color([])
    ax.set_title(f"Simulation @ {sim_t.strftime('%Y-%m-%d %H:%M:%S')}   Scheduled: {len(scheduled)}")
    return (scatter, *labels.values())

ani = animation.FuncAnimation(fig, update, frames=FRAMES, interval=INTERVAL_MS, blit=False, repeat=False)

# Try to save MP4 first (requires ffmpeg). If fails, save GIF (Pillow).
saved = False
try:
    print("Saving animation to MP4 (this may take a moment)...")
    ani.save(OUT_ANIM_MP4, fps=FPS, dpi=150)
    print(f"Saved MP4 to {OUT_ANIM_MP4}")
    saved = True
except Exception as e:
    print("MP4 save failed:", str(e))
    try:
        print("Saving animation to GIF (fallback)...")
        ani.save(OUT_ANIM_GIF, fps=FPS, dpi=100)
        print(f"Saved GIF to {OUT_ANIM_GIF}")
        saved = True
    except Exception as e2:
        print("GIF save failed:", str(e2))
        print("Animation will still be shown interactively (if running in environment that supports it).")

plt.show()


MovieWriter ffmpeg unavailable; using Pillow instead.


Graph constructed: 62 nodes, 73 edges
Loaded 4 OD pairs (requests)
OD 2->58: found 0 feasible routes
OD 3->58: found 0 feasible routes
OD 59->1: found 0 feasible routes
OD 59->3: found 0 feasible routes
Wrote feasible routes to feasible_routes.txt
Iteration complete: scheduled_count=0
Scheduling finished; total trains scheduled: 0
No trains scheduled. Exiting.
Wrote occupancy log to occupancy_log.csv (records: 0)
Maximum number of trains scheduled safely (within 1 hour starts): 0
Saving animation to MP4 (this may take a moment)...


MovieWriter ffmpeg unavailable; using Pillow instead.


MP4 save failed: unknown file extension: .mp4
Saving animation to GIF (fallback)...
