# Using Apache Spark to create Station by Route dataset
**Rationale:** There is potential for enormous amounts of data, it's important to build a stack that can handle distributed compute even if the dataset currently being used does not *require* such use

In [1]:
# start a Spark session
import pyspark

session = pyspark.sql.SparkSession
# use all available memory for the driver
spark = session.builder.config('spark.driver.memory','10g').getOrCreate()
spark

## Get basic info on datasets

In [3]:
# Set dataset paths
stations_path = "MTA_data/MTA_Subway_Stations_20250717.csv"

# Function to read in CSV files
def read_data(path):
    df = spark.read\
              .option("header", True)\
              .csv(path)
    df.printSchema()
    return df

In [4]:
stations = read_data(stations_path)

root
 |-- GTFS Stop ID: string (nullable = true)
 |-- Station ID: string (nullable = true)
 |-- Complex ID: string (nullable = true)
 |-- Division: string (nullable = true)
 |-- Line: string (nullable = true)
 |-- Stop Name: string (nullable = true)
 |-- Borough: string (nullable = true)
 |-- CBD: string (nullable = true)
 |-- Daytime Routes: string (nullable = true)
 |-- Structure: string (nullable = true)
 |-- GTFS Latitude: string (nullable = true)
 |-- GTFS Longitude: string (nullable = true)
 |-- North Direction Label: string (nullable = true)
 |-- South Direction Label: string (nullable = true)
 |-- ADA: string (nullable = true)
 |-- ADA Northbound: string (nullable = true)
 |-- ADA Southbound: string (nullable = true)
 |-- ADA Notes: string (nullable = true)
 |-- Georeference: string (nullable = true)



In [4]:
stations.limit(5).toPandas()

Unnamed: 0,GTFS Stop ID,Station ID,Complex ID,Division,Line,Stop Name,Borough,CBD,Daytime Routes,Structure,GTFS Latitude,GTFS Longitude,North Direction Label,South Direction Label,ADA,ADA Northbound,ADA Southbound,ADA Notes,Georeference
0,R01,1,1,BMT,Astoria,Astoria-Ditmars Blvd,Q,False,N W,Elevated,40.775036,-73.912034,Last Stop,Manhattan,0,0,0,,POINT (-73.912034 40.775036)
1,R03,2,2,BMT,Astoria,Astoria Blvd,Q,False,N W,Elevated,40.770258,-73.917843,Astoria,Manhattan,1,1,1,,POINT (-73.917843 40.770258)
2,R04,3,3,BMT,Astoria,30 Av,Q,False,N W,Elevated,40.766779,-73.921479,Astoria,Manhattan,0,0,0,,POINT (-73.921479 40.766779)
3,R05,4,4,BMT,Astoria,Broadway,Q,False,N W,Elevated,40.76182,-73.925508,Astoria,Manhattan,0,0,0,,POINT (-73.925508 40.76182)
4,R06,5,5,BMT,Astoria,36 Av,Q,False,N W,Elevated,40.756804,-73.929575,Astoria,Manhattan,0,0,0,,POINT (-73.929575 40.756804)


In [5]:
stations.select("Station ID").count()

496

In [6]:
stations.select("Station ID").distinct().count()

493

In [7]:
stations.select("GTFS Stop ID").distinct().count() ## Most unique value here -- should try to match on this

496

In [8]:
stations.select("Stop Name").distinct().count() ## There are duplicate stop names -- need to understand why

378

## From the [codebook](https://data.ny.gov/Transportation/MTA-Subway-Stations/39hk-dx4f/about_data)

- **Division:** 	The division of the subway system (IRT, BMT, or IND) that the station is a part of.
- **CBD:** 	    This indicates whether or not a station is in Manhattan’s Central Business District (CBD). This value is either TRUE or FALSE.
- **Daytime Routes:** The subway routes that serve the station during weekdays.

In [9]:
basic_df = stations.select("GTFS Stop ID", "Division", "Stop Name", "Borough", "CBD", "Daytime Routes", "GTFS Latitude", "GTFS Longitude")

In [10]:
basic_df.limit(3).toPandas()

Unnamed: 0,GTFS Stop ID,Division,Stop Name,Borough,CBD,Daytime Routes,GTFS Latitude,GTFS Longitude
0,R01,BMT,Astoria-Ditmars Blvd,Q,False,N W,40.775036,-73.912034
1,R03,BMT,Astoria Blvd,Q,False,N W,40.770258,-73.917843
2,R04,BMT,30 Av,Q,False,N W,40.766779,-73.921479


In [11]:
# Want one row for each route at each subway stop, will use split to create a list of routes at each stop
# Code modified from HW2 solutions
# Tokenize
from pyspark.sql.functions import split, trim, col

 # split on one or more spaces
basic_df = basic_df.withColumn("Routes", split(trim("Daytime Routes"), " +"))
basic_df.select("Routes").show(5)

+------+
|Routes|
+------+
|[N, W]|
|[N, W]|
|[N, W]|
|[N, W]|
|[N, W]|
+------+
only showing top 5 rows



In [12]:
from pyspark.sql.functions import explode
columns = ["GTFS Stop ID", "Division", "Stop Name", "Borough", "CBD", "Route","GTFS Latitude","GTFS Longitude"]
basic_df = basic_df.withColumn("Route", explode("Routes")).select(columns)
basic_df.show(5)

+------------+--------+--------------------+-------+-----+-----+-------------+--------------+
|GTFS Stop ID|Division|           Stop Name|Borough|  CBD|Route|GTFS Latitude|GTFS Longitude|
+------------+--------+--------------------+-------+-----+-----+-------------+--------------+
|         R01|     BMT|Astoria-Ditmars Blvd|      Q|False|    N|    40.775036|    -73.912034|
|         R01|     BMT|Astoria-Ditmars Blvd|      Q|False|    W|    40.775036|    -73.912034|
|         R03|     BMT|        Astoria Blvd|      Q|False|    N|    40.770258|    -73.917843|
|         R03|     BMT|        Astoria Blvd|      Q|False|    W|    40.770258|    -73.917843|
|         R04|     BMT|               30 Av|      Q|False|    N|    40.766779|    -73.921479|
+------------+--------+--------------------+-------+-----+-----+-------------+--------------+
only showing top 5 rows



## Check that route explode worked correctly
- Filter for one line, and check the stops listed

In [13]:
basic_df.where(col("Route") == "7").show(25)

+------------+--------+--------------------+-------+-----+-----+-------------+--------------+
|GTFS Stop ID|Division|           Stop Name|Borough|  CBD|Route|GTFS Latitude|GTFS Longitude|
+------------+--------+--------------------+-------+-----+-----+-------------+--------------+
|         701|     IRT|    Flushing-Main St|      Q|False|    7|      40.7596|     -73.83003|
|         702|     IRT|  Mets-Willets Point|      Q|False|    7|    40.754622|    -73.845625|
|         705|     IRT|              111 St|      Q|False|    7|     40.75173|    -73.855334|
|         706|     IRT| 103 St-Corona Plaza|      Q|False|    7|    40.749865|      -73.8627|
|         707|     IRT|       Junction Blvd|      Q|False|    7|    40.749145|    -73.869527|
|         708|     IRT|   90 St-Elmhurst Av|      Q|False|    7|    40.748408|    -73.876613|
|         709|     IRT|   82 St-Jackson Hts|      Q|False|    7|    40.747659|    -73.883697|
|         710|     IRT|      74 St-Broadway|      Q|False|  

In [14]:
basic_df.where(col("GTFS Stop ID") == "725").show(10) ## This is the Times Square station, should have many routes

+------------+--------+--------------+-------+----+-----+-------------+--------------+
|GTFS Stop ID|Division|     Stop Name|Borough| CBD|Route|GTFS Latitude|GTFS Longitude|
+------------+--------+--------------+-------+----+-----+-------------+--------------+
|         725|     IRT|Times Sq-42 St|      M|True|    7|    40.755477|    -73.987691|
+------------+--------+--------------+-------+----+-----+-------------+--------------+



In [15]:
condition = col("Stop Name").contains("Times Sq")
basic_df.where(condition).show(10)

+------------+--------+--------------+-------+----+-----+-------------+--------------+
|GTFS Stop ID|Division|     Stop Name|Borough| CBD|Route|GTFS Latitude|GTFS Longitude|
+------------+--------+--------------+-------+----+-----+-------------+--------------+
|         R16|     BMT|Times Sq-42 St|      M|True|    N|    40.754672|    -73.986754|
|         R16|     BMT|Times Sq-42 St|      M|True|    Q|    40.754672|    -73.986754|
|         R16|     BMT|Times Sq-42 St|      M|True|    R|    40.754672|    -73.986754|
|         R16|     BMT|Times Sq-42 St|      M|True|    W|    40.754672|    -73.986754|
|         127|     IRT|Times Sq-42 St|      M|True|    1|     40.75529|    -73.987495|
|         127|     IRT|Times Sq-42 St|      M|True|    2|     40.75529|    -73.987495|
|         127|     IRT|Times Sq-42 St|      M|True|    3|     40.75529|    -73.987495|
|         725|     IRT|Times Sq-42 St|      M|True|    7|    40.755477|    -73.987691|
|         902|     IRT|Times Sq-42 St|     

In [16]:
basic_df.count() ## After exploding, we have 767 rows rather than the original 496

767

In [19]:
# Source: https://stackoverflow.com/questions/38611418/writing-a-csv-with-column-names-and-reading-a-csv-file-which-is-being-generated
# Save data as CSV file
basic_df.write.csv('Station_by_Line_Data',header=True)

# Query MTA API
1) Get random list of dates
2) Query API to get randomly sampled dataset of crowdedness on subways
3) Use Spark for memory management

In [5]:
# Create a list of dates from available
# Source: https://stackoverflow.com/questions/993358/creating-a-range-of-dates-in-python
import pandas as pd
from datetime import datetime

datelist = pd.date_range(start = '01/01/2025', end = '07/10/2025').tolist()

In [6]:
datelist[0:4]

[Timestamp('2025-01-01 00:00:00'),
 Timestamp('2025-01-02 00:00:00'),
 Timestamp('2025-01-03 00:00:00'),
 Timestamp('2025-01-04 00:00:00')]

In [7]:
# Source: https://stackoverflow.com/questions/48918009/set-the-date-to-the-next-day-date-in-python
from dateutil.relativedelta import relativedelta

def tomorrow(date):
    return date + relativedelta(days = +1)

tomorrow(datelist[0])

Timestamp('2025-01-02 00:00:00')

In [8]:
len(datelist)

191

In [9]:
from random import randint, seed
from datetime import datetime

In [10]:
seed(20250719) ## Set seed for reproducibility

start = 0
end = len(datelist) - 1
num = 20 # Number of random dates to generate
rand_dates = [] # Array to store dates

while len(rand_dates) != num:
    rand_date = datelist[randint(a = start, b = end)]
    if (rand_date not in rand_dates):
        rand_dates.append(rand_date)

In [11]:
rand_dates.sort()
rand_dates

[Timestamp('2025-01-07 00:00:00'),
 Timestamp('2025-01-09 00:00:00'),
 Timestamp('2025-01-18 00:00:00'),
 Timestamp('2025-02-14 00:00:00'),
 Timestamp('2025-02-21 00:00:00'),
 Timestamp('2025-03-14 00:00:00'),
 Timestamp('2025-03-17 00:00:00'),
 Timestamp('2025-03-25 00:00:00'),
 Timestamp('2025-04-20 00:00:00'),
 Timestamp('2025-04-23 00:00:00'),
 Timestamp('2025-04-29 00:00:00'),
 Timestamp('2025-05-03 00:00:00'),
 Timestamp('2025-05-05 00:00:00'),
 Timestamp('2025-05-28 00:00:00'),
 Timestamp('2025-06-11 00:00:00'),
 Timestamp('2025-06-20 00:00:00'),
 Timestamp('2025-06-24 00:00:00'),
 Timestamp('2025-07-05 00:00:00'),
 Timestamp('2025-07-09 00:00:00'),
 Timestamp('2025-07-10 00:00:00')]

## Test with small subset CSV matching timestamps

In [12]:
jan1_path = "MTA_data/MTA_Subway_Hourly_Ridership__Beginning_2025_20250719.csv"
jan1 = read_data(jan1_path)

root
 |-- transit_timestamp: string (nullable = true)
 |-- transit_mode: string (nullable = true)
 |-- station_complex_id: string (nullable = true)
 |-- station_complex: string (nullable = true)
 |-- borough: string (nullable = true)
 |-- payment_method: string (nullable = true)
 |-- fare_class_category: string (nullable = true)
 |-- ridership: string (nullable = true)
 |-- transfers: string (nullable = true)
 |-- latitude: string (nullable = true)
 |-- longitude: string (nullable = true)
 |-- Georeference: string (nullable = true)



In [13]:
jan1.limit(5).toPandas()

Unnamed: 0,transit_timestamp,transit_mode,station_complex_id,station_complex,borough,payment_method,fare_class_category,ridership,transfers,latitude,longitude,Georeference
0,01/01/2025 07:00:00 PM,subway,10,"49 St (N,R,W)",Manhattan,omny,OMNY - Full Fare,753,9,40.7599,-73.98414,POINT (-73.98414 40.7599)
1,01/01/2025 07:00:00 PM,subway,418,"233 St (2,5)",Bronx,omny,OMNY - Students,4,0,40.893192,-73.857475,POINT (-73.857475 40.893192)
2,01/01/2025 07:00:00 PM,subway,409,Spring St (6),Manhattan,omny,OMNY - Other,2,0,40.7223,-73.99714,POINT (-73.99714 40.7223)
3,01/01/2025 07:00:00 PM,subway,455,69 St (7),Queens,omny,OMNY - Students,7,0,40.746326,-73.8964,POINT (-73.8964 40.746326)
4,01/01/2025 07:00:00 PM,subway,65,79 St (D),Brooklyn,omny,OMNY - Students,6,0,40.613503,-74.00061,POINT (-74.00061 40.613503)


## Sources
https://stackoverflow.com/questions/54775015/soda-api-filtering

https://stackoverflow.com/questions/72127658/use-a-python-list-in-a-socrata-request

In [14]:
!pip install sodapy
from sodapy import Socrata
from TOKENS import APP_TOKEN ## APP_TOKEN allows for unlimited requests as of July 2025



## Get January 7 Data

In [15]:
query = f"""SELECT date_trunc_ymd(transit_timestamp) AS transit_date,
                   station_complex,
                   SUM(ridership) AS total_ridership
            WHERE transit_timestamp >= "{rand_dates[0].isoformat()}" AND transit_timestamp < "{tomorrow(rand_dates[0]).isoformat()}"
            AND transit_mode = 'subway'
            GROUP BY 
              transit_date, 
              station_complex
            ORDER BY station_complex ASC NULLS LAST """

## Get All Data from Random Dates

In [16]:
def getJSON(dates, limit = 10000):
    final = []
    for i in range(0, len(dates)):
        ## SoQL Query
        query = f"""SELECT date_trunc_ymd(transit_timestamp) AS transit_date,
                           station_complex,
                           SUM(ridership) AS total_ridership
                    WHERE transit_timestamp >= "{dates[i].isoformat()}" AND transit_timestamp < "{tomorrow(rand_dates[i]).isoformat()}"
                    AND transit_mode = 'subway'
                    GROUP BY 
                      transit_date, 
                      station_complex
                    ORDER BY station_complex ASC NULLS LAST 
                    LIMIT {limit} """
        client = Socrata("data.ny.gov", APP_TOKEN)
        results = client.get("5wq4-mkjj", query=query)
        final.extend(results)
    return final

In [17]:
rand_ridership = getJSON(rand_dates)

In [18]:
len(rand_ridership)

8364

In [19]:
rand_ridership[8363]

{'transit_date': '2025-07-10T00:00:00.000',
 'station_complex': 'Zerega Av (6)',
 'total_ridership': '1658.0'}

In [20]:
rand_ridership[0:5]

[{'transit_date': '2025-01-07T00:00:00.000',
  'station_complex': '103 St (1)',
  'total_ridership': '8693.0'},
 {'transit_date': '2025-01-07T00:00:00.000',
  'station_complex': '103 St (6)',
  'total_ridership': '9483.0'},
 {'transit_date': '2025-01-07T00:00:00.000',
  'station_complex': '103 St (C,B)',
  'total_ridership': '2978.0'},
 {'transit_date': '2025-01-07T00:00:00.000',
  'station_complex': '103 St-Corona Plaza (7)',
  'total_ridership': '18261.0'},
 {'transit_date': '2025-01-07T00:00:00.000',
  'station_complex': '104 St (A)',
  'total_ridership': '1238.0'}]

In [23]:
# Source: https://stackoverflow.com/questions/75473246/how-to-make-a-list-of-json-objects-into-a-json-file-in-python-and-join-multiple

import json

with open('subway_ridership_sample_2025.json', 'w') as f:
    raw_data_as_string = json.dumps(rand_ridership)
    f.write(raw_data_as_string)