In [1]:
import pandas as pd
import numpy as np
import networkx as nx
import copy
from src.util import *
from src.route_planning import *
from src.transport_network import *
from src.delay_model import *
from sanity_test.graph_test import *

In [2]:
import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go
import plotly.express as px

In [3]:
sbb_network = TransportNetwork('data/sbb_timetable_stop_times.parquet', 'data/stop_to_stop.csv', 'data/stops.csv')
stops = sbb_network.stops
stop_to_stop = sbb_network.stop_to_stop
id_to_stop = stops.set_index('stop_id')['stop_name'].to_dict()
stop_to_id = stops.set_index('stop_name')['stop_id'].to_dict()

In [4]:
stop_info = stops.set_index('stop_id')[['stop_lat', 'stop_lon', 'stop_name']].to_dict('index')

In [5]:
# only one run for one objective id
grouped_istdaten = process_istdaten_data("data/sbb_real_stop_times.parquet")

In [6]:
def count_transfers(path):
    return sum(1 for step in path if step[0].endswith('-transfer'))


def calculate_walking_time(path):
    walking_time = 0
    for node, predecessor, time in path:
        if predecessor is not None and predecessor[2] == 'walking':
            walking_time += time - time_to_minutes2(predecessor[1])
    return int(walking_time)


def print_paths(paths, id_to_stop, stop_info, output_widgets, accordion, expected_time):
    
    sorted_paths = sorted(paths, key=lambda x: (x[1], count_transfers(x[0]), calculate_walking_time(x[0])))
    
    for i, (output, (path, cost)) in enumerate(zip(output_widgets, sorted_paths)):
        with output:
            transfers = count_transfers(path)
            walking_time = calculate_walking_time(path)
            confidence = route_confidence(path, expected_time, grouped_istdaten)
            print(f"Path {i + 1}: Cost: {int(cost)} minutes, Transfers: {transfers}, Walking: {walking_time} minutes, Confidence: {confidence}")
            for node, predecessor, time in path:
                if predecessor is None:
                    print(f"Depart from {id_to_stop[node]}({node}) at {minutes_to_hours(time)}: ")
                else:
                    try:
                        transport_mode = predecessor[2]
                        if transport_mode not in ['transfer', 'walking']:
                            transport_mode = transport_mode.split('-')[1]
                        print(f"            {id_to_stop[node]}({node}) at {minutes_to_hours(time)} via {transport_mode}")
                    except:
                        print(f"            unknown stop({node}) at {minutes_to_hours(time)} via {transport_mode}")
            arrival_node, _, arrival_time = path[-1]
            print(f"Arrival at {id_to_stop[arrival_node]}({arrival_node}) at {minutes_to_hours(arrival_time)}")
            plot_routes([sorted_paths[i]], stop_info)
        accordion.set_title(i, f"Path {i+1}, Cost: {int(cost)} minutes, Transfers: {transfers}, Walking: {walking_time} minutes, Confidence: {confidence}")


def plot_routes(paths, stop_info):
    fig = go.Figure()
    # Color mapping
    trip_colors = {
        'walking': '#9CBA35'
    }
    # Get all unique trip_ids and assign colors
    all_trip_ids = {step[1][2] for path, _ in paths for step in path if step[1] is not None and step[1][2] not in ['walking', 'transfer']}
    trip_colors.update({trip_id: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] for i, trip_id in enumerate(all_trip_ids)})

    for path, cost in paths:
        latitudes = []
        longitudes = []
        hover_texts = []

        for idx, (node, predecessor, time) in enumerate(path):
            if node.endswith('-transfer'):
                if predecessor is not None:
                    predecessor_index = len(hover_texts) - 1
                    hover_texts[predecessor_index] += ", and then transfer."
                continue        
            try:
                coords = stop_info[node]
            except KeyError:
                prev_node = path[idx][1][0]
                next_node = path[idx + 1][0]
                try:
                    prev_name = stop_info[prev_node]['stop_name']
                    flag_p = 0
                except KeyError:
                    prev_name = prev_node
                    flag_p = 1
                try:
                    next_name = stop_info[next_node]['stop_name']
                    flag_n = 0
                except KeyError:
                    next_name = next_node
                    flag_n = 1
                print(f"Stop {node} not found in stop_info. Skipping this stop. Its previous stop: {'missing stop-' + prev_name if flag_p else prev_name}, next stop: {'missing stop-' + next_name if flag_n else next_name}.")
                continue

            latitudes.append(coords['stop_lat'])
            longitudes.append(coords['stop_lon'])

            if predecessor is None:
                hover_texts.append(f"Depart from {coords['stop_name']} at {minutes_to_hours(int(time))}.")
            else:
                trip_type = predecessor[2]
                hover_texts.append(f"{coords['stop_name']} at {minutes_to_hours(int(time))} via {trip_type}")

                pred_node = predecessor[0]
                if pred_node.endswith('-transfer'):
                    pred_node = pred_node.replace('-transfer', '')

                # Draw the route lines
                try:
                    fig.add_trace(go.Scattermapbox(
                        mode="lines",
                        lat=[stop_info[pred_node]['stop_lat'], coords['stop_lat']],
                        lon=[stop_info[pred_node]['stop_lon'], coords['stop_lon']],
                        line=dict(color=trip_colors[trip_type], width=3),
                        hoverinfo="text",
                        text=f"{stop_info[pred_node]['stop_name']} to {coords['stop_name']} via {trip_type}",
                        showlegend=False
                    ))
                except KeyError:
                    continue;
                    
        # Draw the stops
        for i, (lat, lon, hover_text) in enumerate(zip(latitudes, longitudes, hover_texts)):
            if i == 0 or i == len(latitudes) - 1:
                marker_color = 'Black'  # Start or end point
                marker_symbol = 'circle'  
            elif 'transfer' in hover_text:
                marker_color = 'red'  # Transfer stop
                marker_symbol = 'circle'  
            else:
                marker_color = 'Blue'  # Intermediate stop
                marker_symbol = 'circle'  

            fig.add_trace(go.Scattermapbox(
                lat=[lat],
                lon=[lon],
                mode='markers',
                showlegend=False,
                marker=go.scattermapbox.Marker(size=9, color=marker_color, symbol=marker_symbol),
                text=hover_text,
                hoverinfo='text'
            ))

    # Add stop legends
    fig.add_trace(go.Scattermapbox(
        lat=[None], lon=[None], mode='markers',
        marker=go.scattermapbox.Marker(size=9, color='black', symbol='circle'),
        showlegend=True,
        name='Start/End Station',
    ))
    fig.add_trace(go.Scattermapbox(
        lat=[None], lon=[None], mode='markers',
        marker=go.scattermapbox.Marker(size=9, color='blue', symbol='circle'),
        showlegend=True,
        name='Intermediate Station'
    ))
    fig.add_trace(go.Scattermapbox(
        lat=[None], lon=[None], mode='markers',
        marker=go.scattermapbox.Marker(size=9, color='red', symbol='circle'),
        showlegend=True,
        name='Transfer Station'
    ))

    # Add trip legends
    for trip_type, color in trip_colors.items():
        if trip_type not in ['transfer', 'walking']:
            trip_type = trip_type.split('-')[1]
        fig.add_trace(go.Scattermapbox(
            lat=[None], lon=[None], mode='lines',
            line=dict(color=color, width=4),
            showlegend=True,
            name=f"{trip_type}",
        ))

    fig.update_layout(
        width=1000,
        height=600,
        mapbox_style="open-street-map",
        mapbox=dict(
            center=go.layout.mapbox.Center(
                lat=latitudes[0],
                lon=longitudes[0]
            ),
            zoom=12
        ),
    )

    fig.show()


def show_paths(paths, id_to_stop, stop_info, expected_time):
    # Create five output text areas for each drawer's content
    output_widgets = [widgets.Output() for _ in range(len(paths))]

    # Create five drawers and set their children to the corresponding output text areas
    accordion = widgets.Accordion(children=output_widgets)

    # Call the function to display path information
    print_paths(paths, id_to_stop, stop_info, output_widgets, accordion, expected_time)

    # Display the accordion widget
    display(accordion)


In [7]:
# date_picker = widgets.DatePicker(
#     description='Date:',
#     disabled=False
# )

# objectID = widgets.Combobox(
#     placeholder='City ID',
#     options=cities,
#     description='City: ',
#     ensure_option=True,
#     disabled=False
# )

departure_station = widgets.Combobox(
    placeholder='Type or select',
    options=stops.stop_name.to_list(),
    description='From:',
    ensure_option=True,
    disabled=False
)

destination_station = widgets.Combobox(
    placeholder='Type or select',
    options=stops.stop_name.to_list(),
    description='To:',
    ensure_option=True,
    disabled=False
)

depart_time = widgets.Combobox(
    placeholder='Choose Time (HH:MM)',
    options=[f"{i:02}:{j:02}" for i in range(24) for j in range(0,60)],  
    description='Depart:',
    ensure_option=True,
    disabled=False
)

expected_arrival_time = widgets.Combobox(
    placeholder='Choose Time (HH:MM)',
    options=[f"{i:02}:{j:02}" for i in range(24) for j in range(0,60)],
    description='Arrival at: ',
    ensure_option=True,
    disabled=False
)

min_confidence_level = widgets.FloatSlider(
    value=0.9,
    min=0,
    max=1.0,
    step=0.05,
    description='Min Confidence Level:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style={'description_width': 'initial'},
    layout={'width': '30%'}
)


button = widgets.Button(description="Find Routes")

output = widgets.Output()

def on_button_clicked(b):
    with output:
        output.clear_output()

        start_time = depart_time.value
        departure = departure_station.value
        destination = destination_station.value
        expected_time = expected_arrival_time.value
        min_confidence = min_confidence_level.value
        
        G = sbb_network.build_graph(start_time, expected_time)
        departure_id = stop_to_id[departure]
        destination_id = stop_to_id[destination]
        
        paths = yen_ksp(G, start_time, departure_id, destination_id, K=5)

        confidence_list = []
        for route in paths:
            confidence = route_confidence(route[0], expected_time, grouped_istdaten)
            confidence_list.append(confidence)
        filtered_confidence = [(i, conf) for i, conf in enumerate(confidence_list) if conf >= min_confidence]
        filtered_paths = [paths[i] for i, conf in filtered_confidence]        
        if len(filtered_paths) == 0:
            print('Oops, you cannot make it. Choose a different time or confidence level')
        else:
            show_paths(filtered_paths, id_to_stop, stop_info, expected_time)

button.on_click(on_button_clicked)

display(departure_station, destination_station, depart_time, expected_arrival_time, min_confidence_level, button, output)

Combobox(value='', description='From:', ensure_option=True, options=('Belmont-sur-L., Blessoney', 'Belmont-sur…

Combobox(value='', description='To:', ensure_option=True, options=('Belmont-sur-L., Blessoney', 'Belmont-sur-L…

Combobox(value='', description='Depart:', ensure_option=True, options=('00:00', '00:01', '00:02', '00:03', '00…

Combobox(value='', description='Arrival at: ', ensure_option=True, options=('00:00', '00:01', '00:02', '00:03'…

FloatSlider(value=0.9, continuous_update=False, description='Min Confidence Level:', layout=Layout(width='30%'…

Button(description='Find Routes', style=ButtonStyle())

Output()