# Preparing the data

This notebook prepares the data for the predictive modelling part and the network building part.

In [1]:
%%configure
{"conf": {
    "spark.app.name": "MVAY_final"
}}

ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
8491,application_1589299642358_3026,pyspark,idle,Link,Link,
8521,application_1589299642358_3056,pyspark,idle,Link,Link,
8639,application_1589299642358_3164,pyspark,idle,Link,Link,
8642,application_1589299642358_3167,pyspark,idle,Link,Link,
8658,application_1589299642358_3183,pyspark,idle,Link,Link,
8667,application_1589299642358_3192,pyspark,idle,Link,Link,
8677,application_1589299642358_3202,pyspark,busy,Link,Link,
8680,application_1589299642358_3205,pyspark,idle,Link,Link,
8684,application_1589299642358_3209,pyspark,idle,Link,Link,
8685,application_1589299642358_3210,pyspark,idle,Link,Link,


In [2]:
spark

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
8709,application_1589299642358_3234,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<pyspark.sql.session.SparkSession object at 0x7f7a622056d0>

In [3]:
# Import necessary packages
import pyspark.sql.functions as f
import math
from pyspark.sql.types import FloatType, IntegerType

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Load the data

Load the stop_times dataset which includes the routes.

In [4]:
stop_times = spark.read.orc("/data/sbb/timetables/orc/stop_times/")
stop_times.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
|             trip_id|arrival_time|departure_time|    stop_id|stop_sequence|pickup_type|drop_off_type|
+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
|1.TA.1-1-B-j19-1.1.R|    04:20:00|      04:20:00|8500010:0:3|            1|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:24:00|      04:24:00|8500020:0:3|            2|          0|            0|
|1.TA.1-1-B-j19-1.1.R|    04:28:00|      04:28:00|8500021:0:5|            3|          0|            0|
+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
only showing top 3 rows

Load the stops dataset which includes the stops and their coordinates.

In [5]:
stops = spark.read.orc("/data/sbb/timetables/orc/stops/") 
stops.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+------------+----------------+----------------+-------------+--------------+
|stop_id|   stop_name|        stop_lat|        stop_lon|location_type|parent_station|
+-------+------------+----------------+----------------+-------------+--------------+
|1322000|    Altoggio|46.1672513851495|  8.345807131427|             |              |
|1322001|Antronapiana| 46.060121674738|8.11361957990831|             |              |
|1322002|      Anzola|45.9898698225697|8.34571729989858|             |              |
+-------+------------+----------------+----------------+-------------+--------------+
only showing top 3 rows

Load the calendar dataset which includes the schedule for the routes.

In [6]:
calendar = spark.read.orc("/data/sbb/timetables/orc/calendar/") 
calendar.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------+---------+--------+------+--------+------+
|service_id|monday|tuesday|wednesday|thursday|friday|saturday|sunday|
+----------+------+-------+---------+--------+------+--------+------+
|  TA+b0nx9|  true|   true|     true|    true|  true|   false| false|
|  TA+b03bf|  true|   true|     true|    true|  true|   false| false|
|  TA+b0008|  true|   true|     true|    true|  true|   false| false|
+----------+------+-------+---------+--------+------+--------+------+
only showing top 3 rows

Load the trips dataframe needed for retrieving the calendar for different trips.

In [7]:
trips = spark.read.orc("/data/sbb/timetables/orc/trips/") 
trips.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+--------------------+------------------+---------------+------------+
|   route_id|service_id|             trip_id|     trip_headsign|trip_short_name|direction_id|
+-----------+----------+--------------------+------------------+---------------+------------+
|1-1-C-j19-1|  TA+b0001|5.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            108|           1|
|1-1-C-j19-1|  TA+b0001|7.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            112|           1|
|1-1-C-j19-1|  TA+b0001|9.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            116|           1|
+-----------+----------+--------------------+------------------+---------------+------------+
only showing top 3 rows

**TO DO: Is this dataset used?** Load the routes dataframe in order to be able to retrieve the transportation time.

In [8]:
routes =  spark.read.orc("/data/sbb/timetables/orc/routes/") 
routes.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+---------+----------------+---------------+----------+----------+
|   route_id|agency_id|route_short_name|route_long_name|route_desc|route_type|
+-----------+---------+----------------+---------------+----------+----------+
|11-40-j19-1|      801|             040|               |       Bus|       700|
|11-61-j19-1|     7031|             061|               |       Bus|       700|
|11-62-j19-1|     7031|             062|               |       Bus|       700|
+-----------+---------+----------------+---------------+----------+----------+
only showing top 3 rows

Load the transfers dataframe to get the time that it takes to walk between two stops.

In [9]:
transfers = spark.read.orc("/data/sbb/timetables/orc/transfers/")
transfers.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+-----------+-------------+-----------------+
|from_stop_id| to_stop_id|transfer_type|min_transfer_time|
+------------+-----------+-------------+-----------------+
| 8500309:0:2|8500309:0:4|            2|              180|
| 8500309:0:2|8500309:0:5|            2|              180|
| 8500309:0:2|8500309:0:3|            2|              180|
+------------+-----------+-------------+-----------------+
only showing top 3 rows

## Clean the data

### Filter calendar

The calendar dataframe is filtered to only include trips that go every monday-friday.

In [10]:
calendar.show(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------+---------+--------+------+--------+------+
|service_id|monday|tuesday|wednesday|thursday|friday|saturday|sunday|
+----------+------+-------+---------+--------+------+--------+------+
|  TA+b0nx9|  true|   true|     true|    true|  true|   false| false|
+----------+------+-------+---------+--------+------+--------+------+
only showing top 1 row

In [11]:
# Filter for services that have the schedlue mon-fri
calendar_filt = calendar.filter((calendar.monday == True) & (calendar.tuesday == True)
                                & (calendar.wednesday == True) & (calendar.thursday == True) 
                                & (calendar.friday == True) & (calendar.saturday == False) 
                                & (calendar.sunday == False))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Filter trips

The trips dataframe is joined with the filtered calendar dataframe such that it only includes trips that go monday-friday.

In [12]:
trips.show(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+--------------------+------------------+---------------+------------+
|   route_id|service_id|             trip_id|     trip_headsign|trip_short_name|direction_id|
+-----------+----------+--------------------+------------------+---------------+------------+
|1-1-C-j19-1|  TA+b0001|5.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            108|           1|
+-----------+----------+--------------------+------------------+---------------+------------+
only showing top 1 row

**TO DO: Don't filter for trips with mon-fri schedule anymore. Why? And why join with routes? What do we need in routes?**

In [14]:
# Get the service_ids that have the calendar mon-fri
ids = calendar_filt.select('service_id').distinct()

# Join in order to filter for trips with the schedule mon-fri and keep only relevant columns
trips_filt = trips.join(ids, 'service_id', 'inner').select(['service_id', 'route_id', 'trip_id'])
trips_filt = trips_filt.join(routes, 'route_id', 'inner').select(['service_id', 'route_id', 'trip_id','route_desc','route_type'])
trips_filt.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+-----------+--------------------+----------+----------+
|service_id|   route_id|             trip_id|route_desc|route_type|
+----------+-----------+--------------------+----------+----------+
|  TA+b0003|1-305-j19-1|12.TA.1-305-j19-1...|       Bus|       700|
|  TA+b0003|1-305-j19-1|14.TA.1-305-j19-1...|       Bus|       700|
|  TA+b0003|1-305-j19-1|20.TA.1-305-j19-1...|       Bus|       700|
+----------+-----------+--------------------+----------+----------+
only showing top 3 rows

### Filter stops

The stops dataframe is filtered to only include stops within a 15km radius from Zürich HB.

In [15]:
stops.show(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+---------+----------------+--------------+-------------+--------------+
|stop_id|stop_name|        stop_lat|      stop_lon|location_type|parent_station|
+-------+---------+----------------+--------------+-------------+--------------+
|1322000| Altoggio|46.1672513851495|8.345807131427|             |              |
+-------+---------+----------------+--------------+-------------+--------------+
only showing top 1 row

To be able to use our user defined haversine formula, we add two columns to the dataframe that include the latitude and longitude values of Zurich HB and we convert the four columns that include latitude and longitude value to floats. 

In [16]:
# Define the latitude and longitude for Zurich HB
lat_zur = 47.378177
lon_zur = 8.540192

# Add columns with Zurich HB lat & lon values
temp_stops = stops.withColumn('lat_zur', f.lit(lat_zur))
temp_stops = temp_stops.withColumn('lon_zur', f.lit(lon_zur))

# Make sure that the columns for calculating distance are as floats
temp_stops = temp_stops.withColumn('stop_lat', temp_stops.stop_lat.cast("float"))\
                        .withColumn('stop_lon', temp_stops.stop_lon.cast("float"))\
                        .withColumn('lat_zur', temp_stops.lat_zur.cast("float"))\
                        .withColumn('lon_zur', temp_stops.lon_zur.cast("float"))

temp_stops.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+------------+---------+--------+-------------+--------------+---------+--------+
|stop_id|   stop_name| stop_lat|stop_lon|location_type|parent_station|  lat_zur| lon_zur|
+-------+------------+---------+--------+-------------+--------------+---------+--------+
|1322000|    Altoggio| 46.16725|8.345807|             |              |47.378178|8.540192|
|1322001|Antronapiana|46.060123| 8.11362|             |              |47.378178|8.540192|
|1322002|      Anzola| 45.98987|8.345717|             |              |47.378178|8.540192|
+-------+------------+---------+--------+-------------+--------------+---------+--------+
only showing top 3 rows

We are using the haversine formula in a user defined function to calculate the distance in kilometers between two latitude and longitude values.

In [17]:
# Haversine formula for calculating the distance between two stops
def calc_dist(lat_1, lat_2, lon_1, lon_2):
    R = 6371 
    dLat = math.radians(lat_1 - lat_2)
    dLon = math.radians(lon_1 - lon_2)
    a = math.sin(dLat/2) * math.sin(dLat/2) + math.cos(math.radians(lat_2)) \
        * math.cos(math.radians(lat_1)) * math.sin(dLon/2) * math.sin(dLon/2)
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
    d = R * c
    return d

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

The distance from each stop to Zurich HB is calculated and then we filter to only include stops within a 15 km radius.

In [18]:
# User defined function
udf_func = f.udf(calc_dist, FloatType())

# Calculate the distance from each stop to Zurich HB
stops_filt = temp_stops.withColumn('dist_zurich_km', 
                                   udf_func(temp_stops.lat_zur, temp_stops.stop_lat, 
                                            temp_stops.lon_zur, temp_stops.stop_lon))\
                        .drop(*['lat_zur', 'lon_zur'])

# Filter for stops within 15 km radius from Zurich HB
stops_filt = stops_filt.filter(stops_filt.dist_zurich_km <= 15.0)

stops_filt.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+--------------------+---------+--------+-------------+--------------+--------------+
|    stop_id|           stop_name| stop_lat|stop_lon|location_type|parent_station|dist_zurich_km|
+-----------+--------------------+---------+--------+-------------+--------------+--------------+
|    8500926|Oetwil a.d.L., Sc...|47.423626|8.403183|             |              |     11.483568|
|    8502186|Dietikon Stoffelbach|47.393406|8.398943|             |      8502186P|     10.767946|
|8502186:0:1|Dietikon Stoffelbach|47.393467|8.398943|             |      8502186P|      10.76901|
+-----------+--------------------+---------+--------+-------------+--------------+--------------+
only showing top 3 rows

### Filter stop_times

#### Weekday trips
The stop_times dataframe is joined with the filtered trips dataframe to only include the trips are scheduled monday to friday.

In [19]:
stop_times.show(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
|             trip_id|arrival_time|departure_time|    stop_id|stop_sequence|pickup_type|drop_off_type|
+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
|1.TA.1-1-B-j19-1.1.R|    04:20:00|      04:20:00|8500010:0:3|            1|          0|            0|
+--------------------+------------+--------------+-----------+-------------+-----------+-------------+
only showing top 1 row

In [20]:
trips_filt.show(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+-----------+--------------------+----------+----------+
|service_id|   route_id|             trip_id|route_desc|route_type|
+----------+-----------+--------------------+----------+----------+
|  TA+b0003|1-305-j19-1|12.TA.1-305-j19-1...|       Bus|       700|
+----------+-----------+--------------------+----------+----------+
only showing top 1 row

In [21]:
# Join in order to filter for trips scheduled mon-fri and drop unecessary columns 
temp_filt = stop_times.join(trips_filt, 'trip_id', 'inner')\
                        .drop(*['pickup_type', 'drop_off_type'])
temp_filt.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+--------------+-------+-------------+----------+------------+----------+----------+
|             trip_id|arrival_time|departure_time|stop_id|stop_sequence|service_id|    route_id|route_desc|route_type|
+--------------------+------------+--------------+-------+-------------+----------+------------+----------+----------+
|1.TA.16-440-j19-1...|    06:30:00|      06:30:00|8574776|            1|  TA+b0ei3|16-440-j19-1|       Bus|       700|
|1.TA.16-440-j19-1...|    06:32:00|      06:32:00|8509650|            2|  TA+b0ei3|16-440-j19-1|       Bus|       700|
|1.TA.16-440-j19-1...|    06:32:00|      06:32:00|8574777|            3|  TA+b0ei3|16-440-j19-1|       Bus|       700|
+--------------------+------------+--------------+-------+-------------+----------+------------+----------+----------+
only showing top 3 rows

#### Within 15km radius

The stop_times dataframe is also joined with the filtered stops dataframe, such that it only include stops that are within the 15km radius of Zurich HB.

In [22]:
stops_filt.show(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+---------+--------+-------------+--------------+--------------+
|stop_id|           stop_name| stop_lat|stop_lon|location_type|parent_station|dist_zurich_km|
+-------+--------------------+---------+--------+-------------+--------------+--------------+
|8500926|Oetwil a.d.L., Sc...|47.423626|8.403183|             |              |     11.483568|
+-------+--------------------+---------+--------+-------------+--------------+--------------+
only showing top 1 row

In [23]:
# Join in order to filter for stops within 15km radius and drop unecessary columns
stop_times_filt = temp_filt.join(stops_filt, 'stop_id', 'inner')\
                            .drop(*['location_type', 'dist_zurich_km','route_type'])
stop_times_filt.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+--------------------+------------+--------------+-------------+----------+------------+----------+--------------------+---------+--------+--------------+
|    stop_id|             trip_id|arrival_time|departure_time|stop_sequence|service_id|    route_id|route_desc|           stop_name| stop_lat|stop_lon|parent_station|
+-----------+--------------------+------------+--------------+-------------+----------+------------+----------+--------------------+---------+--------+--------------+
|8503855:0:F|1005.TA.26-131-j1...|    16:14:00|      16:14:00|            1|  TA+b0nfm|26-131-j19-1|       Bus|     Horgen, Bahnhof|47.261856|8.596976|      8503855P|
|    8589111|1005.TA.26-131-j1...|    16:15:00|      16:15:00|            2|  TA+b0nfm|26-131-j19-1|       Bus|Horgen, Gumelenst...|47.260857|8.592305|              |
|    8573553|1005.TA.26-131-j1...|    16:16:00|      16:16:00|            3|  TA+b0nfm|26-131-j19-1|       Bus|     Horgen, Stocker|47.261517|8.588927|              

Now it's important to keep in mind that this stop_times_filt dataset likely has some gaps in the trips. If a trip goes out of the 15km radius and then returns back within the 15 km radius, there are going to be numbers missing in the stop_sequence for the trip at question.


#### Journey times within business hours

The stop_times dataframe is further filtered to only include business hours: 8AM to 5PM. However we also include the hour 7AM to 8AM since a person needs some time to travel in order to arrive at their destination. If one wants to arrive at 8AM, they may have to leave an hour in advance.

In [24]:
# Creating 2 new columns where the hour is extracted from the arrival and departure timestamps
stop_times_filt = stop_times_filt.withColumn('ar', stop_times_filt.arrival_time.substr(1, 2).cast(IntegerType()))
stop_times_filt = stop_times_filt.withColumn('dep', stop_times_filt.departure_time.substr(1, 2).cast(IntegerType()))

# Filter such that we only keep times between 7-17
stop_times_filt = stop_times_filt.where((stop_times_filt.ar >= 7) 
                                        & (stop_times_filt.ar <= 17) 
                                        & (stop_times_filt.dep <= 17) 
                                        & (stop_times_filt.dep >= 7))

# Drop unecessary columns
stop_times_filt = stop_times_filt.drop(*['ar', 'dep'])

stop_times_filt.cache()
stop_times_filt.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+--------------------+------------+--------------+-------------+----------+------------+----------+--------------------+---------+--------+--------------+
|    stop_id|             trip_id|arrival_time|departure_time|stop_sequence|service_id|    route_id|route_desc|           stop_name| stop_lat|stop_lon|parent_station|
+-----------+--------------------+------------+--------------+-------------+----------+------------+----------+--------------------+---------+--------+--------------+
|8503855:0:F|1005.TA.26-131-j1...|    16:14:00|      16:14:00|            1|  TA+b0nfm|26-131-j19-1|       Bus|     Horgen, Bahnhof|47.261856|8.596976|      8503855P|
|    8589111|1005.TA.26-131-j1...|    16:15:00|      16:15:00|            2|  TA+b0nfm|26-131-j19-1|       Bus|Horgen, Gumelenst...|47.260857|8.592305|              |
|    8573553|1005.TA.26-131-j1...|    16:16:00|      16:16:00|            3|  TA+b0nfm|26-131-j19-1|       Bus|     Horgen, Stocker|47.261517|8.588927|              

## Create the timetable

Re-format the stop_times_filt dataframe and group it such that we get the table in the desired format for creating the graph. 

In [34]:
# Convert timestamps to unix timestamp format
timetable_data = stop_times_filt.withColumn('arrival_time', f.unix_timestamp(stop_times_filt.arrival_time, 'HH:mm:ss'))
timetable_data = timetable_data.withColumn('departure_time', f.unix_timestamp(timetable_data.departure_time, 'HH:mm:ss'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [32]:
# Group the data
timetable = timetable_data.groupBy(['trip_id']).agg(f.collect_list('stop_id').alias('stop_ids'), 
                                                    f.collect_list('stop_sequence').alias('stop_sequence'), 
                                                    f.collect_list('arrival_time').alias('arrival_times'), 
                                                    f.collect_list('departure_time').alias('departure_times'), 
                                                    f.collect_list('stop_lat').alias('lats'), 
                                                    f.collect_list('stop_lon').alias('longs'))
timetable.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|             trip_id|            stop_ids|       stop_sequence|       arrival_times|     departure_times|                lats|               longs|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|1005.TA.26-131-j1...|[8503855:0:F, 858...|[1, 2, 3, 4, 5, 6...|[54840, 54900, 54...|[54840, 54900, 54...|[47.261856, 47.26...|[8.596976, 8.5923...|
|103.TA.26-925-j19...|  [8576082, 8576080]|            [20, 21]|      [23880, 24120]|      [23880, 24120]|[47.26727, 47.26944]|[8.650713, 8.644883]|
|104.TA.26-733-j19...|[8573205:0:D, 858...|[1, 2, 3, 4, 5, 6...|[23580, 23580, 23...|[23580, 23580, 23...|[47.450684, 47.45...|[8.563729, 8.5656...|
+--------------------+--------------------+--------------------+--------------------+--------------------+

## Saving necessary data for further use

We have some dataframes that we want to use in other notebooks. This is solved by saving the in the following way.

In [29]:
# Saving dataframes in orc format
stop_times_filt.write.format("orc").mode('overwrite').save("/user/fristedt/stop_times_filt.orc")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [33]:
# Saving dataframes in parquet format
timetable.write.format("parquet").mode('overwrite').save("/user/fristedt/timetable.parquet")
stops_filt.write.format("parquet").mode('overwrite').save("/user/fristedt/stops_filt.parquet")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…