In [1]:
!git lfs pull > /dev/null

In [2]:
import datetime
import re

import ipywidgets as widgets
import ipymaterialui as mui
import networkx as nx
import pandas as pd

from orientexpress.planner import RoutePlanner
from orientexpress.utils import print_timestamp
from orientexpress.delay import Delay

%load_ext autoreload
%autoreload 2

In [3]:
COLOR = "#e5191d"

# Read graph and stations data
stops = pd.read_parquet('../data/stops.parquet.gz')
G = nx.read_gpickle('../data/sbb_graph.gpickle')
delay = Delay()

with open("../data/pikachu_loading.gif", "rb") as f:
    loading_img = f.read()

pikachu_loading = widgets.Image(value=loading_img, format='gif')

with open("../data/orient-express.png", "rb") as f:
    img = f.read()

widgets.Image(value=img, layout=widgets.Layout(width="99%"))

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x06@\x00\x00\x01\xf4\x08\x06\x00\x00\x00e\xf0\xe1J\x…

In [4]:
stations_list = stops.groupby('stop_name').stop_id.apply(lambda x: sorted(x, key=len)[0])\
                     .reset_index().apply(lambda row: f"{row['stop_name']} ({row['stop_id']})", axis=1).to_list()

start = widgets.Combobox(placeholder='Choose a Starting Point',
                         options=stations_list,
                         description='Start station',
                         ensure_option=True,
                         layout=widgets.Layout(padding='2rem 1rem', width="auto", height='100px'))

end = widgets.Combobox(placeholder='Choose a Destination',
                       options=stations_list,
                       description='Destination',
                       ensure_option=True,
                       layout=widgets.Layout(padding='2rem 1rem', width="auto", height='100px'))

time_picker = mui.TextField(type="time", label="Arrival time", value="12:30", style_={"width": "100px", "margin": "1rem 2rem"})

confidence = widgets.FloatSlider(value=90, min=50, max=100, step=0.1, description="Confidence", layout=widgets.Layout(padding='2rem 0 2rem 1rem', height="100px"))
percent_label = widgets.Label('%', layout=widgets.Layout(padding='2.4rem 1rem 2rem 0', margin="0 0 0 -1.2rem", height="100px"))

widgets_group = widgets.HBox((start, end, time_picker, confidence, percent_label), layout=widgets.Layout(width="80%", margin="auto"))
widgets_group

HBox(children=(Combobox(value='', description='Start station', ensure_option=True, layout=Layout(height='100px…

In [5]:
def get_station_name(stop_id):
    return stops[stops.stop_id == stop_id].iloc[0]['stop_name']

def get_pretty_time(s):
    s = int(round(s))
    minutes, seconds = divmod(s, 60)
    hours, minutes = divmod(minutes, 60)
    if hours > 0:
        return f'{hours}h{minutes:02d}\'{seconds:02d}"'
    elif minutes > 0:
        return f'{minutes}\'{seconds:02d}"'
    else:
        return f'{seconds}"'

def get_small_widget(text):
    return mui.Html(tag="span", children=[text], style_={"font-size":"10px"})

def get_station_id_widget(station_id):
    return get_small_widget(f" ({station_id})")

def get_arrival_time_widget(arr_time):
    return mui.Html(tag='i', style_={'font-size': '0.9em'}, children=[print_timestamp(arr_time)])

def get_arrival_station_widget(station_id):
    return mui.Html(tag='i', style_={'font-size': '0.9em'}, children=[get_station_name(station_id), get_station_id_widget(station_id)])

def get_delay_widget(delay, is_last):
    if not delay:
        return []
    expected_delay, proba = delay
    text = f"Catch next: {proba:.2%}" if not is_last else f"Before deadline: {proba:.2%}"
    return [mui.Html(tag='i', children=[f"Expected: {get_pretty_time(expected_delay)}"]),
            mui.Html(tag='br'),
            mui.Html(tag='span', style_={'font-size': '0.9em'}, children=[text])]

def is_last(i, path, route):
    if i < len(path) - 3:
        return False
    return (i == len(path) - 3 and route[path[-2]][path[-1]]["trip_id"].startswith("TRANSFER")) or (i == len(path) - 2)

GRAY = "#f8f8f8"

def get_rows(path, route, proba):
    rows = []
    current_trip_id = ''
    dep_time, arr_time = '', ''
    for i, (station1, station2) in enumerate(zip(path, path[1:])):
        trip_data = route[station1][station2]
        if i == 0:
            dep_time = trip_data['dep_time']
        if trip_data['trip_id'].startswith("TRANSFER"):
            time = mui.TableCell(children=[mui.Html(tag='i', children=[get_pretty_time(trip_data['t_time'])]),
                                           mui.Html(tag='br'), 'Leave latest: ', print_timestamp(trip_data['dep_time'])])
            station_extra = [f" from {get_station_name(station1)}", get_station_id_widget(station1)] if i == 0 else ''
            stations = mui.TableCell(children=[mui.Icon(children="directions_walk", style_={'position': 'relative', 'top': '5px'}), " Walk", *station_extra])
            rows.append(mui.TableRow(style_={"background-color": GRAY}, children=[time, stations, mui.TableCell()]))
            current_trip_id = ''
            if arr_time != '':
                arr_time += datetime.timedelta(seconds=trip_data['t_time'])
        else:
            trip_info = trip_data['trip_info']
            arr_time = trip_info['arr_time']
            if trip_data['trip_id'] != current_trip_id:
                current_trip_id = trip_data['trip_id']
                time = mui.TableCell(children=[mui.Html(tag='b', children=[print_timestamp(trip_data['dep_time'])]),
                                               mui.Html(tag='br'), mui.Html(tag='br'),
                                               get_arrival_time_widget(trip_info['arr_time'])])
                stations = mui.TableCell(children=[mui.Html(tag='b', children=[get_station_name(station1), get_station_id_widget(station1)]),
                                                   mui.Html(tag='br'),
                                                   mui.Html(tag='span', style_={'padding-left': '1em', 'font-size': '0.9em'}, children=[
                                                       f"{trip_info['route_desc']} {trip_info['route_short_name']} headed to {trip_info['trip_headsign']} ",
                                                       get_small_widget(f"({trip_data['trip_id']})")
                                                   ]),
                                                   mui.Html(tag='br'),
                                                   get_arrival_station_widget(station2)])
                delay = mui.TableCell(children=get_delay_widget(trip_info.get("delay"), is_last(i, path, route)))
                rows.append(mui.TableRow(children=[time, stations, delay]))
            else:
                time = mui.TableCell(children=[print_timestamp(trip_data['dep_time']),
                                               mui.Html(tag='br'), get_arrival_time_widget(trip_info['arr_time'])])

                stations = mui.TableCell(children=[get_station_name(station1), get_station_id_widget(station1),
                                                   mui.Html(tag='br'),
                                                   get_arrival_station_widget(station2)])
                delay = mui.TableCell(children=get_delay_widget(trip_info.get("delay"), is_last(i, path, route)))
                rows.append(mui.TableRow(children=[time, stations, delay]))
    final_proba = mui.TableCell(children=[f"Success probability: {proba:.2%}"])
    total_travel_time = (arr_time - dep_time).total_seconds()
    rows.append(mui.TableRow(children=[mui.TableCell(children=[f"Total time: {get_pretty_time(total_travel_time)}"]),
                                       mui.TableCell(children=[mui.Html(tag='b', children=[get_station_name(path[-1]),
                                                                                           get_station_id_widget(path[-1])])]),
                                       final_proba]))
    return rows

def draw_table(rows):
    return mui.Table(children=[mui.TableHead(children=[mui.TableRow(children=[mui.TableCell(children=['Time']),
                                                                              mui.TableCell(children=['Station']),
                                                                              mui.TableCell(children=['Delay'])])]),
                               mui.TableBody(children=rows)])

In [6]:
output = mui.Html(tag='div', children='', style_={"width" :"80%", "margin":"auto"})

def get_error_widget(text):
    return mui.Html(tag='div', style_={"width": "80%", "margin": "auto",
                                       "color": "#D8000C", "background-color": "#FFBABA",
                                       "border": "2px solid", "border-radius": "10px",
                                       "padding": "0.5em", "font-size": "1.1em", "display": "flex","align-items": "center"},
                     children=[mui.Icon(children="error_outline", style_={"margin-right": "0.25em"}), text])

def extract_id(text):
    try:
        return re.search(r'\(([^)]+)\)', text).group(1)
    except:
        return None

def process(widget, event, data):
    # Parse args
    threshold = confidence.value / 100
    if threshold == 1:
        output.children = get_error_widget(f"You can never be 100% sure to be always on time 😉")
        return
   
    start_id = extract_id(start.value)
    end_id = extract_id(end.value)
    if not start_id or not end_id:
        output.children = get_error_widget(f"Illegal start ({start_id}) or end station ({end_id})")
        return
    
    output.children = mui.Html(tag="div", children=[pikachu_loading], style_={"width" :"100px", "margin":"auto"})
    
    # Get route
    planner = RoutePlanner(G, start_id, end_id, time_picker.value)
    path, route, proba = planner.plan_robust(threshold, delay, stops, info=True, print_path=False)
    if not path or not route:
        msg = f"No path found from {get_station_name(start_id)} to {get_station_name(end_id)}"       
        output.children = get_error_widget(msg)
        return
    not_robust_enough = [get_error_widget("No robust enough path found within 20 iterations. Try to reduce the confidence. Here's the most robust path we found for you.")]\
                        if proba < threshold else []
    output.children = not_robust_enough + [mui.Html(tag="h3", children=[f"From {get_station_name(start_id)} ",
                                                                        get_station_id_widget(start_id),
                                                                        f" to {get_station_name(end_id)} ",
                                                                        get_station_id_widget(end_id),
                                                                        f" by {time_picker.value} with probability bigger than {confidence.value:.2f}%"]),
                                                    draw_table(get_rows(path, route, proba))]

execute = mui.Button(center_ripple=True,
                     children=['Get me a trip', mui.Icon(children='commute')],
                     variant="contained",
                     style_={"float": "right", "background-color": COLOR, "color": "white"})
execute.on_event('onClick', process)
layout_button = mui.Html(tag="div", children=execute, style_={"width" :"80%", "margin":"auto"})
layout_button

Html(children=Button(center_ripple=True, children=['Get me a trip', Icon(children='commute')], style_={'float'…

In [7]:
mui.Divider()

Divider()

In [8]:
output

Html(children='', style_={'width': '80%', 'margin': 'auto'}, tag='div')