In [3]:
import numpy as np
import plotly.graph_objects as go
from mesa import Agent, Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector
import ipywidgets as widgets
from IPython.display import display

In [5]:
class Person(Agent):
    """Agent representing a person in the epidemic simulation."""
    def __init__(self, unique_id, model, age_group):
        super().__init__(unique_id, model)
        self.age_group = age_group
        self.state = "Susceptible"
        self.is_vaccinated = False

    def step(self):
        if not self.is_vaccinated or np.random.random() > 0.95:
            self.move()
            self.contact()

    def move(self):
        if not self.model.social_distancing or np.random.random() > self.model.compliance_rate:
            new_position = self.random.choice(self.model.grid.get_neighborhood(self.pos, moore=True, include_center=False))
            self.model.grid.move_agent(self, new_position)

    def contact(self):
        if self.state == "Infected":
            if np.random.random() < self.model.recovery_rate:
                self.state = "Recovered"
            else:
                neighbors = self.model.grid.get_neighbors(self.pos, moore=True, include_center=False)
                for neighbor in neighbors:
                    if neighbor.state == "Susceptible" and np.random.random() < self.model.infection_rate:
                        neighbor.state = "Infected"

In [9]:
class EpidemicModel(Model):
    """A model simulating an epidemic."""
    def __init__(self, N, width, height, infection_rate, recovery_rate, vaccination_start_day, compliance_rate):
        self.num_agents = N
        self.grid = MultiGrid(width, height, True)
        self.schedule = RandomActivation(self)
        self.infection_rate = infection_rate
        self.recovery_rate = recovery_rate
        self.vaccination_start_day = vaccination_start_day
        self.compliance_rate = compliance_rate
        self.day = 0
        self.social_distancing = False
        self.unvaccinated_agents = []

        for i in range(self.num_agents):
            age_group = np.random.choice(['young', 'adult', 'elderly'], p=[0.3, 0.5, 0.2])
            a = Person(i, self, age_group)
            self.schedule.add(a)
            self.grid.place_agent(a, (np.random.randint(0, width), np.random.randint(0, height)))
            self.unvaccinated_agents.append(a)
            if i < 10:
                a.state = "Infected"

        self.datacollector = DataCollector(
            model_reporters={
                "Susceptible": lambda m: sum(agent.state == "Susceptible" for agent in m.schedule.agents),
                "Infected": lambda m: sum(agent.state == "Infected" for agent in m.schedule.agents),
                "Recovered": lambda m: sum(agent.state == "Recovered" for agent in m.schedule.agents),
                "Vaccinated": lambda m: sum(agent.is_vaccinated for agent in m.schedule.agents)
            }
        )

    def step(self):
        self.day += 1
        if self.day >= self.vaccination_start_day:
            self.vaccinate_agents()
        self.datacollector.collect(self)
        self.schedule.step()
    def vaccinate_agents(self):
        vaccinated_today = 0
        for agent in sorted(self.unvaccinated_agents, key=lambda x: x.age_group, reverse=True):
            if not agent.is_vaccinated and np.random.random() < 0.5 and vaccinated_today < 10:
                agent.is_vaccinated = True
                vaccinated_today += 1
        self.unvaccinated_agents = [agent for agent in self.unvaccinated_agents if not agent.is_vaccinated]


In [11]:
def plot_results(data, vaccination_start_day):
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=data.index, y=data['Susceptible'], mode='lines+markers', name='Susceptible', line=dict(color='blue', width=2)))
    fig.add_trace(go.Scatter(x=data.index, y=data['Infected'], mode='lines+markers', name='Infected', line=dict(color='red', width=2, dash='dash')))
    fig.add_trace(go.Scatter(x=data.index, y=data['Recovered'], mode='lines+markers', name='Recovered', line=dict(color='green', width=2)))
    fig.add_trace(go.Scatter(x=data.index, y=data['Vaccinated'], mode='lines+markers', name='Vaccinated', line=dict(color='gold', width=2)))

    fig.update_layout(title="Epidemic Model Simulation", xaxis_title="Days", yaxis_title="Number of Agents", legend_title="Agent State",
                      hovermode="x", plot_bgcolor='rgba(245, 246, 249, 1)', xaxis_showgrid=False, yaxis_gridcolor='gray')
    
    # Annotating key points
    peak_infection_day = data['Infected'].idxmax()
    peak_infection_value = data['Infected'].max()
    fig.add_annotation(x=peak_infection_day, y=peak_infection_value, text="Peak Infection", showarrow=True, arrowhead=1)
    
    fig.add_vline(x=vaccination_start_day, line=dict(color='purple', dash='dash'), annotation_text="Vaccination Starts", annotation_position="bottom right")
    
    fig.show()


In [13]:
def interactive_model(infection_rate, recovery_rate, social_distancing_enabled, vaccination_start_day, compliance_rate, population):
    model = EpidemicModel(population, 40, 40, infection_rate, recovery_rate, vaccination_start_day, compliance_rate)
    model.social_distancing = social_distancing_enabled

    while model.day < 100:
        model.step()
    data = model.datacollector.get_model_vars_dataframe()
    plot_results(data, vaccination_start_day)

In [15]:
# Widgets for controlling parameters
population_slider = widgets.IntSlider(value=300, min=100, max=1000, step=50, description='Population:')
infection_rate_slider = widgets.FloatSlider(value=0.1, min=0.01, max=0.5, step=0.01, description='Infection Rate:')
recovery_rate_slider = widgets.FloatSlider(value=0.05, min=0.01, max=0.5, step=0.01, description='Recovery Rate:')
social_distancing_toggle = widgets.ToggleButton(value=False, description='Enable Social Distancing', button_style='success')
vaccination_day_slider = widgets.IntSlider(value=50, min=1, max=100, step=1, description='Vaccination Start Day:')
compliance_slider = widgets.FloatSlider(value=0.8, min=0.5, max=1.0, step=0.01, description='Compliance Rate:')
ui = widgets.VBox([population_slider, infection_rate_slider, recovery_rate_slider, social_distancing_toggle, vaccination_day_slider, compliance_slider])
out = widgets.interactive_output(interactive_model, {'infection_rate': infection_rate_slider, 'recovery_rate': recovery_rate_slider, 'social_distancing_enabled': social_distancing_toggle, 'vaccination_start_day': vaccination_day_slider, 'compliance_rate': compliance_slider, 'population': population_slider})

In [17]:
display(ui, out)

VBox(children=(IntSlider(value=300, description='Population:', max=1000, min=100, step=50), FloatSlider(value=…

Output(outputs=({'output_type': 'display_data', 'data': {'application/vnd.plotly.v1+json': {'data': [{'line': …