In [None]:
# IMPORT BLOCK
import networkx as nx 
import matplotlib.pyplot as plt 
import numpy as np
import EoN
import time
import seaborn as sns

In [None]:
# BASIC UTILITIES BLOCK 
def selector(els, idxs):
    for idx in idxs:
        yield els[idx]
        
def tuple_filter(tup_iter, idx):
    return [tup[idx] for tup in tup_iter]

        
def c(i):
    return 'bgrcmyk'[i]


def invert_dict(d):
    new_dict = {} 
    for k,v in d.items():
        if v in new_dict:
            new_dict[v].append(k)
        else:
            new_dict[v] = [k]
    return new_dict
        
def argmax(iterable):
    # Returns max_idx, max_value
    max_val, max_idx = -float('inf'), None
    for i, el in enumerate(iterable):
        if el > max_val:
            max_val = el 
            max_idx = i 
    return max_idx, max_val

In [None]:
# GENERIC CODE BLOCK (WILL APPEAR IN ALL NOTEBOOKS)
class TupleSIR:
    """ Class to handle tuples """
    def __init__(self, t, S, I, R, tmax=None):
        self.t = t 
        self.S = S 
        self.I = I 
        self.R = R 
        if tmax is not None:
            self.tmax = tmax 
        else:
            self.tmax = t[-1]
    
    @classmethod
    def from_summary(cls, summary, tmax=None):
        return cls(summary.t(), summary.S(), summary.I(), summary.R(), tmax=tmax)

    @classmethod
    def empty(cls):
        return cls(*[np.empty(0) for _ in range(4)], tmax=0.0)
    
    def get_by_str(self, s):
        return getattr(self, s)
    
    def get_max_I(self):
        return max(self.I)
    
    def get_final_R(self):
        return self.R[-1]
    
    def get_peak_I_time(self):
        # Returns (time_of_max_I, max_I)
        return argmax(self.I)
    
        
    

def run_until_time(G, tau, gamma, rho, tmax):
    """ Runs basic SIR until time T:
    
    RETURNS:
        G', TupleSIR
        G': new graph object with I,R nodes removed (unless tmax=0.0, then returns identical G)
        TUpleSIR: TupleSIR object with run data
    """
    # handle edge case where tmax == 0:
    if tmax == 0:
        return G, TupleSIR.empty() 
    summary = EoN.fast_SIR(G, tau, gamma, rho=rho, tmax=tmax, return_full_data=True)
    
    summary_dict = invert_dict(summary.get_statuses(time=tmax))
    
    I_nodes = summary_dict.get('I', [])
    R_nodes = summary_dict.get('R', [])
    
    G = G.copy() 
    G.remove_nodes_from(I_nodes)
    G.remove_nodes_from(R_nodes)
    return G, TupleSIR.from_summary(summary)


    
    
def cat_tuples(tup_list):
    # Given a list of TupleSIR objects, concatenate them into a single TupleSIR object 
    if not isinstance(tup_list, list):
        return tup_list # should just be a single TupleSIR object 
    
    t_lists, s_lists, i_lists, r_lists = [], [], [], [] 
    running_t = 0
    running_r = 0
    for tup in tup_list:
        # Save in-run data 
        t_lists.append(running_t + tup.t)
        s_lists.append(tup.S)
        i_lists.append(tup.I)
        r_lists.append(running_r + tup.R)
        
        # add end-time-arrays 
        if tup.tmax > 0.0:
            t_lists.append(np.array([tup.tmax + running_t]))
            s_lists.append(np.array([tup.S[-1]]))
            i_lists.append(np.array([tup.I[-1]]))
            r_lists.append(np.array([tup.R[-1] + running_r]))
        
            # update running_r,t
            running_t += tup.tmax 
            running_r += tup.I[-1] + tup.R[-1]
    

    # And concatenate everything in the end 
    final_t = np.concatenate(t_lists)
    final_s = np.concatenate(s_lists)
    final_i = np.concatenate(i_lists)
    final_r = np.concatenate(r_lists)
    
    return TupleSIR(final_t, final_s, final_i, final_r, tmax=running_t)

def plot_sir_counts(tup_list, plot_series='IR', ax=None):
    # Makes a plot for each of the plots in plot_series
    str_dict = {'I': '# Infected', 
                  'R': '# Recovered'}
    for serie in plot_series:
        fig, ax = plt.subplots(figsize=(8,8))
        for i, tup in enumerate(tup_list):
            ax.plot(tup.t, tup.get_by_str(serie), c=c(i), alpha=0.5, label=getattr(tup, 'label', None))
        ax.set_xlabel('Time')
        ax.set_ylabel(str_dict[serie])
        ax.set_title(str_dict[serie] + ' by time')
        ax.legend() 
        
        
def quarantines_by_time(G, tau, gamma, rho, qtimes, tmax):
    # Runs a model for time tmax, with quarantines specified by qtimes 
    # Qtimes is a list of cumulative quarantine times 
    if not isinstance(qtimes, list):
        qtimes = [qtimes]
    deltas = [qtimes[0]]
    for i in range(len(qtimes) - 1):
        deltas.append(qtimes[i + 1] - qtimes[i])
    deltas.append(tmax - qtimes[-1])
    print("Deltas", deltas)

    tups = []
    for delta in deltas:
        G, tup = run_until_time(G, tau, gamma, rho=rho, tmax=delta)
        tups.append(tup)
    return cat_tuples(tups)


def quarantine_by_pop(G, tau, gamma, rho, prop, tmax):
    # SINGLE QUARANTINE ONLY!!!
    # Runs a single quarantine which is initialized when the proportion I,R gets to prop 
    # and then runs until tmax afterwards 
    
    G1, tup1 = run_until_prop_IR(G, tau, gamma, rho, tmax, prop)
    remaining_time = tmax - tup1.tmax 
    G2, tup2 = run_until_time(G, tau, gamma, rho, remaining_time)
    
    return cat_tuples([tup1, tup2])



SERIES_IDX = {'S': 1, 'I': 2, 'R': 3}
def plot_vanilla_run(G, tau, gamma, rho, tmax, series='IR'):
    axlist = []
    for serie in series: 
        fig, ax = plt.subplots(figsize=(8,8))

        select = lambda tup: tup[SERIES_IDX[serie]]
        for i in range(5):
            runtup = EoN.fast_SIR(G, tau, gamma, rho=rho, tmax=tmax)
            ax.plot(runtup[0], select(runtup), c='k', alpha=0.3)
        axlist.append(ax)
    if len(axlist) > 1:
        return axlist
    return axlist[0]


# THINGS TO PLOT 
# X-axis: timestep quarantine was performed at 
# Y-axis1: maximum number of infected at any one time 
# Y-axis2: number of recovered at t=infinity (so number of people who got it )


def plot_single_qs(qrange, G, tau, gamma, rho, maxt):
    # qrange is a list of 
    #step 1, run each simulation:
    sum_ranges = [] 
    for qseries in qrange:
        print("Running sim on series:", qseries)
        sum_ranges.append(quarantine_cycle(G, tau, gamma, rho, qseries, maxt))
        
    
    #step2, triples we care about (quarantine_time, maxI, finalR)
    triples = [(qrange[i], get_max_I(_), get_final_R(_)) for i,_ in enumerate(sum_ranges)]
    
    # step 3, plot both series
    plt.plot(tuple_filter(triples, 0), tuple_filter(triples, 1), c='b', alpha=0.5, label='maxI')
    plt.plot(tuple_filter(triples, 0), tuple_filter(triples, 2), c='r', alpha=0.5, label='finalR')
    
    return sum_ranges

    

In [None]:
# CUSTOM CODE BLOCK (WILL APPEAR ONLY IN THIS NOTEBOOK)
SERIES_IDX = {'S': 1, 'I': 2, 'R': 3}
def plot_vanilla_run(G, tau, gamma, rho, tmax, series='IR'):
    axlist = []
    for serie in series: 
        fig, ax = plt.subplots(figsize=(8,8))

        select = lambda tup: tup[SERIES_IDX[serie]]
        for i in range(5):
            runtup = EoN.fast_SIR(G, tau, gamma, rho=rho, tmax=tmax)
            ax.plot(runtup[0], select(runtup), c='k', alpha=0.3)
        axlist.append(ax)
    if len(axlist) > 1:
        return axlist
    return axlist[0]

def data_getter(G, tau, gamma, rho, tmax, qtimes):
    # For every quarantine-schedule in qtimes: 
    return [quarantine_cycle(G, tau, gamma, rho, _, tmax) for _ in qtimes]

    
    
def plot_maxI_finalR(data, qtimes):
    fig, ax = plt.subplots(figsize=(8,8))
    triples = [(qtime, get_max_I(data[i]), get_final_R(data[i])) for i, qtime in enumerate(qtimes)]
    ax.plot(tuple_filter(triples, 0), tuple_filter(triples, 1), c='b', alpha=0.5, label='maxI')
    ax.plot(tuple_filter(triples, 0), tuple_filter(triples, 2), c='r', alpha=0.5, label='finalR')
    ax.set_xlabel("Quaratine Time")
    ax.set_ylabel("# Individuals")
    ax.legend()
    ax.plot()
    return ax 
    
def plot_SIR_runs(data, qtimes, idx_selectors, series='I'):
    fig, ax = plt.subplots(figsize=(12,12))
    qtimes = list(selector(qtimes, idx_selectors))
    runs = list(selector(data, idx_selectors))
    pairs = [(run[0], run[SERIES_IDX[series]]) for run in runs]
    for i, (qtime, pair) in enumerate(zip(qtimes, pairs)):
        ax.plot(pair[0], pair[1], c=c(i), alpha=0.5, label="Q@ time: %s" % qtime)
    ax.legend()
    return ax


def compare_to_vanilla(vanilla_run, comp_run, series='I'):
    fig, ax = plt.subplots(figsize=(12, 12))
    idx = SERIES_IDX[series]
    ax.plot(vanilla_run[0], vanilla_run[idx], c='k', alpha=0.5, label='vanilla')
    ax.plot(comp_run[0], comp_run[idx], c='b', alpha=0.5, label='comparison')
    ax.legend()
    return ax

In [None]:
def run_until_prop_IR(G, tau, gamma, rho, tmax, prop, total_nodes=None):
    """
    Runs SIR model until prop (in [0,1]) fraction of nodes are in I+R states 
    If total_nodes is not None, then the proportion is WRT total_nodes (and not len(G))
    
    RETURNS: G2, status_dict
    G2 is a copy of the graph G with I,R nodes removed 
    status_dict is a dict with keys {t, S, I, R} pointing to the right times 
    """
    if min([tmax, prop]) == 0:
        return G, TupleSIR.empty()

    total_nodes = total_nodes or len(G)
    threshold = total_nodes * prop
    # This has to be slower because we need to run the infection 
    # and then figure out which time to cut things off (ex-post facto)
    summary = EoN.fast_SIR(G, tau, gamma, rho=rho, tmax=tmax, return_full_data=True)

    I, R = summary.I(), summary.R() 
    breakpoint = None
    for i in range(len(I)):
        if I[i] + R[i] >= threshold:
            breakpoint = i 
            break
            
    if breakpoint is not None: # if achieved threshold, modify the graph
        breaktime = summary.t()[breakpoint]
    else:
        breaktime = tmax 
    summary_dict = invert_dict(summary.get_statuses(time=breaktime))
    #print("summary_dict", [(k, len(v)) for k,v in summary_dict.items()])
    I_nodes = summary_dict.get('I', [])
    R_nodes = summary_dict.get('R', [])
    G = G.copy() 
    G.remove_nodes_from(I_nodes)
    G.remove_nodes_from(R_nodes)
    trunc_t = summary.t()[:breakpoint + 1]
    trunc_S = summary.S()[:breakpoint + 1]
    trunc_I = summary.I()[:breakpoint + 1]
    trunc_R = summary.R()[:breakpoint + 1]
    return G, TupleSIR(trunc_t, trunc_S, trunc_I, trunc_R, tmax=breaktime)

    

In [None]:
N = 10 ** 4
G = nx.barabasi_albert_graph(N, 8)

tmax = 20 
iterations = 5  #run 5 simulations
tau = 0.15        #transmission rate
gamma = 1.0    #recovery rate
rho = 0.005      #random fraction initially infected


In [None]:
plot_vanilla_run(G, tau, gamma, rho, 20)

In [None]:
tup_out = quarantine_by_pop(G, tau, gamma, rho, 0.1 ,20)

In [None]:
G2, tup_out = run_until_prop_IR(G, tau, gamma, rho, 20.0, 0.1)

In [None]:
tup_out.R, tup_out.I, tup_out.t

In [None]:
plot_sir_counts([tup_out], plot_series='IR', ax=None)

In [None]:
# HEATMAPS DONE RIGHT:
# Step 1: Define quarantine grid and run multiple runs (don't need to return full data...)
def get_quarantine_grid_data(G, tau, gamma, rho, tmax, first_qrange, second_qrange, num_iter=3):
    """ G,tau,gamm, rho,tmax are all standard parameters 
        first_qrange is a list of times for the first quarantine 
        second_qrange is a list of times for the second quarantine (e.g., 1 here means do 
        second quarantine 1 time unit after first quarantine)
    Runs quarantine for cartesian product between the two qranges 
    OUTPUT DATA STRUCTURE IS A DICT WITH KEYS BEING TUPLES
    """
    qpairs = [[q0, q1] for q0 in first_qrange for q1 in second_qrange]
    output_data = {}
    for q0, q1 in qpairs:
        pair_runs = [] # (t, S, I, R) tuples in this list
        for i in range(num_iter):
            pair_runs.append(quarantine_cycle(G, tau, gamma, rho, [q0, q0 + q1], tmax))
        output_data[(q0, q1)] = pair_runs
    return output_data
                             
sample_grid = get_quarantine_grid_data(G, tau, gamma, rho, tmax, list(range(5)), list(range(5)))

In [None]:
def process_into_grid(data_dict, func=None):
    if func is None:
        func = lambda x: x 
    
    # First make a grid from the data:
    keys = sorted(data_dict.keys())
    # make rows:(based on first element of key)
    grid = []
    prev_key = keys[0][0]
    current_row = []
    grid_check = []
    current_row_grid_check = []
    for k in keys:
        if k[0] != prev_key:
            grid.append(current_row)
            current_row = [] 
            prev_key = k[0]
            grid_check.append(current_row_grid_check)
            current_row_grid_check = []
        current_row.append(func(data_dict[k]))
        current_row_grid_check.append(k)
    grid.append(current_row)
    grid_check.append(current_row_grid_check)
    return grid, grid_check
        

def avg_max_I(sum_list):
    return sum(get_max_I(_) for _ in sum_list) / float(len(sum_list))

def avg_final_R(sum_list):
    return sum(get_final_R(_) for _ in sum_list) /float(len(sum_list))


def heatmapify(grid, grid_idxs, title=None):
    """ Builds a heatmap and formats it (except for the title) """
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(grid, ax=ax)
    yticks = [_[0][0] for _ in grid_idxs]
    xticks = [_[1] for _ in grid_idxs[0]]
    ax.set_ylabel("Time of first quarantine")
    ax.set_xlabel("Time of second quarantine (after first)")
    ax.set_yticklabels(yticks)
    ax.set_xticklabels(xticks)
    if title is not None:
        ax.set_title(title)
    
    
    return ax

In [None]:
heatmapify(*process_into_grid(sample_grid, avg_final_R), 'Max # Infected')

In [None]:
[_[0][0] for _ in grid_idxs]

In [None]:
start = time.time()
EoN.fast_SIR(G, tau, gamma, rho=rho, tmax=tmax, return_full_data=True)
print(time.time() - start)

start = time.time()
EoN.fast_SIR(G, tau, gamma, rho=rho, tmax=tmax, return_full_data=False)
print(time.time() - start)

In [None]:
plot_vanilla_run(G, tau, gamma, rho, tmax)

In [None]:
# EXPERIMENT 1: Infection spreads more slowly after first quarantine:
# Show that time to peak increases after first quarantine
def time_to_second_peak(G, tau, gamma, rho, tmax, qtime):
    g1, summ1 = run_until_time(G, tau, gamma, rho=rho, tmax=qtime)
    g2, summ2 = run_until_time(g1, tau, gamma, rho=rho, tmax=tmax-qtime)
    return summ2.t()[max_idx(summ2.I())]

def avg_by_tt2p(G, tau, gamma, rho, tmax, qtime, num_iter=5):
    times = [time_to_second_peak(G, tau, gamma, rho, tmax, qtime) for _ in range(num_iter)]
    zero_less_times = [_ for _ in times if _ > 0]
    if len(zero_less_times) == 0:
        return 0.0 
    else:
        return sum(zero_less_times) / len(zero_less_times)
    

qtimes = [_ / 10.0 for _ in range(25)]
second_peaks = [(qtime, avg_by_tt2p(G, tau, gamma, rho, 20, qtime)) for qtime in qtimes]

In [None]:
plt.plot([_[0] for _ in second_peaks], [_[1] for _ in second_peaks])

In [None]:
qpairs

In [None]:
# EXPERIMENT 2: Max I/Final R by first_qtime, second_qtime-first_qtime: 
qpairs = [[i / 4.0, i / 4.0 + j / 4.0] for i in range(10) for j in range(10)]
grid_run =data_getter(G, tau, gamma, rho, 20, qpairs)



In [None]:
# Stack into a 2d array:
empty_grid = [[None for _ in range(10)] for _ in range(10)]
for i, el in enumerate(grid_run):
    empty_grid[i % 10][i // 10] = el
max_i_grid = [[get_max_I(el) for el in row] for row in empty_grid]
last_r_grid = [[get_final_R(el) for el in row] for row in empty_grid]

In [None]:
import seaborn as sns
ax = sns.heatmap(np.array(max_i_grid))
ax.set_yticklabels([_ / 4.0 for _ in range(10)])
ax.set_xticklabels([_ / 4.0 for _ in range(10)])
ax.set_title("Max # Infected")
ax.set_ylabel("Time of first quarantine")
ax.set_xlabel("Time after first quarantine before second quarantine")

In [None]:
min_val = float('inf')
min_idx = None
for i, row in enumerate(last_r_grid):
    for j, el in enumerate(row):
        if el < min_val:
            min_val = el 
            min_idx = (i,j)
print(min_val, min_idx)
last_r_grid[7][9]

In [None]:
qpairs[7 * 10 + 1]

In [None]:
last_r_grid[]

In [None]:
last_r_grid[0][0]

In [None]:
fig, ax = plt.subplots(figsize=(10,10))
ax = sns.heatmap(np.array(last_r_grid), ax=ax)
ax.set_yticklabels([_ / 4.0 for _ in range(10)])
ax.set_xticklabels([_ / 4.0 for _ in range(10)])
ax.set_title("Final # Recovered")
ax.set_ylabel("Time of first quarantine")
ax.set_xlabel("Time after first quarantine before second quarantine")

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))


In [None]:
# EXPERIMENT 3: Power of the second quarantine: 
# Hypothesis: can use 2 quarantines to drop peak #I by a factor of 4:

def double_peak_cut(G, tau, gamma, rho, tmax):
    # First do a basic run to figure out max peak and the time to apply first quarantine 
    t_, S_, I_, R_ = EoN.fast_SIR(G, tau, gamma, rho=rho, tmax=tmax)
    peak_I = max(I_)
    peak_I_time = t_[max_idx(I_)]
    
    peak_I_quarter = peak_I / 4.0 
    quarter_I_time = 0.0
    for idx, el in enumerate(I_):
        if el > peak_I_quarter:
            quarter_I_time = t_[idx]
            break
    # Now do a run until quarter_time: 
    G1, summ1 = run_until_time(G, tau, gamma, rho, quarter_I_time, copy_graph=True)
    
    # And do another run until peak hits quarter peak:
    t_, S_, I_, R_ = EoN.fast_SIR(G1, tau, gamma, rho=rho, tmax=tmax)
    for idx, el in enumerate(I_):
        if el > peak_I_quarter:
            quarter_I_time = t_[idx]
            break
    G2, summ2 = run_until_time(G1, tau, gamma, rho, quarter_I_time, copy_graph=True)
    G3, summ3 = run_until_time(G2, tau, gamma, rho, tmax, copy_graph=True)
    return summ1, summ2, summ3

summs = list(double_peak_cut(G, tau, gamma, rho, tmax))

In [None]:
plot_sir_counts(summs, 'I')