Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions mesa/batchrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BatchRunner:
entire DataCollector object.

"""
def __init__(self, model_cls, variable_parameters=None,
def __init__(self, model_cls, variable_parameters={},
Copy link
Member

@dmasad dmasad Apr 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's generally not a good practice to use a mutable value (e.g. {}) as a default argument. In this case I don't think anything harmful would happen, but I'd suggest leaving None as the default value just in case, and for general Pythonic-ness. Then you can add something like if not variable_parameters: variable_parameters = {}

fixed_parameters=None, iterations=1, max_steps=1000,
model_reporters=None, agent_reporters=None, display_progress=True):
""" Create a new BatchRunner for a given model with the given
Expand Down Expand Up @@ -110,31 +110,50 @@ def _process_parameters(self, params):

def run_all(self):
""" Run the model at all parameter combinations and store results. """
param_names, param_ranges = zip(*self.variable_parameters.items())
run_count = count()
total_iterations = self.iterations
for param_range in param_ranges:
total_iterations *= len(param_range)
with tqdm(total_iterations, disable=not self.display_progress) as pbar:
for param_values in product(*param_ranges):
kwargs = dict(zip(param_names, param_values))
kwargs.update(self.fixed_parameters)

if len(self.variable_parameters.keys()) > 0:
param_names, param_ranges = zip(*self.variable_parameters.items())
for param_range in param_ranges:
total_iterations *= len(param_range)

with tqdm(total_iterations, disable=not self.display_progress) as pbar:
for param_values in product(*param_ranges):
kwargs = dict(zip(param_names, param_values))
kwargs.update(self.fixed_parameters)

for _ in range(self.iterations):
self.run_iteration(kwargs, param_values, run_count)
pbar.update()
else:
kwargs = self.fixed_parameters
param_values = None

with tqdm(total_iterations, disable=not self.display_progress) as pbar:
for _ in range(self.iterations):
kwargscopy = copy.deepcopy(kwargs)
model = self.model_cls(**kwargscopy)
self.run_model(model)
# Collect and store results:
model_key = param_values + (next(run_count),)
if self.model_reporters:
self.model_vars[model_key] = self.collect_model_vars(model)
if self.agent_reporters:
agent_vars = self.collect_agent_vars(model)
for agent_id, reports in agent_vars.items():
agent_key = model_key + (agent_id,)
self.agent_vars[agent_key] = reports
self.run_iteration(kwargs, param_values, run_count)
pbar.update()

def run_iteration(self, kwargs, param_values, run_count):
kwargscopy = copy.deepcopy(kwargs)
model = self.model_cls(**kwargscopy)
self.run_model(model)

# Collect and store results:
if param_values is not None:
model_key = param_values + (next(run_count),)
else:
model_key = (next(run_count),)

if self.model_reporters:
self.model_vars[model_key] = self.collect_model_vars(model)
if self.agent_reporters:
agent_vars = self.collect_agent_vars(model)
for agent_id, reports in agent_vars.items():
agent_key = model_key + (agent_id,)
self.agent_vars[agent_key] = reports

def run_model(self, model):
""" Run a model object to completion, or until reaching max steps.

Expand Down