From a0b0158bf7214d2b0cfc8beee45867a53c2f3666 Mon Sep 17 00:00:00 2001 From: Corvince Date: Thu, 9 Nov 2017 14:28:26 +0100 Subject: [PATCH 1/3] Added include_fixed parameter --- mesa/batchrunner.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mesa/batchrunner.py b/mesa/batchrunner.py index e5f23652317..5fa7c384ba0 100644 --- a/mesa/batchrunner.py +++ b/mesa/batchrunner.py @@ -160,12 +160,15 @@ def collect_agent_vars(self, model): agent_vars[agent.unique_id] = agent_record return agent_vars - def get_model_vars_dataframe(self): + def get_model_vars_dataframe(self, include_fixed=False): """ Generate a pandas DataFrame from the model-level variables collected. + Args: + include_fixed: Set True to include fixed parameters """ - return self._prepare_report_table(self.model_vars) + return self._prepare_report_table(self.model_vars, + include_fixed=include_fixed) def get_agent_vars_dataframe(self): """ Generate a pandas DataFrame from the agent-level variables @@ -175,7 +178,8 @@ def get_agent_vars_dataframe(self): return self._prepare_report_table(self.agent_vars, extra_cols=['AgentId']) - def _prepare_report_table(self, vars_dict, extra_cols=None): + def _prepare_report_table(self, vars_dict, extra_cols=None, + include_fixed=False): """ Creates a dataframe from collected records and sorts it using 'Run' column as a key. @@ -193,4 +197,7 @@ def _prepare_report_table(self, vars_dict, extra_cols=None): rest_cols = set(df.columns) - set(index_cols) ordered = df[index_cols + list(sorted(rest_cols))] ordered.sort_values(by='Run', inplace=True) + if include_fixed: + for param in self.fixed_parameters.keys(): + ordered[param] = self.fixed_parameters[param] return ordered From b9976012eed25b02c34d6749c9ce039215f3e080 Mon Sep 17 00:00:00 2001 From: Corvince Date: Wed, 6 Dec 2017 15:11:55 +0100 Subject: [PATCH 2/3] New approach that defaults to include fixed parameters --- mesa/batchrunner.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/mesa/batchrunner.py b/mesa/batchrunner.py index 5fa7c384ba0..c10a4570107 100644 --- a/mesa/batchrunner.py +++ b/mesa/batchrunner.py @@ -82,6 +82,7 @@ def __init__(self, model_cls, variable_parameters=None, self.model_cls = model_cls self.variable_parameters = self._process_parameters(variable_parameters) self.fixed_parameters = fixed_parameters or {} + self._include_fixed = len(self.fixed_parameters.keys()) > 0 self.iterations = iterations self.max_steps = max_steps @@ -160,15 +161,12 @@ def collect_agent_vars(self, model): agent_vars[agent.unique_id] = agent_record return agent_vars - def get_model_vars_dataframe(self, include_fixed=False): + def get_model_vars_dataframe(self): """ Generate a pandas DataFrame from the model-level variables collected. - Args: - include_fixed: Set True to include fixed parameters """ - return self._prepare_report_table(self.model_vars, - include_fixed=include_fixed) + return self._prepare_report_table(self.model_vars) def get_agent_vars_dataframe(self): """ Generate a pandas DataFrame from the agent-level variables @@ -178,8 +176,7 @@ def get_agent_vars_dataframe(self): return self._prepare_report_table(self.agent_vars, extra_cols=['AgentId']) - def _prepare_report_table(self, vars_dict, extra_cols=None, - include_fixed=False): + def _prepare_report_table(self, vars_dict, extra_cols=None): """ Creates a dataframe from collected records and sorts it using 'Run' column as a key. @@ -197,7 +194,7 @@ def _prepare_report_table(self, vars_dict, extra_cols=None, rest_cols = set(df.columns) - set(index_cols) ordered = df[index_cols + list(sorted(rest_cols))] ordered.sort_values(by='Run', inplace=True) - if include_fixed: + if self._include_fixed: for param in self.fixed_parameters.keys(): ordered[param] = self.fixed_parameters[param] return ordered From 59edb4eebacc632d513d301092478b6965dac820 Mon Sep 17 00:00:00 2001 From: Corvince Date: Wed, 6 Dec 2017 15:21:44 +0100 Subject: [PATCH 3/3] Fixed test --- tests/test_batchrunner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_batchrunner.py b/tests/test_batchrunner.py index e9b9904b9f8..d815f30b95e 100644 --- a/tests/test_batchrunner.py +++ b/tests/test_batchrunner.py @@ -155,6 +155,7 @@ def test_model_with_variable_and_fixed_kwargs(self): batch = self.launch_batch_processing() model_vars = batch.get_model_vars_dataframe() expected_cols = (len(self.variable_params) + + len(self.fixed_params) + len(self.model_reporters) + 1) self.assertEqual(model_vars.shape, (self.model_runs, expected_cols))