Skip to content

Commit

Permalink
add ability to restrict a gtfs import to a single day
Browse files Browse the repository at this point in the history
  • Loading branch information
bmander committed Nov 3, 2010
1 parent d3aa0de commit a69a12b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
27 changes: 20 additions & 7 deletions pygs/graphserver/compiler/gdb_import_gtfs.py
@@ -1,10 +1,11 @@
from graphserver.core import Graph, TripBoard, HeadwayBoard, HeadwayAlight, Crossing, TripAlight, Timezone, Street, Link, ElapseTime
from optparse import OptionParser
from graphserver.graphdb import GraphDatabase
from graphserver.ext.gtfs.gtfsdb import GTFSDatabase
from graphserver.ext.gtfs.gtfsdb import GTFSDatabase, parse_gtfs_date
import sys
import pytz
from tools import service_calendar_from_timezone
import datetime

def cons(ary):
for i in range(len(ary)-1):
Expand Down Expand Up @@ -109,7 +110,7 @@ def bundle_to_boardalight_edges(self, bundle, service_id):
to_patternstop_vx_name,
crossing )

def gtfsdb_to_scheduled_edges(self, maxtrips=None):
def gtfsdb_to_scheduled_edges(self, maxtrips=None, service_ids=None):

# compile trip bundles from gtfsdb
if self.reporter: self.reporter.write( "Compiling trip bundles...\n" )
Expand All @@ -122,6 +123,9 @@ def gtfsdb_to_scheduled_edges(self, maxtrips=None):
if self.reporter: self.reporter.write( "%d/%d loading %s\n"%(i+1, n_bundles, bundle) )

for service_id in [x.encode("ascii") for x in self.gtfsdb.service_ids()]:
if service_ids is not None and service_id not in service_ids:
continue

for fromv_label, tov_label, edge in self.bundle_to_boardalight_edges(bundle, service_id):
yield fromv_label, tov_label, edge

Expand Down Expand Up @@ -194,8 +198,8 @@ def gtfsdb_to_transfer_edges( self ):
elif conn_type == 3: # Transfers are not possible between routes at this location.
print "WARNING: Support for no-transfer (transfers.txt transfer_type=3) not implemented."

def gtfsdb_to_edges( self, maxtrips=None ):
for edge_tuple in self.gtfsdb_to_scheduled_edges(maxtrips):
def gtfsdb_to_edges( self, maxtrips=None, service_ids=None ):
for edge_tuple in self.gtfsdb_to_scheduled_edges(maxtrips, service_ids=service_ids):
yield edge_tuple

for edge_tuple in self.gtfsdb_to_headway_edges(maxtrips):
Expand All @@ -204,12 +208,20 @@ def gtfsdb_to_edges( self, maxtrips=None ):
for edge_tuple in self.gtfsdb_to_transfer_edges():
yield edge_tuple

def gdb_load_gtfsdb(gdb, agency_namespace, gtfsdb, cursor, agency_id=None, maxtrips=None, reporter=sys.stdout):
def gdb_load_gtfsdb(gdb, agency_namespace, gtfsdb, cursor, agency_id=None, maxtrips=None, sample_date=None, reporter=sys.stdout):

# determine which service periods run on the given day, if a day is given
if sample_date is not None:
sample_date = datetime.date( *parse_gtfs_date( sample_date ) )
acceptable_service_ids = gtfsdb.service_periods( sample_date )
print "Importing only service periods operating on %s: %s"%(sample_date, acceptable_service_ids)
else:
acceptable_service_ids = None

compiler = GTFSGraphCompiler( gtfsdb, agency_namespace, agency_id, reporter )
c = gdb.get_cursor()
v_added = set([])
for fromv_label, tov_label, edge in compiler.gtfsdb_to_edges( maxtrips ):
for fromv_label, tov_label, edge in compiler.gtfsdb_to_edges( maxtrips, service_ids=acceptable_service_ids ):
if fromv_label not in v_added:
gdb.add_vertex( fromv_label, c )
v_added.add(fromv_label)
Expand Down Expand Up @@ -237,6 +249,7 @@ def main():
parser.add_option("-n", "--namespace", dest="namespace", default="0",
help="agency namespace")
parser.add_option("-m", "--maxtrips", dest="maxtrips", default=None, help="maximum number of trips to load")
parser.add_option("-d", "--date", dest="sample_date", default=None, help="only load transit running on a given day. YYYYMMDD" )

(options, args) = parser.parse_args()

Expand All @@ -254,7 +267,7 @@ def main():
gdb = GraphDatabase( graphdb_filename, overwrite=False )

maxtrips = int(options.maxtrips) if options.maxtrips else None
gdb_load_gtfsdb( gdb, options.namespace, gtfsdb, gdb.get_cursor(), agency_id, maxtrips=maxtrips)
gdb_load_gtfsdb( gdb, options.namespace, gtfsdb, gdb.get_cursor(), agency_id, maxtrips=maxtrips, sample_date=options.sample_date)
gdb.commit()

print "done"
Expand Down
14 changes: 7 additions & 7 deletions pygs/graphserver/ext/gtfs/gtfsdb.py
Expand Up @@ -405,26 +405,26 @@ def date_range(self):
DOWS = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']
DOW_INDEX = dict(zip(range(len(DOWS)),DOWS))

def service_periods(self, datetime):
datetimestr = datetime.strftime( "%Y%m%d" ) #datetime to string like "20081225"
def service_periods(self, sample_date):
datetimestr = sample_date.strftime( "%Y%m%d" ) #sample_date to string like "20081225"
datetimeint = int(datetimestr) #int like 20081225. These ints have the same ordering as regular dates, so comparison operators work

# Get the gtfs date range. If the datetime is out of the range, no service periods are in effect
# Get the gtfs date range. If the sample_date is out of the range, no service periods are in effect
start_date, end_date = self.date_range()
if datetime < start_date or datetime > end_date:
if sample_date < start_date or sample_date > end_date:
return []

# Use the day-of-week name to query for all service periods that run on that day
dow_name = self.DOW_INDEX[datetime.weekday()]
dow_name = self.DOW_INDEX[sample_date.weekday()]
service_periods = list( self.execute( "SELECT service_id, start_date, end_date FROM calendar WHERE %s=1"%dow_name ) )

# Exclude service periods whose range does not include this datetime
# Exclude service periods whose range does not include this sample_date
service_periods = [x for x in service_periods if (int(x[1]) <= datetimeint and int(x[2]) >= datetimeint)]

# Cut service periods down to service IDs
sids = set( [x[0] for x in service_periods] )

# For each exception on the given datetime, add or remove service_id to the accumulating list
# For each exception on the given sample_date, add or remove service_id to the accumulating list

for exception_sid, exception_type in self.execute( "select service_id, exception_type from calendar_dates WHERE date = ?", (datetimestr,) ):
if exception_type == 1:
Expand Down

0 comments on commit a69a12b

Please sign in to comment.