# Semester 3 Coding Portfolio Topic 6 Formative Part 2/2:
# Agent-Based Modeling and Managing Epidemics ü¶†ü¶†ü¶†

This notebook covers the following topics:
 - Disease modelling

This notebook is expected to take around 5 hours to complete:
 - 2 hours for the formative part
 - 3 hours of self-study on the topics covered by this notebook

<b>Formative section</b><br>
Simply complete the given functions such that they pass the automated tests. This part is graded Pass/Fail; you must get 100% correct!
You can submit your notebook through Canvas as often as you like. Make sure to start doing so early to insure that your code passes all tests!
You may ask for help from fellow students and TAs on this section, and solutions might be provided later on.

In [None]:
# TODO: Please enter your student number here
STUDENT_NUMBER = ...

In this notebook, we‚Äôll return to one of the defining global experiences of our time ‚Äî the COVID-19 pandemic ‚Äî to explore how we can use simple models to understand the spread of an epidemic and the impact of public policies.

As we‚Äôve discussed in the lectures, modeling is a powerful tool for thinking through complex social and biological processes.
It allows us to test interventions in a controlled, simulated environment ‚Äî asking ‚Äúwhat if‚Äù questions that would be impossible or unethical to test in the real world.

We‚Äôll start with basic epidemic models and gradually increase their complexity.
Along the way, we‚Äôll see how modeling can help us reason about the effects of policy choices ‚Äî such as distancing, vaccination, or lockdowns ‚Äî on the course of an outbreak.

By the end, you‚Äôll understand how these models can shed light on real-world decisions and why they remain central to public health and policy planning.

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

## 1. Mathematical/Compartmental model: The SIR Model

The **SIR model** is one of the simplest ways to describe how an infectious disease spreads through a population. It is not an ABM, but a mathematical model. It divides people into three groups:

* **S (Susceptible):** individuals who can catch the disease
* **I (Infected):** individuals who currently have the disease and can spread it
* **R (Recovered):** individuals who have recovered (or died) and can no longer spread the disease

The model assumes a fixed population size (N = S + I + R).
At each time step:

* Some susceptible people become infected, depending on how often they meet infected individuals and how contagious the disease is.
* Infected people recover at a certain rate.

The key parameters are:

* **beta (Œ≤):** transmission rate (how quickly the disease spreads)
* **gamma (Œ≥):** recovery rate (how quickly infected people recover)
* **R0 = beta / gamma:** the basic reproduction number ‚Äî the average number of people infected by one sick person in a fully susceptible population.

The code below uses a simple numerical simulation (Euler‚Äôs method) to model how S, I, and R change over time and plots the results, including how an intervention (reducing beta) affects the spread of the disease.


In [None]:
# --------------------------------------------------------
# SIMULATION FUNCTION (SIRD) # SIR MODEL + DEATH
# --------------------------------------------------------
def simulate_sird(beta=0.30, gamma=0.10, mu=0.005, N=1_000_000,
                  I0=10, R0_init=0, D0_init=0, days=180, dt=0.1,
                  beta_fn=None, mu_fn=None, use_live_population=True):
    """
    Simulate an SIRD model (Susceptible‚ÄìInfected‚ÄìRecovered‚ÄìDead)
    with explicit Euler time stepping.

    Parameters
    ----------
    beta : float
        Transmission rate per day.
    gamma : float
        Recovery rate per day (while infected).
    mu : float
        Death rate per day (while infected).
    N : int
        Initial total population (S + I + R + D at t=0).
    I0, R0_init, D0_init : int
        Initial counts of Infected, Recovered, and Dead.
    days : int
        Number of simulated days.
    dt : float
        Time step in days (smaller -> more accurate).
    beta_fn : callable or None
        Optional function beta(t) for time-varying transmission.
    mu_fn : callable or None
        Optional function mu(t) for time-varying mortality.
    use_live_population : bool
        If True, force of infection uses S+I+R (living only).
        If False, uses N (initial total), like classic mass-action SIR.

    Returns
    -------
    t : (T,) array of times (days)
    S, I, R, D : arrays of compartment sizes over time
    """

    # 1) Discretization
    steps = int(days / dt) + 1
    t = np.linspace(0, days, steps)

    # 2) State arrays
    S = np.zeros(steps)
    I = np.zeros(steps)
    R = np.zeros(steps)
    D = np.zeros(steps)

    # 3) Initial conditions
    S[0] = N - I0 - R0_init - D0_init
    I[0] = I0
    R[0] = R0_init
    D[0] = D0_init

    # 4) Time integration
    for k in range(steps - 1):
        # Allow time-varying beta and mu if functions are provided
        b = beta if beta_fn is None else beta_fn(t[k])
        m = mu   if mu_fn   is None else mu_fn(t[k])

        # Choose denominator for the infection term:
        #   live = S + I + R (nobody meets the dead)
        #   or classical = N (initial total)
        if use_live_population:
            N_den = max(S[k] + I[k] + R[k], 1e-12)
        else:
            N_den = max(S[0] + I[0] + R[0] + D[0], 1e-12)

        # --- SIRD ODEs (per-unit-time rates) ---
        # dS/dt = - beta * S * I / N_den
        # dI/dt =   beta * S * I / N_den - gamma * I - mu * I
        # dR/dt =   gamma * I
        # dD/dt =   mu * I
        dS = -b * S[k] * I[k] / N_den
        dI =  b * S[k] * I[k] / N_den - (gamma + m) * I[k]
        dR =  gamma * I[k]
        dD =  m * I[k]

        # Euler updates
        S[k+1] = S[k] + dS * dt
        I[k+1] = I[k] + dI * dt
        R[k+1] = R[k] + dR * dt
        D[k+1] = D[k] + dD * dt

        # Numerical safety: clamp tiny negatives from rounding
        S[k+1] = max(S[k+1], 0.0)
        I[k+1] = max(I[k+1], 0.0)
        R[k+1] = max(R[k+1], 0.0)
        D[k+1] = max(D[k+1], 0.0)

    return t, S, I, R, D


# --------------------------------------------------------
# PLOTTING FUNCTION
# --------------------------------------------------------
def plot_sird(t, S, I, R, D, title="SIRD model dynamics"):
    """
    Plot S, I, R, D over time.
    """
    plt.figure(figsize=(8, 4.8))
    plt.plot(t, S, label="Susceptible", color="tab:blue")
    plt.plot(t, I, label="Infected",    color="tab:red")
    plt.plot(t, R, label="Recovered",   color="tab:green")
    plt.plot(t, D, label="Dead",        color="tab:gray")
    plt.xlabel("Time (days)")
    plt.ylabel("Number of people")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Parameters (tweak!)
N = 1_000_000
beta = 0.30
gamma = 0.10
mu    = 0.005   # per-day mortality while infected (~0.5% per day)

# R0 in SIRD is beta / (gamma + mu)
R0 = beta / (gamma + mu)
print(f"R0 (SIRD) = beta / (gamma + mu) = {R0:.2f}")

# Run
t, S, I, R, D = simulate_sird(beta=beta, gamma=gamma, mu=mu, N=N, I0=10, days=200, dt=0.1)

# Plot
title = f"SIRD (R0 ‚âà {R0:.2f}, infectious period ‚âà {1/(gamma+mu):.1f} days)"
plot_sird(t, S, I, R, D, title=title)

### üßÆ Understanding the relationship between Œ≤, Œ≥, Œº, and R‚ÇÄ

In this extended **SIR** model (often called an **SIRD** model), three key parameters control how an epidemic evolves:

- **Œ≤ (beta)** ‚Äì the **transmission rate**: how many susceptible people an infected person infects per day.  
- **Œ≥ (gamma)** ‚Äì the **recovery rate**: the fraction of infected people who recover each day.  
- **Œº (mu)** ‚Äì the **mortality rate**: the fraction of infected people who die each day.

The **basic reproduction number (R‚ÇÄ)** tells us how many people, on average, each infected person will infect in a fully susceptible population.

For this model, the function is:

R0 = Œ≤ / (Œ≥+Œº)

As you may recall, this R0 value was a central discussion point during the COVID pandemic. It stemmed largely from the SIR model. 


## Exercise 1: How does the deadliness of the disease affect how many people die? 

Try to vary the deadliness of the disease and see how it affects how many people die, keeping the other parameters fixed. 

Do a parameter sweep to see how varying the deadliness changes how many people die. 

Explain the result! 


In [None]:

# TODO: Complete the sweep to explore how deadliness (mu) affects total deaths.
def deadliness_to_total_death(
    N=0,
    beta=0.0,  # transmission rate per day
    gamma=0.0,  # recovery rate per day
    I0=0,
    days=0.0,
    dt=0.0,
):
    """
    Sweep mu values and return:
      - deaths_pct: total deaths as % of population
      - final_attack_rate: share ever infected
      - R0_vals: basic reproduction numbers at each mu
      - IFR_vals: infection fatality ratios at each mu
      - mu_vals: grid of mu values explored
    """
    # TODO: create a grid of mu, simulate with simulate_sird, and collect the metrics above
    mu_vals = None
    deaths_pct = None
    final_attack_rate = None
    R0_vals = None
    IFR_vals = None
    return deaths_pct, final_attack_rate, R0_vals, IFR_vals, mu_vals

#. Your solution here ...



## üè• Exercise 2: Policy response 1 - Flatten the curve? 

In this exercise, we will explore the idea of ‚Äúflattening the curve.‚Äù

As you may recall, the core idea is that hospitals have limited capacity ‚Äî only a certain fraction of the population can receive treatment at the same time.
If the number of severely ill patients exceeds that capacity, not everyone can get the care they need, and the mortality rate increases for those left untreated.

In our model:
- A fixed share of infected individuals (for example, 5%) require hospital care.
- The health system can only care for a given maximum number of patients (capacity).
- When hospital demand exceeds that limit, patients outside the system face a higher death rate.
- By lowering R‚ÇÄ (for instance, through distancing, lockdowns, or mask use), we slow the spread of infections.
- This can keep the number of hospitalized cases below the system‚Äôs capacity, preventing overload and ultimately saving lives ‚Äî even if the total number of infections stays similar.

We will:
- Extend the SIRD model to include hospital capacity and overflow deaths.

Simulate two scenarios:
- Baseline: No intervention (high R‚ÇÄ)
- Policy: R‚ÇÄ reduced by 30% 

**Key question: How many fewer people die when R‚ÇÄ is reduced by 30%?**

The goal is to understand why strength of interventions matter and how reducing transmission can save lives not only by preventing infections, but also by avoiding the collapse of the healthcare system.

In [None]:

# TODO: Implement the SIRD model with hospital capacity and overload mortality.

def simulate_sird_capacity(
    beta=0.30,
    gamma=0.10,
    mu_base=0.004,
    mu_over=0.020,
    hospital_need_frac=0.05,
    capacity_frac=0.004,
    N=1_000_000,
    I0=50,
    days=240,
    dt=0.1,
):
    """
    Simulate SIRD with hospital capacity.
    Returns a dictionary with keys: 't', 'S', 'I', 'R', 'D', 'H', 'overload'.
    """
    # TODO: set up arrays, compute hospital load H = h * I, and update with Euler steps.
    return {}


#. Your solution here ...




**Visualization helpers (provided)**  
These functions plot the epidemic with capacity and summarize outcomes so you can inspect your results quickly after implementing simulate_sird_capacity.

In [None]:

def plot_epidemic_with_capacity(res, title="SIRD with capacity"):
    t, S, I, R, D = res["t"], res["S"], res["I"], res["R"], res["D"]
    H, C = res["H"], res.get("capacity", None)

    plt.figure(figsize=(9, 4.6))
    plt.plot(t, S, label="S")
    plt.plot(t, I, label="I")
    plt.plot(t, R, label="R")
    plt.plot(t, D, label="D")
    plt.title(title)
    plt.xlabel("Day")
    plt.ylabel("People")
    plt.legend()
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(9, 4.2))
    plt.plot(t, H, label="Hospital load (h * I)")
    if C is not None:
        plt.axhline(C, ls="--", label="Capacity")
    plt.title("Hospital load vs capacity")
    plt.xlabel("Day")
    plt.ylabel("Patients in care")
    plt.legend()
    plt.tight_layout()
    plt.show()


def summarize_capacity_outcomes(res):
    t, I, D, H = res["t"], res["I"], res["D"], res["H"]
    C = res.get("capacity", None)
    total_deaths = D[-1]
    peak_I = I.max()
    peak_day = t[I.argmax()]
    peak_H = H.max()
    days_over = np.sum(H > C) * (t[1] - t[0]) if C is not None else 0.0
    print(f"Total deaths: {total_deaths:,.0f}")
    print(f"Peak infected: {peak_I:,.0f} on day {peak_day:.1f}")
    print(f"Peak hospital load: {peak_H:,.0f}")
    if C is not None:
        print(f"Capacity: {C:,.0f}; approx days over capacity: {days_over:.1f}")

## Policy Response 2: "Focused Protection" 

Focused Protection was a famous and hotly debated alternative to "flattening the curve". It was most prominently advocated in the Great Barrington Declaration (October 2020), a statement by a small group of scientists arguing that governments should protect only the elderly and other high-risk groups, while allowing younger, lower-risk individuals to live normally and acquire immunity through infection ‚Äî sometimes described as pursuing *natural herd immunity*.

How do we study this?

This hits at a key limitation of the SIR model - and indeed nearly any mathematical model: the homogeneous mixing assumption. The SIR model treats everyone as equal. Real life isn‚Äôt like that ‚Äî contacts are networked and uneven (families, classes, offices, ‚Äúhubs‚Äù/superspreaders). Some people are more vulnerable than others. And some are more likely to get sick then others. 

So let's go to a simple ABM to address this! 


## New Model: ABM of Epidemic

This ABM models the world as a lattice on which agents are moving randomly around, and disease spread when they crash into each other. 

It is a very simple model, but it does allow us to capture population heterogeneity.

In [None]:
# =====================================================
# Minimal moving-dots epidemic ABM 
# =====================================================

rng = np.random.default_rng(1)

# --- core parameters (tweak freely) ---
N = 300                 # number of agents
init_infected = 5       # initially infected
box = 100.0             # square [0, box] x [0, box]
steps = 700             # number of frames (time steps)
dt = 1.0                # step size
speed = 1.0             # constant speed for all living agents
radius = 2.0            # contact radius (same for all)

# disease parameters (per step)
p_trans   = 0.15        # infection probability per contact per step
p_recover = 0.05        # recovery probability per step
p_die     = 0.002       # death probability per step (while infected)

# state codes and colors
SUS, INF, REC, DEA = 0, 1, 2, 3
COLORS = {SUS:"#1f77b4", INF:"#d62728", REC:"#2ca02c", DEA:"#7f7f7f"}

# ---------- initialization ----------
pos = rng.uniform(0, box, size=(N, 2))
ang = rng.uniform(0, 2*np.pi, size=N)
vel = np.c_[np.cos(ang), np.sin(ang)] * speed

state = np.full(N, SUS, dtype=np.int8)
state[rng.choice(N, size=init_infected, replace=False)] = INF

S_hist, I_hist, R_hist, D_hist = [], [], [], []

# ---------- helpers ----------
def move_and_reflect(pos, vel, L, dt):
    pos += vel * dt
    # bounce on walls (elastic)
    hit_left = pos[:,0] < 0;    pos[hit_left,0] = -pos[hit_left,0];     vel[hit_left,0] *= -1
    hit_right= pos[:,0] > L;    pos[hit_right,0]= 2*L - pos[hit_right,0];vel[hit_right,0]*= -1
    hit_bot  = pos[:,1] < 0;    pos[hit_bot,1]  = -pos[hit_bot,1];      vel[hit_bot,1]  *= -1
    hit_top  = pos[:,1] > L;    pos[hit_top,1]  = 2*L - pos[hit_top,1];  vel[hit_top,1]  *= -1

def infection_step(pos, state, radius, p_trans):
    sus_idx = np.where(state == SUS)[0]
    inf_idx = np.where(state == INF)[0]
    if len(sus_idx) == 0 or len(inf_idx) == 0:
        return np.zeros_like(state, dtype=bool)

    sus_pos = pos[sus_idx]
    inf_pos = pos[inf_idx]

    # pairwise distance^2 between each susceptible and infected
    diff = sus_pos[:, None, :] - inf_pos[None, :, :]
    dist2 = np.sum(diff*diff, axis=2)

    # contact if distance <= 2*radius
    contact_thresh2 = (2*radius)**2
    k_contacts = np.sum(dist2 <= contact_thresh2, axis=1)  # how many infected within range

    # P(get infected) = 1 - (1 - p_trans)^k
    p_inf = 1.0 - (1.0 - p_trans)**k_contacts
    draws = rng.random(len(sus_idx)) < p_inf

    new_inf_mask = np.zeros_like(state, dtype=bool)
    new_inf_mask[sus_idx[draws]] = True
    return new_inf_mask

def disease_progress(state, p_recover, p_die):
    infected = (state == INF)
    if not np.any(infected):
        return np.zeros_like(state, dtype=bool), np.zeros_like(state, dtype=bool)
    # deaths first, then recoveries among remaining infected
    die_draw = (rng.random(len(state)) < p_die) & infected
    still_inf = infected & (~die_draw)
    rec_draw = (rng.random(len(state)) < p_recover) & still_inf
    return die_draw, rec_draw

# ---------- animation ----------
fig, ax = plt.subplots(figsize=(6.8, 6.8))
ax.set_xlim(0, box); ax.set_ylim(0, box); ax.set_aspect('equal', adjustable='box')
ax.set_title("Minimal random-movement ABM")
scat_S = ax.scatter([], [], s=[], c=COLORS[SUS], label="S")
scat_I = ax.scatter([], [], s=[], c=COLORS[INF], label="I")
scat_R = ax.scatter([], [], s=[], c=COLORS[REC], label="R")
scat_D = ax.scatter([], [], s=[], c=COLORS[DEA], label="D")
ax.legend(loc="upper right", frameon=False)
txt = ax.text(0.02, 0.98, "", transform=ax.transAxes, va="top", ha="left")

# marker size proportional to area of the interaction radius
marker_area = 6.0 * (radius**2)

def update(frame):
    # dead agents don't move or interact
    vel[state == DEA] = 0.0

    # 1) move and bounce
    move_and_reflect(pos, vel, box, dt)

    # 2) infections
    new_inf = infection_step(pos, state, radius, p_trans)

    # 3) disease progression
    die, rec = disease_progress(state, p_recover, p_die)

    # 4) apply updates
    state[new_inf] = INF
    state[die] = DEA
    state[rec] = REC

    # 5) record counts
    S_hist.append(np.sum(state == SUS))
    I_hist.append(np.sum(state == INF))
    R_hist.append(np.sum(state == REC))
    D_hist.append(np.sum(state == DEA))

    # 6) update scatter artists
    for scat, st in [(scat_S, SUS), (scat_I, INF), (scat_R, REC), (scat_D, DEA)]:
        idx = np.where(state == st)[0]
        offsets = pos[idx] if len(idx) else np.empty((0,2))
        sizes = np.full(len(idx), marker_area)
        scat.set_offsets(offsets)
        scat.set_sizes(sizes)

    txt.set_text(f"S:{S_hist[-1]}  I:{I_hist[-1]}  R:{R_hist[-1]}  D:{D_hist[-1]}  t={frame}")
    return scat_S, scat_I, scat_R, scat_D, txt

anim = FuncAnimation(fig, update, frames=steps, interval=30, blit=True)
HTML(anim.to_jshtml())


In [None]:
t = np.arange(len(S_hist))
plt.figure(figsize=(8,4.5))
plt.plot(t, S_hist, label="S")
plt.plot(t, I_hist, label="I")
plt.plot(t, R_hist, label="R")
plt.plot(t, D_hist, label="D")
plt.xlabel("Time step"); plt.ylabel("Agents")
plt.title("Minimal ABM: S/I/R/D over time")
plt.legend(); plt.tight_layout(); plt.show()


**Play with the model**: how does the model dynamics differ from the SIRD model? What else can we learn from it? 

## Exercise 3: Heterogeneous population: Young and Old

Your task is to introduce a separation between young and old agents, so that we can see the effects of interventions targeting specifically the older cohort.

Building on the model above, create a model that has young and old agents.

Old agents are less likely to be infected, but more likely to die if they do get infected. 

We will use this to test the effect of "focused protection: a lock-down of the old population!

In [None]:

# TODO: Implement the ABM with age structure (young vs old).

def simulate_abm_age(
    N=300,
    frac_old=0.20,
    init_infected=5,
    box=100.0,
    steps=700,
    dt=1.0,
    speed=1.0,
    radius=2.0,
    p_recover=0.05,
    p_trans_young=0.15,
    p_trans_old=0.10,
    p_die_young=0.002,
    p_die_old=0.010,
    seed=1,
):
    """
    Simulate ABM with age structure (Young/Old).
    Returns a dictionary with keys: 'S', 'I', 'R', 'D', 'D_old', 'D_young' (and others if needed)
    """
    # TODO: move agents, handle infections and age-specific disease progression
    return {}

#. Your solution here ...



**Helper functions**  
Use these building blocks for movement, infection, and animation when you implement simulate_abm_age or explore the model.

In [None]:
SUS, INF, REC, DEA = 0, 1, 2, 3
COLORS = {SUS: "#1f77b4", INF: "#d62728", REC: "#2ca02c", DEA: "#7f7f7f"}


def _move_and_reflect(pos, vel, L, dt):
    pos += vel * dt
    hit_left = pos[:, 0] < 0
    pos[hit_left, 0] = -pos[hit_left, 0]
    vel[hit_left, 0] *= -1
    hit_right = pos[:, 0] > L
    pos[hit_right, 0] = 2 * L - pos[hit_right, 0]
    vel[hit_right, 0] *= -1
    hit_bot = pos[:, 1] < 0
    pos[hit_bot, 1] = -pos[hit_bot, 1]
    vel[hit_bot, 1] *= -1
    hit_top = pos[:, 1] > L
    pos[hit_top, 1] = 2 * L - pos[hit_top, 1]
    vel[hit_top, 1] *= -1


def _infection_step_age(pos, state, radius, p_trans_young, p_trans_old, age_is_old, rng):
    sus_idx = np.where(state == SUS)[0]
    inf_idx = np.where(state == INF)[0]
    if len(sus_idx) == 0 or len(inf_idx) == 0:
        return np.zeros_like(state, dtype=bool)

    diff = pos[sus_idx][:, None, :] - pos[inf_idx][None, :, :]
    dist2 = np.sum(diff * diff, axis=2)
    contact2 = (2 * radius) ** 2
    k_contacts = np.sum(dist2 <= contact2, axis=1)

    p_sus = np.where(age_is_old[sus_idx], p_trans_old, p_trans_young)
    p_inf = 1.0 - np.power((1.0 - p_sus), k_contacts)

    draws = rng.random(len(sus_idx)) < p_inf
    new_inf_mask = np.zeros_like(state, dtype=bool)
    new_inf_mask[sus_idx[draws]] = True
    return new_inf_mask


def _disease_progress_age(state, age_is_old, p_recover, p_die_young, p_die_old, rng):
    infected = state == INF
    if not np.any(infected):
        return np.zeros_like(state, dtype=bool), np.zeros_like(state, dtype=bool)
    p_die = np.where(age_is_old, p_die_old, p_die_young)
    die_draw = (rng.random(len(state)) < p_die) & infected
    still_inf = infected & (~die_draw)
    rec_draw = (rng.random(len(state)) < p_recover) & still_inf
    return die_draw, rec_draw


def animate_abm_age(
    N=300,
    frac_old=0.20,
    init_infected=5,
    box=100.0,
    steps=400,
    dt=1.0,
    speed=1.0,
    radius=2.0,
    p_recover=0.05,
    p_trans_young=0.15,
    p_trans_old=0.10,
    p_die_young=0.002,
    p_die_old=0.010,
    seed=1,
    fps=30,
):
    rng = np.random.default_rng(seed)

    age_is_old = rng.random(N) < frac_old
    pos = rng.uniform(0, box, size=(N, 2))
    ang = rng.uniform(0, 2 * np.pi, size=N)
    vel = np.c_[np.cos(ang), np.sin(ang)] * speed
    state = np.full(N, SUS, dtype=np.int8)
    state[rng.choice(N, size=min(init_infected, N), replace=False)] = INF

    fig, ax = plt.subplots(figsize=(6.6, 6.6))
    ax.set_xlim(0, box)
    ax.set_ylim(0, box)
    ax.set_aspect('equal', adjustable='box')
    ax.set_title("ABM with Old vs Young (helpers provided)")
    scatter_kwargs = dict(s=6 * (radius ** 2))
    scat_S = ax.scatter([], [], c=COLORS[SUS], label="S", **scatter_kwargs)
    scat_I = ax.scatter([], [], c=COLORS[INF], label="I", **scatter_kwargs)
    scat_R = ax.scatter([], [], c=COLORS[REC], label="R", **scatter_kwargs)
    scat_D = ax.scatter([], [], c=COLORS[DEA], label="D", **scatter_kwargs)
    ax.legend(loc="upper right", frameon=False)
    txt = ax.text(0.02, 0.98, "", transform=ax.transAxes, va="top", ha="left")

    S_hist = []
    I_hist = []
    R_hist = []
    D_hist = []
    D_old_hist = []
    D_young_hist = []

    def update(frame):
        vel[state == DEA] = 0.0
        _move_and_reflect(pos, vel, box, dt)
        new_inf = _infection_step_age(pos, state, radius, p_trans_young, p_trans_old, age_is_old, rng)
        die, rec = _disease_progress_age(state, age_is_old, p_recover, p_die_young, p_die_old, rng)

        state[new_inf] = INF
        state[die] = DEA
        state[rec] = REC

        for scat, st in [(scat_S, SUS), (scat_I, INF), (scat_R, REC), (scat_D, DEA)]:
            idx = np.where(state == st)[0]
            scat.set_offsets(pos[idx] if len(idx) else np.empty((0, 2)))

        S_hist.append(np.sum(state == SUS))
        I_hist.append(np.sum(state == INF))
        R_hist.append(np.sum(state == REC))
        D_hist.append(np.sum(state == DEA))
        D_old_hist.append(np.sum((state == DEA) & age_is_old))
        D_young_hist.append(np.sum((state == DEA) & (~age_is_old)))

        txt.set_text(
            f"S:{S_hist[-1]}  I:{I_hist[-1]}  R:{R_hist[-1]}  D:{D_hist[-1]}  "
            f"(old D:{D_old_hist[-1]}, young D:{D_young_hist[-1]})  t={frame}"
        )
        return scat_S, scat_I, scat_R, scat_D, txt

    anim = FuncAnimation(fig, update, frames=steps, interval=1000 / fps, blit=True)
    return HTML(anim.to_jshtml())

In [None]:
# Here is code for running the new version

res = simulate_abm_age(
    N=300, frac_old=0.20, init_infected=5,
    p_trans_young=0.15, p_trans_old=0.10,
    p_die_young=0.002, p_die_old=0.010,
    steps=600, seed=42
)

t = np.arange(len(res["S"]))
plt.figure(figsize=(8,4.5))
plt.plot(t, res["S"], label="S")
plt.plot(t, res["I"], label="I")
plt.plot(t, res["R"], label="R")
plt.plot(t, res["D"], label="D")
plt.xlabel("Time step"); plt.ylabel("Agents")
plt.title("ABM (old vs young): S/I/R/D over time")
plt.legend(); plt.tight_layout(); plt.show()


 ## Examining "Focused Protection"

 Let's use the model to examine the effects of focused protection. Draw on the code below to analyze the model dynamics, and under which conditions the strategy can save lives.
  

In [None]:
# ------------------------------------------------------
# Simple comparison: Focused Protection vs Everyone Less
# ------------------------------------------------------
# Uses simulate_abm_age(...) from previous cell!

def run_once_scenario(scenario, seed):
    """Run a single simulation under one scenario."""
    base = dict(
        N=300, frac_old=0.2, steps=400,
        p_trans_young=0.15, p_trans_old=0.10,
        p_die_young=0.002, p_die_old=0.010,
        seed=seed
    )

    # EXAMPLE! PLay with these parameters! 
    if scenario == "default":
        pass
    elif scenario == "everyone_less":
        base["p_trans_young"] *= 0.8
        base["p_trans_old"]   *= 0.8
    elif scenario == "focused_protection":
        base["p_trans_young"] *= 1.1
        base["p_trans_old"]   *= 0.5 
    else:
        raise ValueError("Unknown scenario")

    res = simulate_abm_age(**base)
    return res["D"][-1], res["D_old"][-1], res["D_young"][-1]


# ---- run a few replicates for each scenario ----
scenarios = ["default", "everyone_less", "focused_protection"]
runs = 50
rng = np.random.default_rng(42)

results = {s: {"total": [], "old": [], "young": []} for s in scenarios}

for s in scenarios:
    for i in range(runs):
        D_tot, D_old, D_yng = run_once_scenario(s, seed=int(rng.integers(1e9)))
        results[s]["total"].append(D_tot)
        results[s]["old"].append(D_old)
        results[s]["young"].append(D_yng)

# ---- compute means ----
means_total = [np.mean(results[s]["total"]) for s in scenarios]
means_old   = [np.mean(results[s]["old"])   for s in scenarios]
means_young = [np.mean(results[s]["young"]) for s in scenarios]

# ---- plot results ----
x = np.arange(len(scenarios))
plt.figure(figsize=(8,4))
plt.bar(x - 0.2, means_young, width=0.4, label="Young deaths")
plt.bar(x + 0.2, means_old,   width=0.4, label="Old deaths")
plt.xticks(x, ["Default", "Everyone less", "Focused protection"])
plt.ylabel("Average deaths")
plt.title("Focused Protection vs Everyone Less")
plt.legend()
plt.tight_layout()
plt.show()

# ---- print small summary ----
print("Average total deaths:")
for s, m in zip(scenarios, means_total):
    print(f"{s:>18}: {m:6.1f}")


What is your conclusion from your model? Is it a good strategy? What are potential weaknesses/limitations? 

What are limitations of the model in answering this question? 