# Herd Immunity Modeling

In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import itertools
import time
from IPython.display import clear_output
from ipywidgets import widgets, interactive, interact, interact_manual, HBox, Layout, VBox, Label

## Rules:

1. Number of agents is **GRID_WIDTH * GRID_HEIGHT = NUM_AGENTS**. Each point in the rectangular grid is occupied by an agent.

2. Agents are randomly assigned an **IMMUNE** status, either 0 (susceptible) or 1 (immune), at the beginning. **NUM_IMMUNE_START = PERC_IMMUNE_START * NUM_AGENTS**

3. Agents are initially randomly flagged as sick, with **NUM_SICK_START = PERC_SICK_START * NUM_AGENTS.** Immune agents cannot be initially flagged as sick.

4. All agents that aren't immune or sick are set as initially healthy.

5. At each timestep, agents connect with some of the agents around them, with probability of connecting to each agent within their **CONNECTION_DISTANCE** being **PROB_SOCIAL** . This determines the candidate pool a given agent could connect with:

   1. If **CONNECTION_DISTANCE = 1**, agents have a candidate pool of  <= 8, that is, the agents directly around them in a square with themself at the center.

   2. If **CONNECTION_DISTANCE = 2,** agents have a candidate pool of <= 16, that is, the agents directly around them in 2 concentric squares with themself at the center.

   3. etc. with **CONNECTION_DISTANCE <= MAX(GRID_WIDTH, GRID_HEIGHT) - 1**

6. If a sick agent connects with a healthy agent who is not immune, probability of infection is **PROB_INFECTION.** If a sick agent connects with an immune agent, probability of infection is zero (could change later accounting for imperfect vaccination success rate).

7. After infections occur, each agent has a chance to recover with **PROB_RECOVERY**. Sick agents then have a chance to die with **PROB_DEATH**.

In [2]:
def num_to_idxs(num : int, params : dict) -> tuple:
    '''Convert from linear index to 2d indices'''
    return (num//params['GRID_WIDTH'], num%params['GRID_WIDTH'])

In [3]:
def idxs_to_num(down : int, across : int, params : dict) -> int:
    '''Convert from 2d indices to linear index'''
    return down * params['GRID_WIDTH'] + across

In [4]:
def get_probabilities(N : int, exclude : np.ndarray=None) -> np.ndarray:
    '''Given an array length N and a boolean array exclude of length N,
    return an array of length N containing 0 probability where exclude == True
    and equal probabilities elsewhere.'''
    probs = np.ones(N)
    if exclude is not None:
        assert N == len(exclude), 'exclude must be same length as N'
        probs[exclude] = 0
    probs[probs == 1] = 1 / sum(probs)
    assert np.isclose(sum(probs), 1.0)
    return probs

In [5]:
def initialize_array(N : int, num_ones : int, exclude : np.ndarray=None) -> np.ndarray:
    '''Create a 1D array of length N where num_ones 1's are distributed
    randomly throughout the array, excluding indices where exclude == True.
    Other entries are 0.
    '''
    if exclude is not None: exclude = exclude.astype(bool)

    # Randomly select locations for ones, incorporating exclusions
    probs = get_probabilities(N, exclude=exclude)
    selections = np.random.choice(range(N), size=num_ones, replace=False, p=probs)
    arr = np.zeros(N)
    arr[selections] = 1
    
    assert np.sum(arr) == num_ones
    return arr

In [6]:
def derive_params(setup):
    '''Checks input parameters in setup and initializes other stuff'''
    
    params = {}
    params['timestep'] = 0
    params.update(setup)
    
    # Derived constants
    params['NUM_AGENTS'] = params['GRID_WIDTH'] * params['GRID_HEIGHT']
    assert params['PERC_IMMUNE_START'] + params['PERC_SICK_START'] <=1, 'Impossible!'
    params['MAX_REASONABLE_CONNECTION_DISTANCE'] = max(params['GRID_WIDTH'], params['GRID_HEIGHT']) - 1
    if params['CONNECTION_DISTANCE'] > params['MAX_REASONABLE_CONNECTION_DISTANCE']: params['CONNECTION_DISTANCE'] = params['MAX_REASONABLE_CONNECTION_DISTANCE']
    params['NUM_IMMUNE_START'] = int(np.round(params['PERC_IMMUNE_START'] * params['NUM_AGENTS'], 0))
    params['NUM_SICK_START'] = int(np.round(params['PERC_SICK_START'] * params['NUM_AGENTS'], 0))
    params['ALL_IDXS'] = np.array(list(itertools.product(range(params['GRID_HEIGHT']), range(params['GRID_WIDTH']))))
    params['sociability_arr'] = np.ones(params['NUM_AGENTS']) * params['PROB_SOCIAL']

    # Initialize immunity
    params['immune_arr'] = initialize_array(
        N=params['NUM_AGENTS'],
        num_ones=params['NUM_IMMUNE_START'],
    )
    
    # Initialize sickness
    params['sick_arr'] = initialize_array(
        N=params['NUM_AGENTS'],
        num_ones=params['NUM_SICK_START'],
        exclude=params['immune_arr']
    )
    
    # Initialize health
    params['healthy_arr'] = np.ones(params['NUM_AGENTS'])
    params['healthy_arr'] = params['healthy_arr'] - params['sick_arr']
    
    # Initialize death
    params['dead_arr'] = np.zeros(params['NUM_AGENTS'])

    return params

In [7]:
def get_neighbors(idx : int, params : dict) -> np.ndarray:
    '''Takes in linear index and returns of tuple of adjacent 2d indices numbers according grid params'''
    down, across = num_to_idxs(idx, params)
    
    # Determine bounds according to CONNECTION_DISTANCE
    low_horiz_bound = across - params['CONNECTION_DISTANCE']
    high_horiz_bound = across + params['CONNECTION_DISTANCE']
    low_vert_bound = down - params['CONNECTION_DISTANCE']
    high_vert_bound = down + params['CONNECTION_DISTANCE']

    # Vectorized for speed
    neighbors = params['ALL_IDXS'][
        (low_vert_bound <= params['ALL_IDXS'][:,0])
        & (params['ALL_IDXS'][:,0] <= high_vert_bound)
        & (low_horiz_bound <= params['ALL_IDXS'][:,1])
        & (params['ALL_IDXS'][:,1] <= high_horiz_bound)
    ]
    
    # Transform neighboring indices back to agent numbers
    neighbor_nums = idxs_to_num(neighbors[:,0], neighbors[:,1], params)
    neighbor_nums = neighbor_nums[neighbor_nums != idx] # exclude input
    
    return neighbor_nums

In [8]:
def get_interaction_matrix(params):
    '''Agents (linear indices) interact with their neighbors according to input params. Create an interaction matrix.
    
    Note: Bidirectional interactions because either agent can initialize an interaction and once interacting they affect one another IRL
    '''
    
    # Setup interaction matrix
    interaction_matrix = np.zeros((params['NUM_AGENTS'], params['NUM_AGENTS']))
    
    # Build graph
    for sociability, agent in zip(params['sociability_arr'], range(params['NUM_AGENTS'])):
        neighbors = get_neighbors(agent, params)
        interaction_idxs = np.random.binomial(1, sociability, size=len(neighbors))
        neighbors_interacted_with = neighbors[interaction_idxs.astype(bool)]
        interaction_matrix[agent, neighbors_interacted_with] = 1
        
    return interaction_matrix

In [9]:
def set_sick(got_sick_idxs, params):
    params['sick_arr'][got_sick_idxs] = 1
    params['healthy_arr'][got_sick_idxs] = 0
    return params

def set_healthy(got_better_idxs, params):
    '''Input got_better_idxs array, modify params to reflect'''
    params['sick_arr'][got_better_idxs] = 0
    params['healthy_arr'][got_better_idxs] = 1
    
    if params['IMMUNE_AFTER_RECOVERY']:
        params['immune_arr'][got_better_idxs] = 1

    return params

def set_dead(died_idxs, params):
    '''Input died_idxs array, modify params to reflect'''
    params['sick_arr'][died_idxs] = 0
    params['dead_arr'][died_idxs] = 1
    return params

def get_sick_agents(params):
    return np.where(params['sick_arr'] == 1)[0]

def get_exposed_agents(sick_agent, interaction_matrix):
    return np.where(interaction_matrix[sick_agent,:] == 1)[0]

def reshape_to_2d(arr, params):
    return arr.reshape(params['GRID_HEIGHT'], params['GRID_WIDTH'])

In [10]:
def do_recovery(params):
    sick_agents = get_sick_agents(params)
    got_better = np.random.binomial(1, params['PROB_RECOVERY'], size=len(sick_agents)).astype(bool)
    params['got_better_idxs'] = list(set(sick_agents[got_better]))
    params = set_healthy(params['got_better_idxs'], params)
    return params

def do_death(params):
    sick_agents = get_sick_agents(params)
    died = np.random.binomial(1, params['PROB_DEATH'], size=len(sick_agents)).astype(bool)
    params['died_idxs'] = list(set(sick_agents[died]))
    params = set_dead(params['died_idxs'], params)
    return params

In [11]:
def get_possible_sick(params):

    # Obtain matrix of interactions for timestep
    interaction_matrix = get_interaction_matrix(params)

    sick_agents = get_sick_agents(params)
    sick_to_risks = {}
    for sick_agent in sick_agents:
        # Healthy, alive, and non-immune agents that interacted with sick agents are at_risk if they're not immune or dead
        exposed_agents = get_exposed_agents(sick_agent, interaction_matrix)
        healthy = params['healthy_arr'][exposed_agents].astype(bool)
        immune = params['immune_arr'][exposed_agents].astype(bool)
        dead = params['dead_arr'][exposed_agents].astype(bool)
        at_risk_agents = exposed_agents[(healthy) & (~immune) & (~dead)]
        sick_to_risks[sick_agent] = at_risk_agents

    return sick_to_risks

In [12]:
def risk_to_sick(params):
    got_sick_idxs = []
    for _, at_risk in params['sick_to_risks'].items():
        # at_risk agents become sick with PROB_INFECTION
        got_sick = np.random.binomial(1, params['PROB_INFECTION'], size=len(at_risk)).astype(bool)
        got_sick_idxs += list(at_risk[got_sick])

    # Set agents to sick after all interactions so no weird chaining effects
    got_sick_idxs = list(set(got_sick_idxs))
    params = set_sick(got_sick_idxs, params)
    params['got_sick_idxs'] = got_sick_idxs
    return params

In [13]:
def do_timestep(params):
    time.sleep(params['WAIT'])

    params['timestep'] += 1

    ## interact
    params['sick_to_risks'] = get_possible_sick(params)

    # plot board + lines
    ax = plot_interactions(params)
    time.sleep(params['WAIT'])

    ## make sick
    params = risk_to_sick(params)

    # plot X's + O's
    plot_infections(params)
    time.sleep(params['WAIT'])

    # plot new state
    ax = plot(params)
    time.sleep(params['WAIT'])

    ## if P(RECOVER) > 0: recover
    if params['PROB_RECOVERY'] > 0:
        params = do_recovery(params)

    ## if P(DIE) > 0: die
    if params['PROB_DEATH'] > 0:
        params = do_death(params)

    # plot D's + R's
    if params['PROB_RECOVERY'] > 0 or params['PROB_DEATH'] > 0:
        plot_fates(params)
        time.sleep(params['WAIT'])
        # plot new state
        ax = plot(params)
    
    return params

In [14]:
def plot(params, show=True):
    if params['CLEAR_PLOTS']: clear_output(wait=True)
    
    plot_matrix = np.zeros((params['GRID_HEIGHT'], params['GRID_WIDTH']))
    dead_matrix = reshape_to_2d(params['dead_arr'], params)
    plot_matrix[dead_matrix == 1] = 4
    sick_matrix = reshape_to_2d(params['sick_arr'], params)
    plot_matrix[sick_matrix == 1] = 3
    healthy_matrix = reshape_to_2d(params['healthy_arr'], params)
    plot_matrix[healthy_matrix == 1] = 2
    immune_matrix = reshape_to_2d(params['immune_arr'], params)
    plot_matrix[immune_matrix == 1] = 1
    
    bounds = [0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]
    cmap = colors.ListedColormap(['white', params['COLOR_IMMUNE'], params['COLOR_HEALTHY'], params['COLOR_SICK'], params['COLOR_DEAD']])
    norm = colors.BoundaryNorm(bounds, cmap.N)

    # Set up plot
    fig, ax = plt.subplots(figsize=(params['GRID_WIDTH']//2+1, params['GRID_HEIGHT']//2+1))
    for tic in list(ax.xaxis.get_major_ticks()) + list(ax.yaxis.get_major_ticks()):
        tic.tick1line.set_visible(False)
        tic.tick2line.set_visible(False)
    ax.grid(which='major', axis='both', linestyle='-', color='k', linewidth=2)
    ax.set_xlim(left=-0.5, right=params['GRID_WIDTH']-0.5)
    ax.set_ylim(top=-0.5, bottom=params['GRID_WIDTH']-0.5)
    ax.set_xticks(np.arange(-0.5, params['GRID_WIDTH'], 1));
    ax.set_yticks(np.arange(-0.5, params['GRID_HEIGHT'], 1));
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(params['timestep'], fontsize=24)
    
    legend_elements = [
        Patch(facecolor=params['COLOR_IMMUNE'], label='Immune'),
        Patch(facecolor=params['COLOR_HEALTHY'], label='Healthy'),
        Patch(facecolor=params['COLOR_SICK'], label='Sick'),
        Patch(facecolor=params['COLOR_DEAD'], label='Dead'),
    ]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.01,1), loc="upper left", fontsize=24)

    # Show current board state
    ax.imshow(plot_matrix, cmap=cmap, norm=norm)
    if show: plt.show()
    return ax

In [15]:
def plot_interactions(params):
    if params['CLEAR_PLOTS']: clear_output(wait=True)
    ax = plot(params, show=False)
    
    # Show connections
    for sick_agent in params['sick_to_risks'].keys():
        sick_down, sick_across = num_to_idxs(sick_agent, params)
        at_risk_agents = params['sick_to_risks'][sick_agent]
        for at_risk_agent in at_risk_agents:
            at_risk_down, at_risk_across = num_to_idxs(at_risk_agent, params)
            ax.plot([at_risk_across, sick_across], [at_risk_down, sick_down], color='pink', lw=3, marker='o', markerfacecolor='black', markeredgecolor='black')
    plt.show()

In [16]:
def plot_infections(params):
    if params['CLEAR_PLOTS']: clear_output(wait=True)
    ax = plot(params, show=False)
    
    # Get agents that stayed healthy and got sick after interacting with a sick agent
    all_candidates = []
    for candidates in params['sick_to_risks'].values():
        all_candidates += list(candidates)
    all_candidates = set(all_candidates)
    
    got_sick = set(params['got_sick_idxs'])
    stayed_healthy = np.array(list(all_candidates.difference(got_sick)))
    got_sick = np.array(list(got_sick))
    
    got_sick_down, got_sick_across = num_to_idxs(got_sick, params)
    healthy_down, healthy_across = num_to_idxs(stayed_healthy, params)
    
    ax.plot(got_sick_across, got_sick_down, marker='X', markersize=20, ls='', color='k')
    ax.plot(healthy_across, healthy_down, marker='o', markersize=20, ls='', color='k')
    plt.show()

In [17]:
def plot_fates(params):
    if params['CLEAR_PLOTS']: clear_output(wait=True)
    ax = plot(params, show=False)
    
    if 'got_better_idxs' in params:
        recovered = np.array(params['got_better_idxs'])
        recovered_down, recovered_across = num_to_idxs(recovered, params)
        ax.plot(recovered_across, recovered_down, marker='P', markersize=20, ls='', color='red')

    if 'died_idxs' in params:
        died = np.array(params['died_idxs'])
        died_down, died_across = num_to_idxs(died, params)
        ax.plot(died_across, died_down, marker=11, markersize=20, ls='', color='purple')
    
    plt.show()

In [18]:
## Widgets

grid_width = widgets.IntSlider(
    value=30,
    min=1,
    max=100,
    step=1,
    description='Grid Width:',
)

grid_height = widgets.IntSlider(
    value=20,
    min=1,
    max=100,
    step=1,
    description='Grid Height:',
)

connection_distance = widgets.IntSlider(
    value=3,
    min=1,
    max=100,
    step=1,
    description='Distance:',
)

num_timesteps = widgets.IntSlider(
    value=5,
    min=2,
    max=50,
    step=1,
    description='NTimesteps:',
)

perc_immune_start = widgets.FloatSlider(
    value=0.9,
    min=0,
    max=1.,
    step=0.01,
    description='Immune t=0:',
)

perc_sick_start = widgets.FloatSlider(
    value=0.03,
    min=0,
    max=1.,
    step=0.01,
    description='Sick t=0:',
)

prob_social = widgets.FloatSlider(
    value=0.25,
    min=0,
    max=1.,
    step=0.01,
    description='P(Interact)',
)

prob_infection = widgets.FloatSlider(
    value=0.8,
    min=0,
    max=1.,
    step=0.01,
    description='P(Sick|Interact):',
)

prob_death = widgets.FloatSlider(
    value=0.1,
    min=0,
    max=1.,
    step=0.01,
    description='P(Death)',
)

prob_recovery = widgets.FloatSlider(
    value=0.1,
    min=0,
    max=1.,
    step=0.01,
    description='P(Recovery)',
)

immune_after_recovery = widgets.ToggleButton(
    value=False,
    description='Immune after Recovery',
)

clear_plots = widgets.ToggleButton(
    value=True,
    description='Clear Plots',
)

wait = widgets.FloatSlider(
    value=0.,
    min=0,
    max=5.,
    step=0.25,
    description='Delay [s]:',
)

In [21]:
def run_sim(
    grid_width,
    grid_height,
    connection_distance,
    num_timesteps,
    perc_immune_start,
    perc_sick_start,
    prob_social,
    prob_infection,
    prob_death,
    prob_recovery,
    immune_after_recovery,
    clear_plots,
    wait,
):
    setup = dict(
        COLOR_IMMUNE = 'blue',
        COLOR_HEALTHY = 'green',
        COLOR_SICK = 'red',
        COLOR_DEAD = 'black',
        GRID_WIDTH = grid_width,
        GRID_HEIGHT = grid_height,
        CONNECTION_DISTANCE = connection_distance,
        NUM_TIMESTEPS = num_timesteps,
        PERC_IMMUNE_START = perc_immune_start,
        PERC_SICK_START = perc_sick_start,
        PROB_SOCIAL = prob_social,
        PROB_INFECTION = prob_infection,
        PROB_DEATH = prob_death,
        PROB_RECOVERY = prob_recovery,
        IMMUNE_AFTER_RECOVERY = immune_after_recovery,
        CLEAR_PLOTS = clear_plots,
        WAIT = wait,
        TIME_START = time.time()
    )
    
    params = derive_params(setup)

    plot(params)
    while params['timestep'] < params['NUM_TIMESTEPS']:
        params = do_timestep(params)

######

widget = interactive(
    run_sim,
    {'manual': True},
    grid_width=grid_width,
    grid_height=grid_height,
    connection_distance=connection_distance,
    num_timesteps=num_timesteps,
    perc_immune_start=perc_immune_start,
    perc_sick_start=perc_sick_start,
    prob_social=prob_social,
    prob_infection=prob_infection,
    prob_death=prob_death,
    prob_recovery=prob_recovery,
    immune_after_recovery=immune_after_recovery,
    clear_plots=clear_plots,
    wait=wait,
)

run_child = widget.children[-2]
output_child = widget.children[-1]
output_child.layout.height = '800px'

gridVBox = VBox([grid_width, grid_height, connection_distance, num_timesteps])
setupVBox = VBox([perc_immune_start, perc_sick_start])
probVBox = VBox([prob_social, prob_infection, prob_death, prob_recovery])
controls = HBox([gridVBox, setupVBox, probVBox])
toggleVBox = HBox([immune_after_recovery, clear_plots, wait])
VBox([controls, toggleVBox, run_child, output_child])

VBox(children=(HBox(children=(VBox(children=(IntSlider(value=30, description='Grid Width:', min=1), IntSlider(…