Skip to content

Commit

Permalink
Modify and test datacollector
Browse files Browse the repository at this point in the history
Adding and testing tables
  • Loading branch information
dmasad committed Apr 28, 2015
1 parent ca1e91c commit 0037f07
Show file tree
Hide file tree
Showing 4 changed files with 1,367 additions and 1,129 deletions.
1,151 changes: 594 additions & 557 deletions examples/ForestFire/.ipynb_checkpoints/Forest Fire Model-checkpoint.ipynb

Large diffs are not rendered by default.

1,151 changes: 594 additions & 557 deletions examples/ForestFire/Forest Fire Model.ipynb

Large diffs are not rendered by default.

100 changes: 85 additions & 15 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,27 @@
=====================================================
DataCollector is meant to provide a simple, standard way to collect data
generated by a Mesa model. It collects two types of data: model-level and
agent-level data.
generated by a Mesa model. It collects three types of data: model-level data,
agent-level data, and tables.
A DataCollector is instantiated with two dictionaries of reporter names and
associated functions for each, one for model-level data and one for
agent-level data. When the collect() method is called, each model-level
function is called, with the model as the argument, and the results associated
with the relevant variable. Then the agent-level functions are called on each
agent-level data; a third dictionary provides table names and columns.
When the collect() method is called, each model-level function is called, with
the model as the argument, and the results associated with the relevant
variable. Then the agent-level functions are called on each
agent in the model scheduler.
The DataCollector then stores the data it collects in two dictionaries:
Additionally, other objects can write directly to tables by passing in an
appropriate dictionary object for a table row.
The DataCollector then stores the data it collects in dictionaries:
* model_vars maps each reporter to a list of its values
* agent_vars maps each reporter to a list of lists, where each nested list
stores (agent_id, value) pairs.
* tables maps each table to a dictionary, with each column as a key with a
list as its value.
Finally, DataCollector can create a pandas DataFrame from each collection.
Expand All @@ -43,10 +50,11 @@ class DataCollector(object):

model_vars = {}
agent_vars = {}
tables = {}

model = None

def __init__(self, model_reporters=None, agent_reporters=None):
def __init__(self, model_reporters={}, agent_reporters={}, tables={}):
'''
Instantiate a DataCollector with lists of model and agent reporters.
Expand All @@ -59,21 +67,62 @@ def __init__(self, model_reporters=None, agent_reporters=None):
it might look like this:
{"energy": lambda a: a.energy}
The tables arg accepts a dictionary mapping names of tables to lists of
columns. For example, if we want to allow agents to write their age
when they are destroyed (to keep track of lifespans), it might look
like:
{"Lifespan": ["unique_id", "age"]}
Args:
model_reporters: Dictionary of reporter names and functions.
agent_reporters: Dictionary of reporter names and functions.
'''

self.model_reporters = model_reporters
self.agent_reporters = agent_reporters
self.model_reporters = {}
self.agent_reporters = {}
self.tables = {}

if self.model_reporters:
for var in self.model_reporters:
self.model_vars[var] = []
for name, func in model_reporters.items():
self._new_model_reporter(name, func)

if self.agent_reporters:
for var in self.agent_reporters:
self.agent_vars[var] = []
for name, func in agent_reporters.items():
self._new_agent_reporter(name, func)

for name, columns in tables.items():
self._new_table(name, columns)

def _new_model_reporter(self, reporter_name, reporter_function):
'''
Add a new model-level reporter to collect.
Args:
reporter_name: Name of the model-level variable to collect.
reporter_function: Function object that returns the variable when
given a model instance.
'''

self.model_reporters[reporter_name] = reporter_function
self.model_vars[reporter_name] = []

def _new_agent_reporter(self, reporter_name, reporter_function):
'''
Add a new agent-level reporter to collect.
Args:
reporter_name: Name of the agent-level variable to collect.
reporter_function: Function object that returns the variable when
given an agent object.
'''
self.agent_reporters[reporter_name] = reporter_function
self.agent_vars[reporter_name] = []

def _new_table(self, table_name, table_columns):
'''
Add a new table that objects can write to.
Args:
table_name: Name of the new table.
table_columns: List of columns to add to the table.
'''
new_table = {column: [] for column in table_columns}
self.tables[table_name] = new_table

def collect(self, model):
'''
Expand All @@ -90,6 +139,27 @@ def collect(self, model):
agent_records.append((agent.unique_id, reporter(agent)))
self.agent_vars[var].append(agent_records)

def add_table_row(self, table_name, row, ignore_missing=False):
'''
Add a row dictionary to a specific table.
Args:
table_name: Name of the table to append a row to.
row: A dictionary of the form {column_name: value...}
ignore_missing: If True, fill any missing columns with Nones;
if False, throw an error if any columns are missing
'''
if table_name not in self.tables:
raise Exception("Table does not exist.")

for column in self.tables[table_name]:
if column in row:
self.tables[table_name][column].append(row[column])
elif ignore_missing:
self.tables[table_name][column].append(None)
else:
raise Exception("Could not insert row with missing column")

def get_model_vars_dataframe(self):
'''
Create a pandas DataFrame from the model variables.
Expand Down
94 changes: 94 additions & 0 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
'''
Test the DataCollector
'''
import unittest

from mesa import Model, Agent
from mesa.time import BaseScheduler
from mesa.datacollection import DataCollector


class MockAgent(Agent):
'''
Minimalistic agent for testing purposes.
'''
def __init__(self, unique_id, val):
self.unique_id = unique_id
self.val = val

def step(self, model):
'''
Increment val by 1.
'''
self.val += 1

def write_final_values(self, model):
'''
Write the final value to the appropriate table.
'''
row = {"agent_id": self.unique_id, "final_value": self.val}
model.datacollector.add_table_row("Final_Values", row)


class MockModel(Model):
'''
Minimalistic model for testing purposes.
'''

schedule = BaseScheduler(None)

def __init__(self):
self.schedule = BaseScheduler(self)
for i in range(10):
a = MockAgent(i, i)
self.schedule.add(a)
self.datacollector = DataCollector(
{"total_agents": lambda m: m.schedule.get_agent_count()},
{"value": lambda a: a.val},
{"Final_Values": ["agent_id", "final_value"]})

def step(self):
self.schedule.step()
self.datacollector.collect(self)


class TestDataCollector(unittest.TestCase):
def setUp(self):
'''
Create the model and run it a set number of steps.
'''
self.model = MockModel()
for i in range(7):
self.model.step()
# Write to table:
for agent in self.model.schedule.agents:
agent.write_final_values(self.model)

def test_model_vars(self):
data_collector = self.model.datacollector
assert "total_agents" in data_collector.model_vars
assert len(data_collector.model_vars["total_agents"]) == 7
for element in data_collector.model_vars["total_agents"]:
assert element == 10

def test_agent_vars(self):
data_collector = self.model.datacollector
assert len(data_collector.agent_vars["value"]) == 7
for step in data_collector.agent_vars["value"]:
assert len(step) == 10
for record in step:
assert len(record) == 2

def test_table_rows(self):
data_collector = self.model.datacollector
assert len(data_collector.tables["Final_Values"]) == 2
assert "agent_id" in data_collector.tables["Final_Values"]
assert "final_value" in data_collector.tables["Final_Values"]
for key, data in data_collector.tables["Final_Values"].items():
assert len(data) == 10

with self.assertRaises(Exception):
data_collector.add_table_row("error_table", {})

with self.assertRaises(Exception):
data_collector.add_table_row("Final_Values", {"final_value": 10})

0 comments on commit 0037f07

Please sign in to comment.