## Initialize the environment

In [None]:
%load_ext sparkmagic.magics

In [None]:
import os
from IPython import get_ipython
server = "http://iccluster044.iccluster.epfl.ch:8998"
username = os.environ['RENKU_USERNAME']
print(username)

In [None]:
get_ipython().run_cell_magic(
    'spark',
    line='config', 
    cell="""{{ "name": "{0}-homework3", "executorMemory": "4G", "executorCores": 4, "numExecutors": 10, "driverMemory": "4G" }}""".format(username)
)

Send `username` to Spark kernel, which will first start the Spark application if there is no active session.

In [None]:
get_ipython().run_line_magic(
    "spark", f"""add -s {username}-homework3 -l python -u {server} -k"""
)

In [None]:
%%spark
import pyspark.sql.functions as F
print('We are using Spark %s' % spark.version)

# Pre-processing

We picked the week starting April 17th (which is announced by SBB on Aptil 12th) to make our public infrastructure model because there ain't no bank holidays in Switzerland that week. And since our planner only deals with weekdays as same, we started by removing services that ain't available every day of the week.

In [None]:
%%spark
WEEK = 'year=2023/month=04/day=12/'

In [None]:
%%spark
TRIPS_PATH = '/data/sbb/part_csv/timetables/trips/' + WEEK + 'trips.txt'
CALENDAR_PATH = '/data/sbb/part_csv/timetables/calendar/' + WEEK + 'calendar.txt'
STOP_TIMES_PATH = '/data/sbb/part_csv/timetables/stop_times/' + WEEK + 'stop_times.txt'
stops_path = '/data/sbb/part_csv/timetables/stops/' + WEEK + 'stops.txt'

stops = spark.read.csv(stops_path, sep=',', header=True, inferSchema=True)
stop_times = spark.read.csv(STOP_TIMES_PATH, sep=',', header=True, inferSchema=True)
calendar = spark.read.csv(CALENDAR_PATH, sep=',', header=True, inferSchema=True)
trips = spark.read.csv(TRIPS_PATH, sep=',', header=True, inferSchema=True)


In [None]:
%%spark
trips.show(5)
stop_times.show(5)
calendar.show(5)
stops.show(5)

## Keep only services that operate on weekdays

In [None]:
%%spark
weekday_service_ids = calendar.filter(
                  (calendar.monday == 1)
                & (calendar.tuesday  == 1)
                & (calendar.wednesday == 1)
                & (calendar.thursday == 1)
                & (calendar.friday == 1)).select('service_id')

weekday_service_ids.show(5)
weekday_service_ids.count()

In [None]:
%%spark
weekday_trips = trips.join(
    weekday_service_ids,
    on='service_id',
    how='inner'
).distinct()

weekday_trips.show(5)
weekday_trips.count()

## Merge with the stops

In [None]:
%%spark
nodes = weekday_trips.join(stop_times, on='trip_id', how='inner')

nodes.show(5)

In [None]:
%%spark
final_nodes = stops.join(nodes, on='stop_id', how='inner')

final_nodes.show(5)

In [None]:
%%spark
from pyspark.sql.functions import *
nodes_arr = (final_nodes
             .drop("departure_time")
             .withColumnRenamed("arrival_time","time")
             .withColumn("is_arrival",lit(1)))

nodes_dep = (final_nodes
             .drop("arrival_time")
             .withColumnRenamed("departure_time","time")
             .withColumn("is_arrival", lit(0)))

nodes_2 = nodes_arr.union(nodes_dep)

## Gather `stop_id`s around Zurich

In [None]:
%%spark
from pyspark.sql import functions as F
from math import radians, cos, sin, asin, sqrt, atan2

@F.udf
def haversine_distance(lat1, lon1, lat2=47.378177, lon2=8.540192):
    """
    Calculates the Haversine distance between two sets of latitude and longitude coordinates.
    
    Parameters:
    - lat1, lon1: Latitude and longitude of the first point in degrees.
    - lat2, lon2: Latitude and longitude of the second point in degrees. (default: Zurich coordinates)
    
    Returns:
    The Haversine distance between the two points in meters.
    """
    
    # Convert degrees to radians
    lat1_rad, lon1_rad, lat2_rad, lon2_rad = map(radians,[lat1, lon1, lat2, lon2])
    
    # Calculate the differences
    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad
    
    # Apply Haversine formula
    a = sin(dlat / 2) ** 2 + cos(lat1_rad) * cos(lat2_rad) * sin(dlon / 2) ** 2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))
    R = 6371e3 # Earth's radius in meters
    distance = R * c
    
    return distance

In [None]:
%%spark
RADIUS = 15e3
MARGIN = 0.2*RADIUS

stops_zurich = nodes_2.filter(haversine_distance(stops.stop_lat, stops.stop_lon) < RADIUS + MARGIN)
nodes_zurich = stops_zurich.select(
stops_zurich.stop_id,
stops_zurich.stop_name,
stops_zurich.stop_lat,
stops_zurich.stop_lon,
stops_zurich.parent_station).distinct()

In [None]:
%%spark
nodes_zurich.show(6)
nodes_zurich.count()

## Create a Unique ID

In [None]:
%%spark
stops_zurich = stops_zurich.withColumn("unique_stop_id",
                                       F.concat_ws("_", stops_zurich.stop_id,
                                                stops_zurich.time,
                                                stops_zurich.trip_id,
                                                stops_zurich.is_arrival))

In [None]:
%%spark
stops_zh_dist= stops_zurich.select("stop_id", "trip_id","route_id","unique_stop_id","time").distinct()

In [None]:
%%spark

stop_times_zh_arr = (stop_times.join(
    stops_zh_dist.filter(col("is_arrival")==1).select(
        col("stop_id"),
        col("trip_id"),
        col("route_id"),
        col("unique_stop_id"),
        col("time").alias("arrival_time")
), on=["stop_id", "trip_id","arrival_time"], how="inner"))
stop_times_zh_dep = (stop_times.join(
    stops_zh_dist.filter(col("is_arrival")==0).select(
        col("stop_id"),
        col("trip_id"),
        col("route_id"),
        col("unique_stop_id"),
        col("time").alias("departure_time")
), on=["stop_id", "trip_id","departure_time"], how="inner"))

stop_times_zh = stop_times_zh_dep.union(stop_times_zh_arr)

## Building the edges

In [None]:
%%spark
from pyspark.sql.window import Window
from pyspark.sql.functions import to_timestamp, col

stop_times_zh_pairs = stop_times_zh.withColumn('stop_id_dest', F.lead('stop_id').over(Window.partitionBy('trip_id').orderBy([col('stop_sequence').asc(), col("departure_time").asc(), col('unique_stop_id').desc()])))
stop_times_zh_pairs = stop_times_zh_pairs.withColumn('arrival_time_dest', F.lead('arrival_time').over(Window.partitionBy('trip_id').orderBy([col('stop_sequence').asc(), col("departure_time").asc(), col('unique_stop_id').desc()])))
stop_times_zh_pairs = stop_times_zh_pairs.withColumn('unique_stop_id_dest', F.lead('unique_stop_id').over(Window.partitionBy('trip_id').orderBy([col('stop_sequence').asc(), col("departure_time").asc(),col('unique_stop_id').desc()])))


stop_times_zh_pairs = stop_times_zh_pairs.drop('arrival_time').withColumnRenamed('arrival_time_dest', 'arrival_time')

stop_times_zh_pairs = stop_times_zh_pairs.dropna(subset='stop_id_dest')

stop_times_zh_pairs = stop_times_zh_pairs.withColumn('expected_travel_time', F.unix_timestamp(stop_times_zh_pairs.arrival_time, 'HH:mm:ss') - F.unix_timestamp(stop_times_zh_pairs.departure_time, 'HH:mm:ss'))


stop_times_zh_pairs = stop_times_zh_pairs.filter(stop_times_zh_pairs.stop_id != stop_times_zh_pairs.stop_id_dest)

In [None]:
%%spark

stop_times_zh_pairs = stop_times_zh_pairs.select(
    col("unique_stop_id").alias("start_id"),
    col("unique_stop_id_dest").alias("end_id"),
    col("expected_travel_time"),
)

# Split start_id column into separate columns
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("start_id_parts", split(col("start_id"), "_"))
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("start_stop_id", col("start_id_parts")[0])
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("start_time", col("start_id_parts")[1])
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("trip_id", col("start_id_parts")[2])
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("start_is_arrival", col("start_id_parts")[3])

# Split end_id column into separate columns
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("end_id_parts", split(col("end_id"), "_"))
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("end_stop_id", col("end_id_parts")[0])
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("end_time", col("end_id_parts")[1])
stop_times_zh_pairs = stop_times_zh_pairs.withColumn("end_is_arrival", col("end_id_parts")[3])

# Drop the intermediate columns
stop_times_zh_pairs = stop_times_zh_pairs.drop("start_id_parts", "end_id_parts","start_id","end_id", "start_is_arrival","end_is_arrival") 

stop_times_zh_pairs = stop_times_zh_pairs.withColumn("is_walking", lit(0))



In [None]:
%%spark
stop_times_zh_pairs.show(5)
stop_times_zh_pairs.count()

## Adding walking edges

In [None]:
%%spark
from pyspark.sql.functions import col 

time_in_station = 2*60


same_station_different_platform_edges = (
    stops_zurich
    .select(
        stops_zurich.stop_id.alias("stop_1"),
        stops_zurich.parent_station.alias("arr_par"),
        stops_zurich.stop_id.alias("start_stop_id")
    )
    .crossJoin(
        stops_zurich
        .select(
            stops_zurich.stop_id.alias("stop_2"),
            stops_zurich.parent_station.alias("dep_par"),
            stops_zurich.stop_id.alias("end_stop_id")
        )
    )
    .filter(
        col("dep_par").isNotNull() & col("arr_par").isNotNull() &
        (col("arr_par") == col("dep_par")) & (col('start_stop_id') != col('end_stop_id'))
    ).withColumn("transfer_time",lit(time_in_station)))

same_station_different_platform_edges = same_station_different_platform_edges.select(
same_station_different_platform_edges.stop_1,
same_station_different_platform_edges.stop_2,
same_station_different_platform_edges.transfer_time).distinct()
    

                            # .filter((col("waiting_time") < 600) & (col("waiting_time") >= 0)))

different_station_within_walking_distance_edges = (       stops_zurich
                            .select(stops_zurich.stop_id.alias("stop_1"),
                                      stops_zurich.stop_lat.alias("lat_1"),
                                      stops_zurich.stop_lon.alias("lon_1"),
                                      stops_zurich.parent_station.alias("par_1")
                                      ).distinct()
                            .crossJoin(
                                      stops_zurich
                            .select(stops_zurich.stop_id.alias("stop_2"),
                                      stops_zurich.stop_lat.alias("lat_2"),
                                      stops_zurich.stop_lon.alias("lon_2"),
                                      stops_zurich.parent_station.alias("par_2")
                                   )).filter(expr("split(stop_1, ':')[0] != split(stop_2, ':')[0]")).distinct().withColumn("distance", haversine_distance(col("lat_1"), col("lon_1"), col("lat_2"), col("lon_2")))
                .filter((col("distance") <= 500 )&( col("distance") > 0.0)).withColumn("transfer_time", round(col("distance")*(60.0/50),0)))
                
different_station_within_walking_distance_edges = different_station_within_walking_distance_edges.select(
different_station_within_walking_distance_edges.stop_1,
different_station_within_walking_distance_edges.stop_2,
different_station_within_walking_distance_edges.transfer_time).distinct()


In [None]:
%%spark
same_station_different_platform_edges.show(5)
same_station_different_platform_edges.count()


In [None]:
%%spark
different_station_within_walking_distance_edges.show(5)
different_station_within_walking_distance_edges.count()

In [None]:
%%spark

walking_edges = (different_station_within_walking_distance_edges.union(same_station_different_platform_edges)).withColumn("is_walking",lit(1))#
walking_edges=walking_edges.withColumnRenamed("stop_1","start_stop_id").withColumnRenamed("stop_2","end_stop_id").withColumnRenamed("transfer_time","expected_travel_time")
walking_edges.cache()
walking_edges.show(5)
walking_edges.count()

## Grouping all edges

In [None]:
%%spark
# Get the column sets of both DataFrames
columns_df1 = set(walking_edges.columns)
columns_df2 = set(stop_times_zh_pairs.columns)

# Identify the DataFrame with the higher number of columns
if len(columns_df1) >= len(columns_df2):
    higher_columns = columns_df1
    lower_columns = columns_df2
    lower_df = stop_times_zh_pairs
else:
    higher_columns = columns_df2
    lower_columns = columns_df1
    lower_df = walking_edges

# Add default values or nulls to the lower_df DataFrame
for column in higher_columns - lower_columns:
    lower_df = lower_df.withColumn(column, lit(None))  

# Union the DataFrames
complete_edges = stop_times_zh_pairs.unionByName(lower_df)


In [None]:
%%spark
complete_edges.cache()
complete_edges.show(5)
complete_edges.count()

In [None]:
%%spark
nodes_zurich.write.save("/group/grande_envergure/graph/nodes_zurich.orc", format="orc", mode='overwrite')

In [None]:
%%spark
complete_edges.write.save("/group/grande_envergure/graph/complete_edges.orc", format="orc", mode='overwrite')