Skip to content

Commit

Permalink
Have sim.get_table convert table sources.
Browse files Browse the repository at this point in the history
  • Loading branch information
jiffyclub committed Aug 12, 2014
1 parent fbf9084 commit cd7c668
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
9 changes: 7 additions & 2 deletions urbansim/sim/simulation.py
Expand Up @@ -813,17 +813,22 @@ def get_table(table_name):
"""
Get a registered table.
Table sources will be converted to `DataFrameWrapper`.
Parameters
----------
table_name : str
Returns
-------
table : `DataFrameWrapper`, `TableFuncWrapper`, or `_TableSourceWrapper`
table : `DataFrameWrapper` or `TableFuncWrapper`
"""
if table_name in _TABLES:
return _TABLES[table_name]
table = _TABLES[table_name]
if isinstance(table, _TableSourceWrapper):
table = table.convert()
return table
else:
raise KeyError('table not found: {}'.format(table_name))

Expand Down
36 changes: 32 additions & 4 deletions urbansim/sim/tests/test_simulation.py
Expand Up @@ -475,7 +475,9 @@ def test_table_source(df):
def source():
return df

table = sim.get_table('source')
_source = lambda: sim._TABLES['source']

table = _source()
assert isinstance(table, sim._TableSourceWrapper)

test_df = table.to_frame()
Expand All @@ -484,7 +486,7 @@ def source():
assert len(table) == len(df)
pdt.assert_index_equal(table.index, df.index)

table = sim.get_table('source')
table = _source()
assert isinstance(table, sim.DataFrameWrapper)

test_df = table.to_frame()
Expand All @@ -496,14 +498,16 @@ def test_table_source_convert(df):
def source():
return df

table = sim.get_table('source')
_source = lambda: sim._TABLES['source']

table = _source()
assert isinstance(table, sim._TableSourceWrapper)

table = table.convert()
assert isinstance(table, sim.DataFrameWrapper)
pdt.assert_frame_equal(table.to_frame(), df)

table2 = sim.get_table('source')
table2 = _source()
assert table2 is table


Expand Down Expand Up @@ -589,3 +593,27 @@ def model(year, table):
for x in range(11):
pdt.assert_series_equal(
store['final/table'][year_key(x)], series_year(x))


def test_get_table(df):
sim.add_table('frame', df)

@sim.table('table')
def table():
return df

@sim.table_source('source')
def source():
return df

fr = sim.get_table('frame')
ta = sim.get_table('table')
so = sim.get_table('source')

assert isinstance(fr, sim.DataFrameWrapper)
assert isinstance(ta, sim.TableFuncWrapper)
assert isinstance(so, sim.DataFrameWrapper)

pdt.assert_frame_equal(fr.to_frame(), df)
pdt.assert_frame_equal(ta.to_frame(), df)
pdt.assert_frame_equal(so.to_frame(), df)

0 comments on commit cd7c668

Please sign in to comment.