Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Chunk Commit

  • Loading branch information...
commit 3db8f14865b763f4997121d00350c1fa45696a8e 1 parent 3edaa77
@snktagarwal authored
Showing with 204 additions and 30 deletions.
  1. +190 −26 delay_profile/TrainModel.py
  2. +14 −4 delay_profile/tests.py
View
216 delay_profile/TrainModel.py
@@ -6,30 +6,37 @@
import Segments
from Utilities import *
from scipy.stats.stats import pearsonr
+import pickle
util = UtilitiesStat()
class Station:
- def __init__(self, tr_no, tr_nm, stn_nm, sch_arr, del_arr, act_arr, sch_dep, del_dep, act_dep):
+ def __init__(self, tr_no, tr_nm, stn_nm, sch_arr, del_arr, act_arr, sch_dep, del_dep, act_dep, stn_code = None, src_dist = 0):
- self.SEG_INFO_TYPE = ['START', 'MID', 'END', 'NONE']
- self.seg_number = -1
- self.seg_info_type = self.SEG_INFO_TYPE[3]
+ self.SEG_INFO_TYPE = ['START', 'MID', 'END']
+
+ # A station might be part of multiple segments, hence use of lists
+ self.seg_info = []
+
+ self.run_day = 1 # By default train in running on the same day
self.tr_no = tr_no
self.tr_nm = tr_nm
self.stn_nm = stn_nm
- try:
- self.stn_code = Station.disambiguate(util.stn_codes[self.stn_nm])
- except:
- self.stn_code = 'NA'
+ if stn_code: self.stn_code = stn_code
+ else:
+ try:
+ self.stn_code = Station.disambiguate(util.stn_codes[self.stn_nm])
+ except:
+ self.stn_code = 'NA'
self.sch_arr = sch_arr
self.act_arr = act_arr
self.del_arr = del_arr
self.sch_dep = sch_dep
self.act_dep = act_dep
self.del_dep = del_dep
+ self.src_dist = src_dist
def _print(self):
print 'Train Name: ' + str(self.tr_nm)
@@ -42,8 +49,7 @@ def augment_seg_info(self, seg_number, info_type):
""" Augments information to a train regarding the segment.
see self.SEG_INFO_TYPE for the types of information available"""
- self.seg_number = seg_number
- self.seg_info_type = info_type
+ self.seg_info.append([seg_number, info_type])
@staticmethod
def disambiguate(stn_name):
@@ -73,31 +79,40 @@ class Train:
delay over a segment, and other statistics *particular to a train* """
def __init__(self, tr_no):
+
""" Starts a simple train as an empty object """
self.stn_list = []
self.tr_no = tr_no
+ def update_segment(self, seg_no):
+
+ """ Takes a segment number as argument and fills in information
+ to various stations if they lie on this segment. This is useful
+ for constructing Train -> STN1 || STN2 .. list augmented with
+ segment information """
+
+
def isSegment(self, seg):
+
""" Given a segment as a list of stops and train information
figure out the left and right end points of segment on the train
A Train is said to cross a segment if it has atleast 2 stops from
the segment as a part of the train schedule """
- self.seg = seg
- train = self
+
# Find the intersection of segment with train
- if len(set(map(lambda x: x.stn_code, train.stn_list)).intersection(set(self.seg)))>=2:
+ if len(set(map(lambda x: x.stn_code, self.stn_list)).intersection(set(seg)))>=2:
# Find the min intersection point
mini = 100
- for stn in self.seg:
- for i in range(len(train.stn_list)):
- if train.stn_list[i].stn_code==stn and mini > i:
+ for stn in seg:
+ for i in range(len(self.stn_list)):
+ if self.stn_list[i].stn_code==stn and mini > i:
mini = i
maxi = -1
- for stn in self.seg:
- for i in range(len(train.stn_list)):
- if train.stn_list[i].stn_code==stn and maxi < i:
+ for stn in seg:
+ for i in range(len(self.stn_list)):
+ if self.stn_list[i].stn_code==stn and maxi < i:
maxi = i
return [mini, maxi]
@@ -106,6 +121,16 @@ def isSegment(self, seg):
def addStn(self, stn):
+
+ """ Given a station, adds it to the list of existing station stops
+ of a train. Additional check require to update the day if train has
+ run overnight """
+
+ if len(self.stn_list) > 1 and self.stn_list[-1].sch_arr > stn.sch_arr:
+ stn.run_day = self.stn_list[-1].run_day + 1
+ elif len(self.stn_list) > 1:
+ stn.run_day = self.stn_list[-1].run_day
+
self.stn_list.append(stn)
def getDelayOverSegment(self, lidx, ridx):
@@ -160,11 +185,6 @@ def getTrafficOverSegment(self, lidx, ridx):
#print delay
return act_stay_seg
- @staticmethod
- def runConsistencyCheck(idx):
- for (k,v) in idx.iteritems():
- v.checkStationsWithTT()
-
class Segment:
""" This class describes each segment defined in Segments.py.
@@ -343,6 +363,124 @@ def __init__(self, filename):
self.filename = filename
+ def constructTimeTableIndex(self):
+
+ """ Constructs a static index of train information using a train
+ station index. """
+
+ f = file(self.filename, 'r')
+ idx = {}
+
+ lines = f.readlines()
+
+ # Skip all the comments
+ while lines[0].startswith('#') or lines[0].startswith(' '):
+ lines = lines[1:]
+
+ # Read each train carefully
+ for train in lines:
+
+ parts = train.strip().split()
+ if len(parts) < 3: continue
+ tr_no = parts[0]
+ parts = parts[1:]
+
+ # Create a new train
+ idx[tr_no] = Train(tr_no)
+
+ # Read each station
+ while parts:
+ [stn_code, arr_time, dep_time, dist_source] = parts[0:4]
+ print parts[0:4]
+ arr_time = util.toMin(arr_time)
+ dep_time = util.toMin(dep_time)
+ parts = parts[4:]
+ idx[tr_no].addStn(Station(tr_no, 'X', 'X', arr_time, 0, 0, dep_time, 0, 0, stn_code, int(dist_source)))
+
+ self.tt_idx = idx
+ return self.tt_idx
+
+ def augmentTimeTableWithSegments(self):
+
+ """ Augments information about segments in the time table object """
+
+ for i in range(len(Segments.all_segments)):
+
+ s = Segment(Segments.all_segments[i])
+
+ for (tr_no, train) in self.tt_idx.iteritems():
+
+ l1, l2 = s.isSegment(train)
+
+
+ if not -1 in [l1, l2]:
+
+ # Starting station
+ train.stn_list[l1].augment_seg_info(i, 0)
+
+ # Ending station
+ train.stn_list[l2].augment_seg_info(i, 2)
+
+ # Rest are all mid segments
+ for j in range(l1+1, l2):
+ train.stn_list[j].augment_seg_info(i, 1)
+
+ return self.tt_idx
+
+ def printAugmentedTimeTable(self, output):
+
+ """ Prints the new segment information augmented timetable
+ The format is as follows:
+ TRAIN || STN1 || STN2 || ... || STN_N
+ where each STN_i expands as:
+ STN_i : STN_CODE || ARR || DEP || SOURCE || SEG_INFO_LIST
+ where SEG_INFO_LIST expands as:
+ SEG_INFO_LIST : SEG_NO, STATUS, SEG_NO, STATUS, -1 -1 (ends with -1)
+ Please look at the final output to gain idea. """
+
+ f = file(output, 'w')
+
+ for (tr_no, train) in self.tt_idx.iteritems():
+
+ f.write(tr_no + '||')
+
+ for stn in train.stn_list:
+
+ f.write(stn.stn_code+ '||')
+ f.write(str(stn.sch_arr) + '||')
+ f.write(str(stn.sch_dep) + '||')
+
+ for e in stn.seg_info:
+ f.write(str(e[0]) + ',' + str(e[1]))
+ f.write('-1,-1')
+ f.write('||')
+
+ f.write('\n')
+
+ def printAugmentedTimeTableP(self, output):
+ """ Prints the augmented time table in augmented form. Read the
+ above documentation for an explanation.
+ A dictionary containing train no as keys, has a dictionary
+ having each key as a station name having list arr, dep, source_dist,
+ and a dict having seg_no and status -- complicated huh ;) """
+
+ sol_dict= {}
+
+ for (tr_no, train) in self.tt_idx.iteritems():
+
+ sol_dict[tr_no] = []
+
+ for stn in train.stn_list:
+
+ sol_dict[tr_no].append([ stn.stn_code, stn.sch_arr, stn.sch_dep, {}, stn.src_dist])
+
+ for e in stn.seg_info:
+ sol_dict[tr_no][-1][3][str(e[0])] = str(e[1])
+
+ pickle.dump(sol_dict, open(output, 'wb'))
+
+
+
def constructSegmentsIndex(self):
""" Constructs a segment index. Note that we do not do any
processing about average delay daily, hourly or averge traffic
@@ -372,8 +510,6 @@ def dailyAverageDelayPS(self):
s = s + float(sum(x.total_delay.values()))/len(x.total_delay.values())
self.daily_delay.append(s)
-
-
def dailyAverageTrafficPS(self):
""" Finds the average traffic, which is simply the number of trains
@@ -563,6 +699,34 @@ def totalHourlyDelay(self):
t = map(lambda x: float(x)/len(self.idx_list), t)
self.total_hourly_delay.append(t)
+ def hourVsSegmentDelayMat(self):
+
+ """ Constructs a segment vs hours matrix ( 54 x 12 ) which contains
+ Delay metrics """
+
+ self.seg_hour_del_mat = [[0]*len(self._hours)]*len(Segments.all_segments)
+
+ for i in range(len(Segments.all_segments)):
+ for j in range(len(self._hours)):
+ for idx in self.idx_list:
+ self.seg_hour_del_mat[i][j] += idx.hourly_delay_dict[i][j]
+ self.seg_hour_del_mat[i][j] = float(self.seg_hour_del_mat[i][j])/len(self.seg_hour_del_mat)
+
+ def hourVsSegmentTrafficMat(self):
+
+ """ Constructs a segment vs hours matrix ( 54 x 12 ) which contains
+ Traffic metrics """
+
+ self.seg_hour_traf_mat = [[0]*len(self._hours)]*len(Segments.all_segments)
+
+ for i in range(len(Segments.all_segments)):
+ for j in range(len(self._hours)):
+ for idx in self.idx_list:
+ self.seg_hour_traf_mat[i][j] += idx.hourly_traffic_dict[i][j]
+ self.seg_hour_traf_mat[i][j] = float(self.seg_hour__traf_mat[i][j])/len(self.seg_hour_traf_mat)
+
+
+
def plotHourlyTraffic(self):
""" For a first view let us create segmentsxhours insatances of
View
18 delay_profile/tests.py
@@ -75,7 +75,7 @@ def getAvgDelayPerSegmentForAllFiles(delay_set):
def new_test():
- handle = DailySets(glob.glob('daily_data/*.out'))
+ handle = DailySets(glob.glob('daily_data/*23.out'))
handle.index()
for idx in handle.idx_list:
@@ -84,9 +84,6 @@ def new_test():
idx.hourlyAverageTrafficPS()
idx.hourlyAverageDelayPS()
- #print idx.daily_traffic
- #print idx.daily_delay
-
handle.totalAverageTraffic()
handle.totalAverageDelay()
handle.totalHourlyTraffic()
@@ -132,8 +129,21 @@ def getSegmentsSortedByDelay(handle):
print seg_del_list_sorted[i][0]
print seg_del_list_sorted[i][1]
+def getTimeTableAugmented():
+
+ c = Indexing('delay_profile/datasets/NewTrainStationDetail.txt')
+ c.constructTimeTableIndex()
+ c.augmentTimeTableWithSegments()
+ c.printAugmentedTimeTableP('NewTrainStationDetailWSegments.p')
+ idx = c.tt_idx
+
+ # Some debugging info
+ print map(lambda x: [x.stn_nm, x.seg_info, x.src_dist ], idx['12280'].stn_list)
+ return idx
+
if __name__=='__main__':
#d_s = getDelayForEachFile('daily_data/')
#getAvgDelayPerSegmentForAllFiles(d_s)
new_test()
+ #getTimeTableAugmented()
Please sign in to comment.
Something went wrong with that request. Please try again.