Skip to content

Commit

Permalink
Remove check for NaN flags in MNLLocationChoiceModel
Browse files Browse the repository at this point in the history
When doing LCM prediction, it will be up to the
user to pass in only the choosers who are picking new
locations. Both RelcationModel and TransitionModel return
indexes of choosers, those can be concatenated to to make
an indexer of choosers.
  • Loading branch information
jiffyclub committed May 16, 2014
1 parent a12d911 commit 0ab213b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 19 deletions.
12 changes: 1 addition & 11 deletions urbansim/models/lcm.py
Expand Up @@ -79,11 +79,6 @@ class MNLLocationChoiceModel(object):
A patsy model expression. Should contain only a right-hand side.
sample_size : int
Number of choices to sample for estimating the model.
location_id_col : str, optional
Name of a column in the choosers table that corresponds to the
index of the location being chosen. If given, this is used to
make sure that during prediction only choosers that have nan
in this column choose new alternatives.
choosers_fit_filters : list of str, optional
Filters applied to choosers table before fitting the model.
choosers_predict_filters : list of str, optional
Expand All @@ -108,15 +103,14 @@ class MNLLocationChoiceModel(object):
in output.
"""
def __init__(self, model_expression, sample_size, location_id_col=None,
def __init__(self, model_expression, sample_size,
choosers_fit_filters=None, choosers_predict_filters=None,
alts_fit_filters=None, alts_predict_filters=None,
interaction_predict_filters=None,
estimation_sample_size=None,
choice_column=None, name=None):
self.model_expression = model_expression
self.sample_size = sample_size
self.location_id_col = location_id_col
self.choosers_fit_filters = choosers_fit_filters
self.choosers_predict_filters = choosers_predict_filters
self.alts_fit_filters = alts_fit_filters
Expand Down Expand Up @@ -152,7 +146,6 @@ def from_yaml(cls, yaml_str=None, str_or_buffer=None):
model = cls(
cfg['model_expression'],
cfg['sample_size'],
location_id_col=cfg.get('location_id_col', None),
choosers_fit_filters=cfg.get('choosers_fit_filters', None),
choosers_predict_filters=cfg.get('choosers_predict_filters', None),
alts_fit_filters=cfg.get('alts_fit_filters', None),
Expand Down Expand Up @@ -290,8 +283,6 @@ def predict(self, choosers, alternatives):
"""
self.assert_fitted()

if self.location_id_col:
choosers = choosers[choosers[self.location_id_col].isnull()]
choosers = util.apply_filter_query(
choosers, self.choosers_predict_filters)
alternatives = util.apply_filter_query(
Expand Down Expand Up @@ -331,7 +322,6 @@ def to_dict(self):
'model_expression': self.model_expression,
'sample_size': self.sample_size,
'name': self.name,
'location_id_col': self.location_id_col,
'choosers_fit_filters': self.choosers_fit_filters,
'choosers_predict_filters': self.choosers_predict_filters,
'alts_fit_filters': self.alts_fit_filters,
Expand Down
11 changes: 3 additions & 8 deletions urbansim/models/tests/test_lcm.py
Expand Up @@ -58,7 +58,6 @@ def test_unit_choice_none_available(choosers, alternatives):
def test_mnl_lcm(choosers, alternatives):
model_exp = 'var2 + var1:var3'
sample_size = 5
location_id_col = 'thing_id'
choosers_fit_filters = ['var1 != 5']
choosers_predict_filters = ['var1 != 7']
alts_fit_filters = ['var3 != 15']
Expand All @@ -69,7 +68,7 @@ def test_mnl_lcm(choosers, alternatives):
name = 'Test LCM'

model = lcm.MNLLocationChoiceModel(
model_exp, sample_size, location_id_col,
model_exp, sample_size,
choosers_fit_filters, choosers_predict_filters,
alts_fit_filters, alts_predict_filters,
interaction_predict_filters, estimation_sample_size,
Expand All @@ -83,10 +82,7 @@ def test_mnl_lcm(choosers, alternatives):
assert len(model.fit_parameters) == 2
assert len(model.fit_parameters.columns) == 3

choosers.thing_id = np.nan
choosers.thing_id.iloc[0] = 'a'

choices = model.predict(choosers, alternatives)
choices = model.predict(choosers.iloc[1:], alternatives)

pdt.assert_index_equal(choices.index, pd.Index([1, 3, 4]))
assert choices.isin(alternatives.index).all()
Expand All @@ -102,7 +98,6 @@ def test_mnl_lcm(choosers, alternatives):
def test_mnl_lcm_repeated_alts(choosers, alternatives):
model_exp = 'var2 + var1:var3'
sample_size = 5
location_id_col = None
choosers_fit_filters = ['var1 != 5']
choosers_predict_filters = ['var1 != 7']
alts_fit_filters = ['var3 != 15']
Expand All @@ -113,7 +108,7 @@ def test_mnl_lcm_repeated_alts(choosers, alternatives):
name = 'Test LCM'

model = lcm.MNLLocationChoiceModel(
model_exp, sample_size, location_id_col,
model_exp, sample_size,
choosers_fit_filters, choosers_predict_filters,
alts_fit_filters, alts_predict_filters,
interaction_predict_filters, estimation_sample_size,
Expand Down

0 comments on commit 0ab213b

Please sign in to comment.