# Orient Express &ndash; Preprocessing

In [1]:
%%configure -f
{"driverMemory": "4g",
"executorMemory": "4g",
"executorCores": 25,
"numExecutors": 16,
"conf": {"spark.app.name": "orientexpress_preprocessing"},
"kind": "pyspark"}

ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
7933,application_1589299642358_2451,pyspark,idle,Link,Link,
7978,application_1589299642358_2498,pyspark,idle,Link,Link,
7992,application_1589299642358_2514,pyspark,idle,Link,Link,
7994,application_1589299642358_2517,pyspark,idle,Link,Link,
8002,application_1589299642358_2525,pyspark,idle,Link,Link,
8004,application_1589299642358_2527,pyspark,idle,Link,Link,
8008,application_1589299642358_2531,pyspark,idle,Link,Link,
8013,application_1589299642358_2536,pyspark,idle,Link,Link,
8018,application_1589299642358_2541,pyspark,idle,Link,Link,
8019,application_1589299642358_2542,pyspark,idle,Link,Link,


### Start Spark

In [2]:
# Initialization

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
8046,application_1589299642358_2566,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
%%info

ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
7933,application_1589299642358_2451,pyspark,idle,Link,Link,
7978,application_1589299642358_2498,pyspark,idle,Link,Link,
7992,application_1589299642358_2514,pyspark,idle,Link,Link,
7994,application_1589299642358_2517,pyspark,idle,Link,Link,
8002,application_1589299642358_2525,pyspark,idle,Link,Link,
8004,application_1589299642358_2527,pyspark,idle,Link,Link,
8008,application_1589299642358_2531,pyspark,idle,Link,Link,
8013,application_1589299642358_2536,pyspark,idle,Link,Link,
8018,application_1589299642358_2541,pyspark,idle,Link,Link,
8019,application_1589299642358_2542,pyspark,idle,Link,Link,


## 1. Read data from HDFS

We read and check the schemas of the data we find to be necessary for our implementation.

In [4]:
def path(obj):
    return '/data/sbb/timetables/csv/{0}/2019/05/14/{0}.txt'.format(obj)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
calendar = spark.read.csv(path('calendar'), header=True, inferSchema=True)
routes = spark.read.csv(path('routes'), header=True, inferSchema=True)
stop_times = spark.read.csv(path('stop_times'), header=True, inferSchema=True)
stops = spark.read.csv(path('stops'), header=True, inferSchema=True)
transfers = spark.read.csv(path('transfers'), header=True, inferSchema=True)
trips = spark.read.csv(path('trips'), header=True, inferSchema=True)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
calendar.count(),routes.count(),stops.count(),stop_times.count(),transfers.count(),trips.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(22615, 5026, 30631, 11128930, 25274, 1017413)

In [7]:
calendar.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- service_id: string (nullable = true)
 |-- monday: integer (nullable = true)
 |-- tuesday: integer (nullable = true)
 |-- wednesday: integer (nullable = true)
 |-- thursday: integer (nullable = true)
 |-- friday: integer (nullable = true)
 |-- saturday: integer (nullable = true)
 |-- sunday: integer (nullable = true)
 |-- start_date: integer (nullable = true)
 |-- end_date: integer (nullable = true)

In [8]:
routes.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- route_id: string (nullable = true)
 |-- agency_id: string (nullable = true)
 |-- route_short_name: string (nullable = true)
 |-- route_long_name: string (nullable = true)
 |-- route_desc: string (nullable = true)
 |-- route_type: integer (nullable = true)

In [9]:
stop_times.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- trip_id: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- stop_sequence: integer (nullable = true)
 |-- pickup_type: integer (nullable = true)
 |-- drop_off_type: integer (nullable = true)

In [10]:
stops.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- stop_id: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- stop_lat: double (nullable = true)
 |-- stop_lon: double (nullable = true)
 |-- location_type: integer (nullable = true)
 |-- parent_station: string (nullable = true)

In [11]:
trips.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- route_id: string (nullable = true)
 |-- service_id: string (nullable = true)
 |-- trip_id: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- trip_short_name: integer (nullable = true)
 |-- direction_id: integer (nullable = true)

## 2. Select stations

We are interested in stations that are situated within a 15km radius from Zurich HB. We store these stations in `perimeter_stops`.

In [12]:
import pyspark.sql.functions as F
import math
from pyspark.sql.types import FloatType

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
from geopy import distance

ZURICH_COORD = (47.378177, 8.540192)

@F.udf(returnType=FloatType())
def coord_distance_udf(lat, long):
    return distance.distance(ZURICH_COORD,(lat, long)).km

@F.udf(returnType=FloatType())
def transfer_distance_udf(lat1, long1, lat2, long2):
    return distance.distance((lat1,long1),(lat2, long2)).m

@F.udf
def get_type(obj):
    return type(obj)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [14]:
stops.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+----------------+-----------------+-------------+--------------+
|stop_id|           stop_name|        stop_lat|         stop_lon|location_type|parent_station|
+-------+--------------------+----------------+-----------------+-------------+--------------+
|1322000|            Altoggio|46.1672513851495|   8.345807131427|         null|          null|
|1322001|        Antronapiana| 46.060121674738| 8.11361957990831|         null|          null|
|1322002|              Anzola|45.9898698225697| 8.34571729989858|         null|          null|
|1322003|              Baceno|46.2614983591677| 8.31925293162473|         null|          null|
|1322004|Beura Cardezza, c...|46.0790618438814| 8.29927439970313|         null|          null|
|1322005|Bognanco, T. Vill...|46.1222963432243| 8.21077237789936|         null|          null|
|1322006|           Boschetto|46.0656504576122| 8.26113193273411|         null|          null|
|1322007|            Cadarese|46.2978807772998|  8

In [15]:
#Select stations that are at a 15km radius from Zurich HB
perimeter_stops = stops.withColumn('zh_dist', coord_distance_udf(stops.stop_lat, stops.stop_lon))\
                       .where('zh_dist < 15').select('stop_id', 'stop_name', 'stop_lat', 'stop_lon')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [16]:
perimeter_stops.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

1880

In [17]:
perimeter_stops.write.parquet('/user/tvaucher/stops', mode='overwrite', compression='gzip')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 3. Stop times

We select the stop times for the stations identified in `perimeter_stops`, and store them in `perimeter_stop_times`. We also store the trip IDs that are inside the radius in `perimeter_trip_ids`.

In [18]:
stop_times.show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
|             trip_id|arrival_time|departure_time|    stop_id|stop_sequence|pickup_type|drop_off_type|
+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
|1.TA.1-1-B-j19-1.1.R|    04:20:00|      04:20:00|8500010:0:3|            1|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:24:00|      04:24:00|8500020:0:3|            2|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:28:00|      04:28:00|8500021:0:5|            3|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:30:00|      04:30:00|8517131:0:2|            4|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:32:00|      04:32:00|8500300:0:5|            5|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:35:00|      04:35:00|8500313:0:2|            6|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:37:00|      04:38:00|8500301:0:3|           

In [19]:
#Select stop times with stops within the 15km Zurich HB distance
perimeter_stop_times = stop_times.join(perimeter_stops, on='stop_id', how='leftsemi')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [20]:
# Select the trip ids that are inside the radius
perimeter_trip_ids = perimeter_stop_times.select(F.col('trip_id').alias('perimeter_ids')).dropDuplicates()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 4. Transfers

We select the transfers where, for each (from stop ID, to stop ID) pair, both stop IDs are in `perimeter_stops`.

In [21]:
transfers = (perimeter_stops.select(F.col('stop_id').alias('from_stop_id')).distinct())\
            .crossJoin(perimeter_stops.select(F.col('stop_id').alias('to_stop_id')).distinct())\
            .where('from_stop_id != to_stop_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [22]:
from_stops = perimeter_stops.select(\
                       F.col('stop_id').alias('stop_id1'),\
                       F.col('stop_lat').alias('from_stop_lat'),\
                       F.col('stop_lon').alias('from_stop_lon'))
to_stops = perimeter_stops.select(\
                       F.col('stop_id').alias('stop_id2'),\
                       F.col('stop_lat').alias('to_stop_lat'),\
                       F.col('stop_lon').alias('to_stop_lon'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [23]:
transfers = transfers.join(from_stops, transfers.from_stop_id == from_stops.stop_id1, how='inner')\
         .join(to_stops, transfers.to_stop_id == to_stops.stop_id2, how='inner')\
         .select('from_stop_id', 'to_stop_id', 'from_stop_lat', 'from_stop_lon', 'to_stop_lat', 'to_stop_lon')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [24]:
transfers = transfers.withColumn('distance', transfer_distance_udf(transfers.from_stop_lat,\
                                 transfers.from_stop_lon,\
                                 transfers.to_stop_lat,\
                                 transfers.to_stop_lon)).where('distance <= 500')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [25]:
transfers = transfers.select('from_stop_id', 'to_stop_id', 'distance')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

We also calculate the time it takes to walk for each transfer, in seconds.

In [26]:
SPEED = 50.0 / 60.0
MIN_TRANSFER = 120.

transfer_w_time = transfers.withColumn("t_time", (F.col("distance") / SPEED) + MIN_TRANSFER)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [27]:
transfer_w_time.write.parquet('/user/tvaucher/transfers', mode='overwrite', compression='gzip')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 5. Trips

We select trips with trip IDs that are inside the radius.

In [28]:
trips.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+--------------------+------------------+---------------+------------+
|   route_id|service_id|             trip_id|     trip_headsign|trip_short_name|direction_id|
+-----------+----------+--------------------+------------------+---------------+------------+
|1-1-C-j19-1|  TA+b0001|5.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            108|           1|
|1-1-C-j19-1|  TA+b0001|7.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            112|           1|
|1-1-C-j19-1|  TA+b0001|9.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            116|           1|
|1-1-C-j19-1|  TA+b0001|11.TA.1-1-C-j19-1...|Zofingen, Altachen|            120|           1|
|1-1-C-j19-1|  TA+b0001|13.TA.1-1-C-j19-1...|Zofingen, Altachen|            124|           1|
+-----------+----------+--------------------+------------------+---------------+------------+
only showing top 5 rows

In [29]:
#Select trips that pass through the 15km perimeter around Zurich HB
perimeter_trips = trips.join(perimeter_trip_ids, on=perimeter_trip_ids.perimeter_ids == trips.trip_id, how='leftsemi')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [30]:
perimeter_routes_ids = perimeter_trips.select('route_id').dropDuplicates()
perimeter_service_ids = perimeter_trips.select('service_id').dropDuplicates()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 6. Routes

In [31]:
routes.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+---------+----------------+---------------+----------+----------+
|   route_id|agency_id|route_short_name|route_long_name|route_desc|route_type|
+-----------+---------+----------------+---------------+----------+----------+
|11-40-j19-1|      801|             040|           null|       Bus|       700|
|11-61-j19-1|     7031|             061|           null|       Bus|       700|
|11-62-j19-1|     7031|             062|           null|       Bus|       700|
|24-64-j19-1|      801|             064|           null|       Bus|       700|
|11-83-j19-1|      801|             083|           null|       Bus|       700|
+-----------+---------+----------------+---------------+----------+----------+
only showing top 5 rows

In [32]:
# Select routes within the 15km around Zurich HB perimeter
perimeter_routes = routes.join(perimeter_routes_ids, on='route_id', how="leftsemi")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 7. Services

We select services that are available on regular business days, within a 15km radius around Zurich HB. We also only consider the schedule from 7:00 (included) until 19:00 (not included), so we filter the appropriate dataframes based on these conditions.

In [33]:
calendar.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------+---------+--------+------+--------+------+----------+--------+
|service_id|monday|tuesday|wednesday|thursday|friday|saturday|sunday|start_date|end_date|
+----------+------+-------+---------+--------+------+--------+------+----------+--------+
|  TA+b0nx9|     1|      1|        1|       1|     1|       0|     0|  20181209|20191214|
|  TA+b03bf|     1|      1|        1|       1|     1|       0|     0|  20181209|20191214|
|  TA+b0008|     1|      1|        1|       1|     1|       0|     0|  20181209|20191214|
|  TA+b0nxg|     1|      1|        1|       1|     1|       0|     0|  20181209|20191214|
|  TA+b08k4|     1|      0|        0|       0|     0|       0|     0|  20181209|20191214|
+----------+------+-------+---------+--------+------+--------+------+----------+--------+
only showing top 5 rows

In [34]:
# Select services available from Monday to Friday
all_week_calendar = calendar.where("monday = 1 and tuesday = 1 and wednesday = 1 and thursday = 1 and friday = 1")\
                            .select('service_id')

# Select services that are within the 15km Zurich HB perimeter
perimeter_week_service_ids = all_week_calendar.join(perimeter_service_ids, on="service_id", how="leftsemi")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [35]:
# Select trips that happen Monday to Friday and in the perimeter
filtered_trips = perimeter_trips.join(perimeter_week_service_ids, on="service_id", how="leftsemi")\
                                .join(routes.select('route_id', 'route_desc', 'route_short_name'), on="route_id", how="inner")\
                                .select("trip_id", "trip_headsign", "route_desc", "route_short_name")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [36]:
LOWER_TIME = 7  # included
UPPER_TIME = 19 # trips up to 18:59:59
# Select stop times that happen Monday to Friday (between LOWER and UPPER_TIME) in the perimeter
filtered_stop_times = perimeter_stop_times.join(filtered_trips, on="trip_id", how="inner")\
                                          .withColumn('arrival_time_ts', F.to_timestamp('arrival_time', 'HH:mm:ss'))\
                                          .withColumn('departure_time_ts', F.to_timestamp('departure_time', 'HH:mm:ss'))\
                                          .filter((F.hour('arrival_time_ts') >= LOWER_TIME)
                                                  & (F.hour('arrival_time_ts') < UPPER_TIME)
                                                  & (F.hour('departure_time_ts') >= LOWER_TIME)
                                                  & (F.hour('departure_time_ts') < UPPER_TIME)
                                                 )

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [37]:
filtered_stop_times.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

281237

In [38]:
from pyspark.sql import Window

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [39]:
w = Window.partitionBy("trip_id").orderBy("stop_sequence")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

We define a dataframe that contains edges to represent the public transport network. We further explain this in `planning.ipynb`, where we construct the graph representation of the network.

In [40]:
edges = (filtered_stop_times.select("*",
                                    F.lag("departure_time", 1, None).over(w).alias("prev_departure_time"),
                                    F.lag("stop_id", 1, None).over(w).alias("prev_stop_id"))
                            .dropna()
                            .withColumn('prev_departure_time_ts', F.to_timestamp('prev_departure_time', 'HH:mm:ss'))
                            .withColumn('duration', F.unix_timestamp('arrival_time_ts') - F.unix_timestamp('prev_departure_time_ts')))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [41]:
edges.drop('arrival_time_ts', 'departure_time_ts', 'prev_departure_time_ts').write.parquet('/user/tvaucher/edges', mode="overwrite", compression="gzip")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…