In [1]:
sc

In [2]:
taxi = sc.textFile('yellow.csv.gz')
bike = sc.textFile('citibike.csv')

In [3]:
list(enumerate(bike.first().split(',')))

[(0, 'cartodb_id'),
 (1, 'the_geom'),
 (2, 'tripduration'),
 (3, 'starttime'),
 (4, 'stoptime'),
 (5, 'start_station_id'),
 (6, 'start_station_name'),
 (7, 'start_station_latitude'),
 (8, 'start_station_longitude'),
 (9, 'end_station_id'),
 (10, 'end_station_name'),
 (11, 'end_station_latitude'),
 (12, 'end_station_longitude'),
 (13, 'bikeid'),
 (14, 'usertype'),
 (15, 'birth_year'),
 (16, 'gender')]

In [4]:
list(enumerate(taxi.first().split(',')))

[(0, 'tpep_pickup_datetime'),
 (1, 'tpep_dropoff_datetime'),
 (2, 'pickup_latitude'),
 (3, 'pickup_longitude'),
 (4, 'dropoff_latitude'),
 (5, 'dropoff_longitude')]

In [16]:
def filterBike(records):
    for record in records:
        fields = record.split(',')
        if (fields[6]=='Greenwich Ave & 8 Ave' and 
            fields[3].startswith('2015-02-01')):
            yield (fields[3][:19], 1)

matchedBike = bike.mapPartitions(filterBike)

In [17]:
matchedBike.take(2)

[('2015-02-01 00:05:00', 1), ('2015-02-01 00:05:00', 1)]

In [18]:
bikeStation = (-74.00263761, 40.73901691)

In [25]:
def filterTaxi(pid, lines):
    if pid==0:
        next(lines)
    import pyproj
    proj = pyproj.Proj(init="epsg:2263", preserve_units=True)
    station = proj(-74.00263761, 40.73901691)
    squared_radius = 1320**2
    for trip in lines:
            fields = trip.split(',')
            if 'NULL' in fields[4:6]: continue
            dropoff = proj(fields[5], fields[4])
            squared_distance = (dropoff[0]-station[0])**2 + (dropoff[1]-station[1])**2
            if (fields[1].startswith('2015-02-01') and
                squared_distance <= squared_radius):
                yield (fields[1][:19], 0)

matchedTaxi = taxi.mapPartitionsWithIndex(filterTaxi)
matchedTaxi.count()

7278

In [20]:
matchedTaxi.take(2)

[('2015-02-01 00:11:03', 0), ('2015-02-01 00:10:23', 0)]

In [22]:
allTrips = (matchedBike+matchedTaxi).sortByKey().cache()

In [30]:
def connectTrips(_, records):
    import datetime
    lastTaxiTime = None
    count = 0
    for dt,mode in records:
        t = datetime.datetime.strptime(dt, '%Y-%m-%d %H:%M:%S')
        if mode==1:
            if lastTaxiTime!=None:
                diff = (t-lastTaxiTime).total_seconds()
                if diff>=0 and diff<=600:
                    count += 1
        else:
            lastTaxiTime = t
    yield(count)

allTrips.mapPartitionsWithIndex(connectTrips).reduce(lambda x,y: x+y)

65

In [35]:
dfAll = sqlContext.createDataFrame(allTrips, ('time', 'mode'))
dfAll.show()

+-------------------+----+
|               time|mode|
+-------------------+----+
|2015-02-01 00:03:12|   0|
|2015-02-01 00:04:39|   0|
|2015-02-01 00:05:00|   1|
|2015-02-01 00:05:00|   1|
|2015-02-01 00:05:38|   0|
|2015-02-01 00:06:15|   0|
|2015-02-01 00:07:07|   0|
|2015-02-01 00:07:29|   0|
|2015-02-01 00:07:57|   0|
|2015-02-01 00:08:56|   0|
|2015-02-01 00:08:57|   0|
|2015-02-01 00:09:17|   0|
|2015-02-01 00:09:52|   0|
|2015-02-01 00:10:12|   0|
|2015-02-01 00:10:14|   0|
|2015-02-01 00:10:23|   0|
|2015-02-01 00:10:34|   0|
|2015-02-01 00:10:56|   0|
|2015-02-01 00:11:01|   0|
|2015-02-01 00:11:02|   0|
+-------------------+----+
only showing top 20 rows



In [37]:
dfTrips = dfAll.select(dfAll['time'].cast('timestamp').cast('long').alias('epoch'), 'mode')
dfTrips.take(2)

[Row(epoch=1422766992, mode=0), Row(epoch=1422767079, mode=0)]

In [38]:
dfTrips.registerTempTable('trips')

In [64]:
statement = '''
SELECT sum(has_taxi)
FROM (SELECT mode, 1-MIN(mode) OVER 
                    (ORDER BY epoch RANGE BETWEEN 600 PRECEDING AND CURRENT ROW) 
                    AS has_taxi
        FROM trips) newTrips
WHERE mode=1
'''
sqlContext.sql(statement).show()

+-------------+
|sum(has_taxi)|
+-------------+
|           65|
+-------------+



In [57]:
import pyspark.sql.functions as sf
import pyspark.sql.window as sw
window = sw.Window.orderBy('epoch').rangeBetween(-600, 0)
results = dfTrips.select('mode', (1-sf.min(dfTrips['mode']).over(window)).alias('has_taxi')) \
    .filter(dfTrips['mode']==1) \
    .select(sf.sum(sf.col('has_taxi')))
results.show()

+-------------+
|sum(has_taxi)|
+-------------+
|           65|
+-------------+

