Skip to content

Commit

Permalink
Function for registering table merges as broadcasts of one onto the o…
Browse files Browse the repository at this point in the history
…ther.
  • Loading branch information
jiffyclub committed Jul 25, 2014
1 parent fb1eb8a commit 4bd0439
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
54 changes: 53 additions & 1 deletion urbansim/sim/simulation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import print_function

import inspect
from collections import Callable
from collections import Callable, namedtuple

import pandas as pd
import toolz

_TABLES = {}
_COLUMNS = {}
_MODELS = {}
_BROADCASTS = {}


def clear_sim():
Expand All @@ -19,6 +20,7 @@ def clear_sim():
_TABLES.clear()
_COLUMNS.clear()
_MODELS.clear()
_BROADCASTS.clear()


class _DataFrameWrapper(object):
Expand Down Expand Up @@ -490,3 +492,53 @@ def run(models, years=None):
print('Running model {}'.format(model_name))
model = get_model(model_name)
model(year=year)


_Broadcast = namedtuple(
'_Broadcast',
['cast', 'onto', 'cast_on', 'onto_on', 'cast_index', 'onto_index'])


def broadcast(cast, onto, cast_on=None, onto_on=None,
cast_index=False, onto_index=False):
"""
Register a rule for merging two tables by broadcasting one onto
the other.
Parameters
----------
cast, onto : str
Names of registered tables.
cast_on, onto_on : str, optional
Column names used for merge, equivalent of ``left_on``/``right_on``
parameters of pandas.merge.
cast_index, onto_index : bool, optional
Whether to use table indexes for merge. Equivalent of
``left_index``/``right_index`` parameters of pandas.merge.
"""
_BROADCASTS[(cast, onto)] = \
_Broadcast(cast, onto, cast_on, onto_on, cast_index, onto_index)


def _get_broadcasts(tables):
"""
Get the broadcasts associated with a set of tables.
Parameters
----------
tables : sequence of str
Table names for which broadcasts have been registered.
Returns
-------
casts : dict of `_Broadcast`
Keys are tuples of strings like (cast_name, onto_name).
"""
tables = set(tables)
casts = toolz.keyfilter(
lambda x: x[0] in tables and x[1] in tables, _BROADCASTS)
if tables - set(toolz.concat(casts.keys())):
raise ValueError('Not enough links to merge all tables.')
return casts
17 changes: 17 additions & 0 deletions urbansim/sim/tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,20 @@ def test_model2(test_table):
2000: [2012, 2015, 2018],
3000: [3012, 3017, 3024]},
index=['x', 'y', 'z']))


def test_get_broadcasts(clear_sim):
sim.broadcast('a', 'b')
sim.broadcast('b', 'c')
sim.broadcast('z', 'b')
sim.broadcast('f', 'g')

with pytest.raises(ValueError):
sim._get_broadcasts(['a', 'b', 'g'])

assert set(sim._get_broadcasts(['a', 'b', 'c', 'z']).keys()) == \
{('a', 'b'), ('b', 'c'), ('z', 'b')}
assert set(sim._get_broadcasts(['a', 'b', 'z']).keys()) == \
{('a', 'b'), ('z', 'b')}
assert set(sim._get_broadcasts(['a', 'b', 'c']).keys()) == \
{('a', 'b'), ('b', 'c')}

0 comments on commit 4bd0439

Please sign in to comment.