diff --git a/mesa/batchrunner.py b/mesa/batchrunner.py index b86f44845f4..4c3a49153b7 100644 --- a/mesa/batchrunner.py +++ b/mesa/batchrunner.py @@ -45,7 +45,7 @@ class BatchRunner: entire DataCollector object. """ - def __init__(self, model_cls, variable_parameters=None, + def __init__(self, model_cls, variable_parameters={}, fixed_parameters=None, iterations=1, max_steps=1000, model_reporters=None, agent_reporters=None, display_progress=True): """ Create a new BatchRunner for a given model with the given @@ -110,31 +110,50 @@ def _process_parameters(self, params): def run_all(self): """ Run the model at all parameter combinations and store results. """ - param_names, param_ranges = zip(*self.variable_parameters.items()) run_count = count() total_iterations = self.iterations - for param_range in param_ranges: - total_iterations *= len(param_range) - with tqdm(total_iterations, disable=not self.display_progress) as pbar: - for param_values in product(*param_ranges): - kwargs = dict(zip(param_names, param_values)) - kwargs.update(self.fixed_parameters) + if len(self.variable_parameters.keys()) > 0: + param_names, param_ranges = zip(*self.variable_parameters.items()) + for param_range in param_ranges: + total_iterations *= len(param_range) + + with tqdm(total_iterations, disable=not self.display_progress) as pbar: + for param_values in product(*param_ranges): + kwargs = dict(zip(param_names, param_values)) + kwargs.update(self.fixed_parameters) + + for _ in range(self.iterations): + self.run_iteration(kwargs, param_values, run_count) + pbar.update() + else: + kwargs = self.fixed_parameters + param_values = None + + with tqdm(total_iterations, disable=not self.display_progress) as pbar: for _ in range(self.iterations): - kwargscopy = copy.deepcopy(kwargs) - model = self.model_cls(**kwargscopy) - self.run_model(model) - # Collect and store results: - model_key = param_values + (next(run_count),) - if self.model_reporters: - self.model_vars[model_key] = self.collect_model_vars(model) - if self.agent_reporters: - agent_vars = self.collect_agent_vars(model) - for agent_id, reports in agent_vars.items(): - agent_key = model_key + (agent_id,) - self.agent_vars[agent_key] = reports + self.run_iteration(kwargs, param_values, run_count) pbar.update() + def run_iteration(self, kwargs, param_values, run_count): + kwargscopy = copy.deepcopy(kwargs) + model = self.model_cls(**kwargscopy) + self.run_model(model) + + # Collect and store results: + if param_values is not None: + model_key = param_values + (next(run_count),) + else: + model_key = (next(run_count),) + + if self.model_reporters: + self.model_vars[model_key] = self.collect_model_vars(model) + if self.agent_reporters: + agent_vars = self.collect_agent_vars(model) + for agent_id, reports in agent_vars.items(): + agent_key = model_key + (agent_id,) + self.agent_vars[agent_key] = reports + def run_model(self, model): """ Run a model object to completion, or until reaching max steps.