-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Very basic table registration and injection.
- Loading branch information
Showing
4 changed files
with
167 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import inspect | ||
from collections import Callable | ||
|
||
import pandas as pd | ||
|
||
_TABLES = {} | ||
|
||
|
||
class _DataFrameWrapper(object): | ||
""" | ||
Wraps a DataFrame so it can provide certain columns and handle | ||
computed columns. | ||
Parameters | ||
---------- | ||
name : str | ||
Name for the table. | ||
frame : pandas.DataFrame | ||
""" | ||
def __init__(self, name, frame): | ||
self.name = name | ||
self._frame = frame | ||
|
||
def to_frame(self, columns=None): | ||
""" | ||
Make a DataFrame with the given columns. | ||
Parameters | ||
---------- | ||
columns : sequence, optional | ||
Sequence of the column names desired in the DataFrame. | ||
If None all columns are returned. | ||
Returns | ||
------- | ||
frame : pandas.DataFrame | ||
""" | ||
return self._frame | ||
|
||
|
||
class _TableFuncWrapper(object): | ||
""" | ||
Wrap a function that provides a DataFrame. | ||
Parameters | ||
---------- | ||
name : str | ||
Name for the table. | ||
func : callable | ||
Callable that returns a DataFrame. | ||
""" | ||
def __init__(self, name, func): | ||
self.name = name | ||
self._func = func | ||
self._arg_list = inspect.getargspec(func).args | ||
|
||
def to_frame(self, columns=None): | ||
""" | ||
Make a DataFrame with the given columns. | ||
Parameters | ||
---------- | ||
columns : sequence, optional | ||
Sequence of the column names desired in the DataFrame. | ||
If None all columns are returned. | ||
Returns | ||
------- | ||
frame : pandas.DataFrame | ||
""" | ||
kwargs = {t: get_table(t) for t in self._arg_list} | ||
frame = self._func(**kwargs) | ||
return _DataFrameWrapper(self.name, frame).to_frame(columns) | ||
|
||
|
||
def add_table(table_name, table): | ||
""" | ||
Register a table with the simulation. | ||
Parameters | ||
---------- | ||
table_name : str | ||
Should be globally unique to this table. | ||
table : pandas.DataFrame or function | ||
If a function it should return a DataFrame. Function argument | ||
names will be matched to known tables, which will be injected | ||
when this function is called. | ||
""" | ||
if isinstance(table, pd.DataFrame): | ||
table = _DataFrameWrapper(table_name, table) | ||
elif isinstance(table, Callable): | ||
table = _TableFuncWrapper(table_name, table) | ||
else: | ||
raise TypeError('table must be DataFrame or function.') | ||
|
||
_TABLES[table_name] = table | ||
|
||
|
||
def table(table_name): | ||
""" | ||
Decorator version of `add_table` used for decorating functions | ||
that return DataFrames. | ||
Decorated function argument names will be matched to known tables, | ||
which will be injected when this function is called. | ||
""" | ||
def decorator(func): | ||
add_table(table_name, func) | ||
return func | ||
return decorator | ||
|
||
|
||
def get_table(table_name): | ||
""" | ||
Get a registered table. | ||
Parameters | ||
---------- | ||
table_name : str | ||
Returns | ||
------- | ||
table : _DataFrameWrapper or _TableFuncWrapper | ||
""" | ||
if table_name in _TABLES: | ||
return _TABLES[table_name] | ||
else: | ||
raise KeyError('table not found: {}'.format(table_name)) | ||
|
||
|
||
def list_tables(): | ||
""" | ||
List of table names. | ||
""" | ||
return list(_TABLES.keys()) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pandas as pd | ||
import pytest | ||
from pandas.util import testing as pdt | ||
|
||
from .. import simulation as sim | ||
|
||
|
||
@pytest.fixture | ||
def df(): | ||
return pd.DataFrame( | ||
{'a': [1, 2, 3], | ||
'b': [4, 5, 6]}) | ||
|
||
|
||
def test_tables(df): | ||
sim.add_table('test_frame', df) | ||
|
||
@sim.table('test_func') | ||
def test_func(test_frame): | ||
return test_frame.to_frame() / 2 | ||
|
||
table = sim.get_table('test_func') | ||
|
||
pdt.assert_frame_equal(table.to_frame(), df / 2) |