In [1]:
%cd ./covid_households
import recipes
from constants import STATE
import torch
from settings import device

/Users/thayer/covid_households/covid_households


In [3]:
# the base rate of transmission
beta = 0.55

In [None]:
# adjmat: 1 if individuals are connected (in the same house but not identical) and 0 otherwise
adjmat = pop.make_connectivity_matrix(x._adjmat)

# make the initial state
state = pop.make_initial_state(recipes.InitialSeedingConfig.seed_one_by_susceptibility)
# but let's skip straight to the part where people are infected, not exposed
state[state == STATE.exposed] = STATE.infectious

In [None]:
# send the state and adjacencies to the GPU (if applicable)
state = torch.from_numpy(state).to(device)
adjmat = torch.from_numpy(adjmat).to(device)

In [None]:
sus = torch.from_numpy(pop.sus).to(device)
inf = torch.from_numpy(pop.inf).to(device)
print(state)
find_propensities(state, adjmat, beta, sus, inf)

In [None]:
import state_lengths
state_length_sampler = state_lengths.lognormal_state_length_sampler
state_lengths = torch.zeros_like(state, dtype=torch.double)
for s in STATE:
    if state_lengths[state==s].nelement() > 0:
        state_lengths[state==s] = state_length_sampler(s, state[state == s]) ## how long spent in each state; already on device

In [None]:
state_lengths

In [27]:
def find_propensities(state, beta, sus, inf, connectivity_matrix):
    """
    Determines the propensities towards each event (read: different person being infected) possible in each household.
    
    Parameters
    ----------
    state : ndarray / tensor
        Array of current state for each individual in the population
    connectivity_matrix : ndarray / tensor
        Matrix with A_ij = 1 if individuals i and j are connected
        (in the same house but not identical) and 0 otherwise
    beta : float
        Constant rate of infection
    sus : ndarray / tensor
        Array of the individuals' relative susceptibilities
    inf : ndarray / tensor
        Array of the individuals' relative infectivities
        
    Returns
    -------
    propensities : ndarray
        Propensity (~instantaneous probability) for each individual in each household
        to be infected given current state
    time : float
        Time it took for the reaction to occur.
    """
    inf_mask = (state == STATE.infectious)
    sus_mask = (state == STATE.susceptible)
    
    population_matrix = (sus @ inf) * connectivity_matrix
    p_mat = beta * population_matrix
    propensities = p_mat * sus_mask * inf_mask.permute(0, 2, 1)
    propensities = propensities.sum(axis=2)
    return propensities

def gillespie_simulation(numpy_initial_state, beta, state_length_sampler, numpy_sus, numpy_inf, numpy_connectivity_matrix):
    # move everything onto the torch device
    state = torch.from_numpy(numpy_initial_state).to(device)
    connectivity_matrix = torch.from_numpy(numpy_connectivity_matrix).to(device)
    sus = torch.from_numpy(numpy_sus).to(device)
    inf = torch.from_numpy(numpy_inf).to(device)
    t = torch.zeros((state.shape[0], 1), dtype=torch.float)
    
    # find the duration of the exposed state for everyone who starts exposed
    state_lengths = torch.zeros_like(state, dtype=torch.double)
    for s in STATE:
        if state_lengths[state==s].nelement() > 0:
            state_lengths[state==s] = state_length_sampler(s, state[state == s])

    # while anyone is infected or exposed, we continue simulating
    while (state == STATE.exposed).any() or (state == STATE.infectious).any():
        print(state)
        # perform an update step by finding the next event (via Gillespie simulation) in each household
        dstate, dtime = vector_gillespie_step(find_propensities, state, t, state_lengths, beta, sus, inf, connectivity_matrix)
        state = state + dstate
        #print(t, '\n', state_lengths, '\n', dtime)
        #print(state_lengths.shape, dtime.shape)
        t     = t     + dtime
        state_lengths = state_lengths - dtime.unsqueeze(1)
        #print(t, '\n', state_lengths, '\n', dtime)
        
        changed_states = (dstate != 0)
        # when persons move to a new state, we generate the time that they'll spend in that state
        for s in STATE:
            # find all the people who entered state s
            entrants = state[torch.logical_and(changed_states, state==s)]
            if entrants.nelement() > 0:
                # no one should be entering the susceptible state (they start there)
                assert s>STATE.susceptible
                # find the duration of the state for all the entrants
                entrant_lengths = state_length_sampler(s, entrants)
                state_lengths[torch.logical_and(changed_states, state==s)] = entrant_lengths

            state_lengths[state == STATE.susceptible] = np.inf
            state_lengths[state == STATE.removed] = np.inf
        #break

    return_state = state.cpu().numpy()
    return return_state != STATE.susceptible

In [48]:
torch.manual_seed(0)
np.random.seed(0)

beta = 0.055
x = recipes.PopulationStructure({4:10})
# the population will know how different households are connected
pop = x.make_population()

initial_state = pop.make_initial_state(recipes.InitialSeedingConfig.seed_one_by_susceptibility)
print(initial_state)
import state_lengths
state_length_sampler = state_lengths.lognormal_state_length_sampler

gillespie_simulation(initial_state, beta, state_length_sampler, pop.sus, pop.inf, x._adjmat)

[[[0]
  [0]
  [1]
  [0]]

 [[0]
  [0]
  [1]
  [0]]

 [[0]
  [0]
  [1]
  [0]]

 [[0]
  [0]
  [1]
  [0]]

 [[0]
  [1]
  [0]
  [0]]

 [[0]
  [0]
  [1]
  [0]]

 [[0]
  [1]
  [0]
  [0]]

 [[0]
  [0]
  [0]
  [1]]

 [[0]
  [0]
  [0]
  [1]]

 [[0]
  [1]
  [0]
  [0]]]
tensor([[[0],
         [0],
         [1],
         [0]],

        [[0],
         [0],
         [1],
         [0]],

        [[0],
         [0],
         [1],
         [0]],

        [[0],
         [0],
         [1],
         [0]],

        [[0],
         [1],
         [0],
         [0]],

        [[0],
         [0],
         [1],
         [0]],

        [[0],
         [1],
         [0],
         [0]],

        [[0],
         [0],
         [0],
         [1]],

        [[0],
         [0],
         [0],
         [1]],

        [[0],
         [1],
         [0],
         [0]]])
PROP: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]

array([[[ True],
        [False],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        [ True]],

       [[ True],
        [False],
        [ True],
        [ True]],

       [[ True],
        [False],
        [ True],
        [False]],

       [[ True],
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [False],
        [False]],

       [[ True],
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        [ True]]])

In [44]:
import numpy as np
def vector_gillespie_step(propensity_func, state, t, state_lengths, *propensity_args):
    """
    Draws the next event for each household in the population. Assesses whether the drawn event takes place or time advances so far that someone ages out of their state.
    Returns the change in state and the time advancement for each household.
    
    Parameters
    ----------
    propensity_func : function
        Function used for computing propensities. Signature: propensity_func(state, propensity_args)
        Returns an array of propensities.
    state : ndarray
        Array of current state for each individual in the population
    t : ndarray
        The current time in each household
    state_lengths: ndarray
        The time left for each individual in their current state (if transitory) or ~infinity if stationary
    propensity_args : additional arguments
        Arguments to be passed to the propensity-finding-function
        
    Returns
    -------
    dstate : ndarray
        An array such that state + dstate properly represents the new state after one event in each household.
    dtime : ndarray
        Time that passed in each household before its event.
    """
    # time until next aging event (someone leaving one compartment for the next due to time passing) in each household
    dtime_aging, dstate_aging_indices = state_lengths.min(axis=1)
    dtime_aging *= 0.1
    
    propensities = propensity_func(state, *propensity_args)

    # sum of propensity per household
    household_total_propensity = propensities.sum(axis=1)
    print("PROP:", propensities, household_total_propensity)


    valid_propensity_mask = (household_total_propensity != 0.0)
    # time until the drawn event for each household
    dist = torch.distributions.Exponential(household_total_propensity[valid_propensity_mask])
    # sample from the exponential distribution for time until events
    # (and then transposing futzing to get it pointed the right way)
    dtime_gillespie = torch.full_like(dtime_aging, np.inf)
    dtime_gillespie[valid_propensity_mask] = dist.sample().unsqueeze(0).transpose(0,1)

    # relative probability of each event (read: of each person being infected)
    relative_probabilities = propensities[valid_propensity_mask] / household_total_propensity[valid_propensity_mask].unsqueeze(1)

    # randomly choose an event to happen in proportion to the relative probability
    #print(propensities, household_total_propensity, relative_probabilities)
    dstate_gillespie_indices = torch.zeros_like(dtime_gillespie, dtype=torch.long)
    #print(valid_propensity_mask, relative_probabilities)
    #import pdb; pdb.set_trace()
    dstate_gillespie_indices[valid_propensity_mask] = relative_probabilities.multinomial(num_samples=1, replacement=True)

    # we choose the type of event (infection from gillespie or aging out of a state at fixed time)
    # in each household based on which happens first
    dstate_indices = torch.where((dtime_aging < dtime_gillespie), dstate_aging_indices, dstate_gillespie_indices)
    # the time we should advance is similarly determined by which type of event happens in each household
    dtime = torch.where(dtime_aging < dtime_gillespie, dtime_aging, dtime_gillespie)
    print(dtime, '\n', dtime_aging, '\n', dtime_gillespie)
    # create the vector such that state + dstate = new_state
    dstate = torch.zeros_like(state)
    # take advantage of the fact that compartments are sequential
    # we would need to more fastidiously track which events took place if backtracking (ex. SIRS) were possible
    dstate[torch.arange(dstate.shape[0]), dstate_indices.transpose(0,1)] = 1
    return dstate, dtime