Skip to content

Commit

Permalink
Merge pull request #74 from synthicity/table-func-upgrades
Browse files Browse the repository at this point in the history
Add ability to get local columns from wrapped functions
  • Loading branch information
jiffyclub committed Jul 29, 2014
2 parents ee464a2 + d47122a commit c53635a
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 29 deletions.
89 changes: 69 additions & 20 deletions urbansim/sim/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SimulationError(Exception):
pass


class _DataFrameWrapper(object):
class DataFrameWrapper(object):
"""
Wraps a DataFrame so it can provide certain columns and handle
computed columns.
Expand All @@ -54,7 +54,15 @@ def columns(self):
Columns in this table.
"""
return list(self._frame.columns) + _list_columns_for_table(self.name)
return self.local_columns + _list_columns_for_table(self.name)

@property
def local_columns(self):
"""
Columns that are part of the wrapped DataFrame.
"""
return list(self._frame.columns)

@property
def index(self):
Expand Down Expand Up @@ -136,7 +144,7 @@ def __len__(self):
return len(self._frame)


class _TableFuncWrapper(object):
class TableFuncWrapper(object):
"""
Wrap a function that provides a DataFrame.
Expand All @@ -159,11 +167,25 @@ def __init__(self, name, func):
@property
def columns(self):
"""
Columns in this table. (May often be out of date.)
Columns in this table. (May contain only computed columns
if the wrapped function has not been called yet.)
"""
return self._columns + _list_columns_for_table(self.name)

@property
def local_columns(self):
"""
Only the columns contained in the DataFrame returned by the
wrapped function. (No registered columns included.)
"""
if self._columns:
return self._columns
else:
self._call_func()
return self._columns

@property
def index(self):
"""
Expand All @@ -173,6 +195,19 @@ def index(self):
"""
return self._index

def _call_func(self):
"""
Call the wrapped function and return the result. Also updates
attributes like columns, index, and length.
"""
kwargs = _collect_injectables(self._arg_list)
frame = self._func(**kwargs)
self._columns = list(frame.columns)
self._index = frame.index
self._len = len(frame)
return frame

def to_frame(self, columns=None):
"""
Make a DataFrame with the given columns.
Expand All @@ -188,12 +223,8 @@ def to_frame(self, columns=None):
frame : pandas.DataFrame
"""
kwargs = _collect_injectables(self._arg_list)
frame = self._func(**kwargs)
self._columns = list(frame.columns)
self._index = frame.index
self._len = len(frame)
return _DataFrameWrapper(self.name, frame).to_frame(columns)
frame = self._call_func()
return DataFrameWrapper(self.name, frame).to_frame(columns)

def get_column(self, column_name):
"""
Expand All @@ -220,7 +251,7 @@ def __len__(self):
return self._len


class _TableSourceWrapper(_TableFuncWrapper):
class TableSourceWrapper(TableFuncWrapper):
"""
Wraps a function that returns a DataFrame. After the function
is evaluated the returned DataFrame replaces the function in the
Expand All @@ -232,6 +263,15 @@ class _TableSourceWrapper(_TableFuncWrapper):
func : callable
"""
def convert(self):
"""
Evaluate the wrapped function, store the returned DataFrame as a
table, and return the new DataFrameWrapper instance created.
"""
frame = self._call_func()
return add_table(self.name, frame)

def to_frame(self, columns=None):
"""
Make a DataFrame with the given columns. The first time this
Expand All @@ -249,10 +289,7 @@ def to_frame(self, columns=None):
frame : pandas.DataFrame
"""
kwargs = _collect_injectables(self._arg_list)
frame = self._func(**kwargs)
add_table(self.name, frame)
return _DataFrameWrapper(self.name, frame).to_frame(columns)
return self.convert().to_frame(columns)


class _ColumnFuncWrapper(object):
Expand Down Expand Up @@ -391,16 +428,22 @@ def add_table(table_name, table):
names will be matched to known tables, which will be injected
when this function is called.
Returns
-------
wrapped : `DataFrameWrapper` or `TableFuncWrapper`
"""
if isinstance(table, pd.DataFrame):
table = _DataFrameWrapper(table_name, table)
table = DataFrameWrapper(table_name, table)
elif isinstance(table, Callable):
table = _TableFuncWrapper(table_name, table)
table = TableFuncWrapper(table_name, table)
else:
raise TypeError('table must be DataFrame or function.')

_TABLES[table_name] = table

return table


def table(table_name):
"""
Expand Down Expand Up @@ -430,8 +473,14 @@ def add_table_source(table_name, func):
Function argument names will be matched to known injectables,
which will be injected when this function is called.
Returns
-------
wrapped : `TableSourceWrapper`
"""
_TABLES[table_name] = _TableSourceWrapper(table_name, func)
wrapped = TableSourceWrapper(table_name, func)
_TABLES[table_name] = wrapped
return wrapped


def table_source(table_name):
Expand All @@ -457,7 +506,7 @@ def get_table(table_name):
Returns
-------
table : _DataFrameWrapper or _TableFuncWrapper
table : `DataFrameWrapper`, `TableFuncWrapper`, or `TableSourceWrapper`
"""
if table_name in _TABLES:
Expand Down Expand Up @@ -755,7 +804,7 @@ def merge_tables(target, tables, columns=None):
----------
target : str
Name of the table onto which tables will be merged.
tables : list of _DataFrameWrapper or _TableFuncWrapper
tables : list of `DataFrameWrapper` or `TableFuncWrapper`
All of the tables to merge. Should include the target table.
columns : list of str, optional
If given, columns will be mapped to `tables` and only those columns
Expand Down
12 changes: 6 additions & 6 deletions urbansim/sim/tests/test_mergetables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@pytest.fixture
def dfa():
return sim._DataFrameWrapper('a', pd.DataFrame(
return sim.DataFrameWrapper('a', pd.DataFrame(
{'a1': [1, 2, 3],
'a2': [4, 5, 6],
'a3': [7, 8, 9]},
Expand All @@ -18,7 +18,7 @@ def dfa():

@pytest.fixture
def dfz():
return sim._DataFrameWrapper('z', pd.DataFrame(
return sim.DataFrameWrapper('z', pd.DataFrame(
{'z1': [90, 91],
'z2': [92, 93],
'z3': [94, 95],
Expand All @@ -29,7 +29,7 @@ def dfz():

@pytest.fixture
def dfb():
return sim._DataFrameWrapper('b', pd.DataFrame(
return sim.DataFrameWrapper('b', pd.DataFrame(
{'b1': range(10, 15),
'b2': range(15, 20),
'a_id': ['ac', 'ac', 'ab', 'aa', 'ab'],
Expand All @@ -39,7 +39,7 @@ def dfb():

@pytest.fixture
def dfc():
return sim._DataFrameWrapper('c', pd.DataFrame(
return sim.DataFrameWrapper('c', pd.DataFrame(
{'c1': range(20, 30),
'c2': range(30, 40),
'b_id': ['ba', 'bd', 'bb', 'bc', 'bb', 'ba', 'bb', 'bc', 'bd', 'bb']},
Expand All @@ -48,14 +48,14 @@ def dfc():

@pytest.fixture
def dfg():
return sim._DataFrameWrapper('g', pd.DataFrame(
return sim.DataFrameWrapper('g', pd.DataFrame(
{'g1': [1, 2, 3]},
index=['ga', 'gb', 'gc']))


@pytest.fixture
def dfh():
return sim._DataFrameWrapper('h', pd.DataFrame(
return sim.DataFrameWrapper('h', pd.DataFrame(
{'h1': range(10, 15),
'g_id': ['ga', 'gb', 'gc', 'ga', 'gb']},
index=['ha', 'hb', 'hc', 'hd', 'he']))
Expand Down
45 changes: 42 additions & 3 deletions urbansim/sim/tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def df():


def test_tables(df, clear_sim):
sim.add_table('test_frame', df)
wrapped_df = sim.add_table('test_frame', df)

@sim.table('test_func')
def test_func(test_frame):
Expand All @@ -33,7 +33,9 @@ def test_func(test_frame):
assert set(sim.list_tables()) == {'test_frame', 'test_func'}

table = sim.get_table('test_frame')
assert table is wrapped_df
assert table.columns == ['a', 'b']
assert table.local_columns == ['a', 'b']
assert len(table) == 3
pdt.assert_index_equal(table.index, df.index)
pdt.assert_series_equal(table.get_column('a'), df.a)
Expand Down Expand Up @@ -285,13 +287,50 @@ def source():
return df

table = sim.get_table('source')
assert isinstance(table, sim._TableSourceWrapper)
assert isinstance(table, sim.TableSourceWrapper)

test_df = table.to_frame()
pdt.assert_frame_equal(test_df, df)
assert table.columns == list(df.columns)
assert len(table) == len(df)
pdt.assert_index_equal(table.index, df.index)

table = sim.get_table('source')
assert isinstance(table, sim._DataFrameWrapper)
assert isinstance(table, sim.DataFrameWrapper)

test_df = table.to_frame()
pdt.assert_frame_equal(test_df, df)


def test_table_source_convert(clear_sim, df):
@sim.table_source('source')
def source():
return df

table = sim.get_table('source')
assert isinstance(table, sim.TableSourceWrapper)

table = table.convert()
assert isinstance(table, sim.DataFrameWrapper)
pdt.assert_frame_equal(table.to_frame(), df)

table2 = sim.get_table('source')
assert table2 is table


def test_table_func_local_cols(clear_sim, df):
@sim.table('table')
def table():
return df
sim.add_column('table', 'new', pd.Series(['a', 'b', 'c'], index=df.index))

assert sim.get_table('table').local_columns == ['a', 'b']


def test_table_source_local_cols(clear_sim, df):
@sim.table_source('source')
def source():
return df
sim.add_column('source', 'new', pd.Series(['a', 'b', 'c'], index=df.index))

assert sim.get_table('source').local_columns == ['a', 'b']

0 comments on commit c53635a

Please sign in to comment.