In [None]:
import numpy as np

from mesa import Agent, Model
from mesa.space import Grid
from mesa.datacollection import DataCollector
from mesa.time import SimultaneousActivation
import matplotlib.pyplot as plt
from mesa.visualization.modules import CanvasGrid, ChartModule
from mesa.visualization.ModularVisualization import ModularServer

# Helper functions

In [None]:
def get_sus(model):
    state = 0
    for (contents, i, j) in model.grid.coord_iter():
        if contents.state == 'susceptible':
            state += 1
    return state

def get_inf(model):
    state = 0
    for (contents, i, j) in model.grid.coord_iter():
        if contents.state == 'infected':
            state += 1
    return state

def get_rec(model):
    state = 0
    for (contents, i, j) in model.grid.coord_iter():
        if contents.state == 'recovered':
            state += 1
    return state

# Cellular Automaton class

In [None]:
class CellularAutomaton(Model):
    
    def __init__(
        self,
        height = 50,
        width = 50,
        p = 0.2,
        q = 0.3,
        s = 0
    ):
        self.height = height
        self.width = width
        
        self.p = p
        self.q = q
        self.s = s
        
        self.schedule = SimultaneousActivation(self)
        self.grid = Grid(self.width, self.height, torus=False)

        self.datacollector = DataCollector(
            model_reporters =  {
                "sus": get_sus,
                "inf": get_inf,
                "rec": get_rec
                }
            )

        for (contents, i, j) in self.grid.coord_iter():
            FSM = FiniteStateMachine((i,j), self)
            if np.random.random() < 0.005:
                FSM.state = "infected"
            self.grid.place_agent(FSM, (i,j))
            self.schedule.add(FSM)

        self.running = True
        self.datacollector.collect(self)

    def step(self):
        self.schedule.step()
        self.datacollector.collect(self)

    def run(self, n):
        for _ in range(n):
            self.step()

# Finite State Machine class

In [None]:
class FiniteStateMachine(Agent):
    def __init__(self, pos, model):
        super().__init__(pos, model)
        self.x, self.y = pos

        self.state = "susceptible"
        self._nextState = None
        self.counter = 0

    def step(self):
        neighbors = self.model.grid.get_neighbors((self.x,self.y), True)
        
        R = sum([1 for neighbor in neighbors if neighbor.state == "infected"])
        
        if self.state == "susceptible":
            if np.random.random() < 1 - (1 - self.model.p)**R:
                self._nextState = "infected"
            else:
                self._nextState = "susceptible"
        elif self.state == "infected":
            if np.random.random() < self.model.q:
                self._nextState = "recovered"
            else:
                self._nextState = "infected"
        else:
            self.counter += 1
            if self.counter > 4 and np.random.random() < self.model.s:
                self.counter = 0
                self._nextState = "susceptible"
            else:
                self._nextState = "recovered"
        
    def advance(self):
        self.state = self._nextState

# Manually run model

In [None]:
model = CellularAutomaton()

model.run(30)
    
df = model.datacollector.get_model_vars_dataframe()
df.plot()

# MESA server

In [None]:
def portrayCell(fsm):
    assert fsm is not None
    portrayal = {
        "Shape": "rect",
        "w":1,
        "h":1,
        "Filled": "true",
        "Layer": 0,
        "x": fsm.x,
        "y": fsm.y,
    }
    if fsm.state == "susceptible":
        portrayal["Color"] = "yellow"
    elif fsm.state == "infected":
        portrayal["Color"] = "red"
    else:
        portrayal["Color"] = "white"

    return portrayal

chart = ChartModule(
    [{"Label": "sus", "Color": "#AA0000"},
     {"Label": "inf", "Color": "#00aa1f"}, 
     {"Label": "rec", "Color": "#8500aa"}],
     canvas_height=300, canvas_width=500
)

SIZE = 50

canvas_element = CanvasGrid(portrayCell, SIZE, SIZE, 500, 500 )
server = ModularServer(
    CellularAutomaton, [canvas_element,chart], "Epidemic", {"height": SIZE, "width":SIZE}
    )

server.launch()