# Project: Stochastic and Spatial Models

## 0. Init

Most code will reference algorithms and practices as described in "Modeling infectious diseases in humans and animals" (MID) by M. Keeling and P. Rohani.

In [None]:
import subprocess

# Install all dependencies
subprocess.run(["pip", "install", "--no-input", "python-slugify", "numpy<1.27", "matplotlib", "scipy", "numba", "networkx", "ndlib"]) 

In [None]:
# WARNING: Comment if code doesn't run
# Use jit to compile and optimize Python code
from numba import jit

# Arrays and analysis
import numpy as np
import scipy as sp
from scipy.integrate import solve_ivp
from scipy.optimize import curve_fit
from scipy.ndimage import convolve1d

# Plotting and config
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [6, 3]
plt.rcParams['lines.linewidth'] = 1
plt.rcParams['figure.constrained_layout.use'] = True

# Network
import networkx as nx
import ndlib.models.ModelConfig as mc
import ndlib.models.epidemics as ep
from ndlib.viz.mpl.DiffusionTrend import DiffusionTrend

# Misc imports
from slugify import slugify
from functools import partial
import os

# Important directories
FIG_DIR = 'fig/'
DATA_DIR = 'data/'
DUMP_DIR = 'dump/'

def save_fig(title):
    """Save figure under normalized name."""
    plt.savefig(f'{FIG_DIR}/{slugify(title)}.png', bbox_inches='tight')

def props_fig(props_dict):
    """Update current Axes with the given dictionary of properties."""
    plt.gca().update(props_dict)

def set_ax_props(ax, props_dict):
    """Update Axes with the given dictionary of properties."""
    ax.update(props_dict)

def create_dirs(path):
    """Create directory, do nothing if it exists."""
    os.makedirs(path, exist_ok=True)

create_dirs(FIG_DIR)
create_dirs(DATA_DIR)
create_dirs(DUMP_DIR)

In [None]:
%matplotlib inline

## 1. Gillespie’s Direct Algorithm and Stochastic Hallmarks

### 1.1 Implement Gillespies algorithm

In [None]:
# WARNING: Comment if code doesn't run
# Use jit to compile and optimize Python code
@jit(nopython=True)
def SIR_GDA(y0, params, t_max):
    """Implementation of Gillespie's Direct Algorithm (GDA)
    for the standard SIR model. Written after Box 6.3 in MID (p 201)."""
    # np.random.seed()
    X, Y, Z = y0
    beta, gamma, mu = params
    
    # 1.
    # Events.
    ps = np.array([
        (1, 0, 0),    # birth
        (-1, 1, 0),   # transmission
        (0, -1, 1),   # recovery
        (-1, 0, 0),   # death_X
        (0, -1, 0),   # death_Y
        (0, 0, -1),   # death_Z
    ])

    # Bookkeeping variables.
    t = 0
    ts = []
    ys = []
    while t < t_max:    
        # 2.
        N = np.sum(np.array([X, Y, Z]))

        # Rates in the same order as their corresponding events.
        Rs = np.array([mu*N, beta*X*Y/N, gamma*Y, mu*X, mu*Y, mu*Z])
        
        # 3.
        R_total = np.sum(Rs)
        
        # 4.
        rand_1 = np.random.rand()
        dt = -1 / R_total * np.log(rand_1)
        
        # 5.
        rand_2 = np.random.rand()
        P = rand_2 * R_total
        
        # 6.
        R_cum = np.cumsum(Rs)
        
        # Find index of event p.
        p_idx = np.searchsorted(R_cum, P)
        
        # Add event result to current state.
        p = ps[p_idx]
        X, Y, Z = np.array([X, Y, Z]) + p
        
        # 7.
        t += dt

        # Bookkeep results.
        ts.append(t); ys.append([X, Y, Z])

        # Early exit in case of extinction.
        if Y == 0:
            return np.array(ts), np.array(ys)

    return np.array(ts), np.array(ys)

In [None]:
def SIR_D(t, y0, params):
    """Deterministic, density-dependent implementation of the SIR model
    with demography as described in MID 2.1.2."""
    X, Y, Z = y0
    beta, gamma, mu = params

    N = X + Y + Z
    dx = mu*N - beta*X*Y/N - mu*X
    dy = beta*X*Y/N - gamma*Y - mu*Y
    dz = gamma*Y - mu*Z
    
    return dx, dy, dz

In [None]:
def MA(series, filter_length):
    """Calculate the simple moving average (SMA) of a series over 
    its columns using convolution."""
    kernel = np.ones(filter_length) / filter_length
    
    return convolve1d(series, kernel, mode='nearest', axis=0)

In [None]:
def equal_time(x_equal, xp, fp):
    """Interpolate to convert discrete event to continuous time with equal timesteps."""
    
    return np.interp(x_equal, xp, fp)

In [None]:
title="Comparison of GDA with deterministic SIR model"
t_max = 300
ts_discrete = np.linspace(0, t_max, 1000)
params = (0.3, 0.1, 1/80)
N = 1000
y0 = (N-10, 10, 0)
num_runs = 100

fig, ax = plt.subplots()

Xs = []
Ys = []
Zs = []
for i in range(num_runs):
    ts, ys = SIR_GDA(y0=y0, params=params, t_max=t_max)

    # Interpolate time axis for statistics
    Xs.append(equal_time(ts_discrete, ts, ys[:, 0]))
    Ys.append(equal_time(ts_discrete, ts, ys[:, 1]))
    Zs.append(equal_time(ts_discrete, ts, ys[:, 2]))
    
    # Plot the results for each trajectory
    plt.plot(ts, ys[:, 0], color='b', alpha=0.1)
    plt.plot(ts, ys[:, 1], color='r', alpha=0.1)
    plt.plot(ts, ys[:, 2], color='g', alpha=0.1)

ts_det = np.linspace(0, t_max, 200)
sol = solve_ivp(SIR_D, [0, t_max], y0, args=(params,), dense_output=True)
ys = sol.sol(ts_det).T

plt.plot(ts_det, ys[:, 0], color='#000099', linewidth=3, label='Susceptible')
plt.plot(ts_det, ys[:, 1], color='#990000', linewidth=3, label='Infected')
plt.plot(ts_det, ys[:, 2], color='#006600', linewidth=3, label='Recovered')

plt.legend(loc="upper right")
props_fig({"xlabel": "t (days)", "ylabel": "population", "title": title})
save_fig(title)
plt.show()
plt.close()

In [None]:
X_mean, X_std = np.mean(Xs, axis=0), np.std(Xs, axis=0)
Y_mean, Y_std = np.mean(Ys, axis=0), np.std(Ys, axis=0)
Z_mean, Z_std = np.mean(Zs, axis=0), np.std(Zs, axis=0)

title="Comparison of means"

fig, ax = plt.subplots()

plt.plot(ts_det, ys[:, 0], color='b', label='Exact S')
plt.plot(ts_det, ys[:, 1], color='r', label='Exact I')
plt.plot(ts_det, ys[:, 2], color='g', label='Exact R')

plt.plot(ts_discrete, X_mean, color='b', linestyle="--", label='Mean S')
plt.plot(ts_discrete, Y_mean, color='r', linestyle="--", label='Mean I')
plt.plot(ts_discrete, Z_mean, color='g', linestyle="--", label='Mean R')

plt.legend(loc="upper right")
props_fig({"xlabel": "t (days)", "ylabel": "population", "title": title})
save_fig(title)

plt.show()

title="Comparison of stds"

fig, ax = plt.subplots()

plt.fill_between(ts_discrete, X_mean - X_std, X_mean + X_std, color='b', alpha=0.3)
plt.fill_between(ts_discrete, Y_mean - Y_std, Y_mean + Y_std, color='r', alpha=0.3)
plt.fill_between(ts_discrete, Z_mean - Z_std, Z_mean + Z_std, color='g', alpha=0.3)

plt.plot(ts_discrete, X_mean, color='b', linestyle="--", label='Mean S')
plt.plot(ts_discrete, Y_mean, color='r', linestyle="--", label='Mean I')
plt.plot(ts_discrete, Z_mean, color='g', linestyle="--", label='Mean R')

plt.legend(loc="upper right")
props_fig({"xlabel": "t (days)", "ylabel": "population", "title": title})
save_fig(title)

plt.show()

In [None]:
title="Comparison of std after smoothing"
t_max = 300
ts_discrete = np.linspace(0, t_max, 1000)
params = (0.3, 0.1, 1/80)
N = 1000
y0 = (N-10, 10, 0)
num_runs = 100

fig, ax = plt.subplots()

Xs = []
Ys = []
Zs = []

Xs_ma = []
Ys_ma = []
Zs_ma = []
for i in range(num_runs):
    ts, ys = SIR_GDA(y0=y0, params=params, t_max=t_max)

    # Interpolate time axis for statistics
    Xs.append(equal_time(ts_discrete, ts, ys[:, 0]))
    Ys.append(equal_time(ts_discrete, ts, ys[:, 1]))
    Zs.append(equal_time(ts_discrete, ts, ys[:, 2]))

    ys = MA(ys, 1101)

    Xs_ma.append(equal_time(ts_discrete, ts, ys[:, 0]))
    Ys_ma.append(equal_time(ts_discrete, ts, ys[:, 1]))
    Zs_ma.append(equal_time(ts_discrete, ts, ys[:, 2]))

plt.plot(ts_discrete, np.std(Ys, axis=0), label="Stddev I (unsmoothed)")
plt.plot(ts_discrete, np.std(Ys_ma, axis=0), label="Stddev I (smoothed)")

plt.legend()
props_fig({"xlabel": "t (days)", "ylabel": "stdev", "title": title})
save_fig(title)

plt.show()

In [None]:
def mean_comparison(params):
    """Compare the mean of stochastic with deterministic run."""
    beta, gamma, mu = params

    title=f"comparison, b={beta}, c={gamma}, m={mu}"
    t_max = 300
    ts_discrete = np.linspace(0, t_max, 1000)
    N = 1000
    y0 = (N-10, 10, 0)
    num_runs = 100
    
    Xs = []
    Ys = []
    Zs = []
    for i in range(num_runs):
        ts, ys = SIR_GDA(y0=y0, params=params, t_max=t_max)
    
        # Interpolate time axis for statistics
        Xs.append(equal_time(ts_discrete, ts, ys[:, 0]))
        Ys.append(equal_time(ts_discrete, ts, ys[:, 1]))
        Zs.append(equal_time(ts_discrete, ts, ys[:, 2]))
    
    ts_det = np.linspace(0, t_max, 200)
    sol = solve_ivp(SIR_D, [0, t_max], y0, args=(params,), dense_output=True)
    ys = sol.sol(ts_det).T
    
    X_mean, X_std = np.mean(Xs, axis=0), np.std(Xs, axis=0)
    Y_mean, Y_std = np.mean(Ys, axis=0), np.std(Ys, axis=0)
    Z_mean, Z_std = np.mean(Zs, axis=0), np.std(Zs, axis=0)
        
    fig, ax = plt.subplots()
    
    plt.plot(ts_det, ys[:, 0], color='b', label='Exact S')
    plt.plot(ts_det, ys[:, 1], color='r', label='Exact I')
    plt.plot(ts_det, ys[:, 2], color='g', label='Exact R')
    
    plt.plot(ts_discrete, X_mean, color='b', linestyle="--", label='Mean S')
    plt.plot(ts_discrete, Y_mean, color='r', linestyle="--", label='Mean I')
    plt.plot(ts_discrete, Z_mean, color='g', linestyle="--", label='Mean R')
    
    plt.legend(loc="upper right")
    props_fig({"xlabel": "t (days)", "ylabel": "population", "title": title})
    save_fig(title)
    
    plt.show()

mean_comparison(params=(0.2, 0.1, 1/80))
mean_comparison(params=(0.3, 0.1, 1/80))
mean_comparison(params=(0.4, 0.1, 1/80))

### 1.2 Investigate Simulation Variability and Negative Co-variance

In [None]:
def errorbar_plot(x, ys, fig_props):
    """Create an errorbar plot where observations over x 
    are collected in the rows of ys."""
    fig, ax = plt.subplots()
    
    ax.errorbar(x, np.mean(ys, axis=1), np.std(ys, axis=1), marker='.', color='black', capsize=4, linestyle='--')
    # Always use scientific notation.
    ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
    props_fig(fig_props)
    plt.grid()
    save_fig(fig_props["title"])
    
    plt.show()


def regular_plot(x, ys, fig_props):
    """Create a regular plot where observations over x 
    are collected in the rows of ys."""
    fig, ax = plt.subplots()
    
    ax.plot(x, ys, marker='.', color='black', linestyle='--')
    # Always use scientific notation.
    ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
    props_fig(fig_props)
    plt.grid()
    save_fig(fig_props["title"])
    
    plt.show()

In [None]:
t_max = 20
ts_interpolate = np.linspace(0, t_max, 1000)

y0 = (9990, 10, 0)
num_runs = 20

gamma = 0.1
betas = np.linspace(0.3, 2, 13)
mu = 1/80

# All measured covariances between S and I
all_covs = []
# All measured variances
all_variances = []
for beta in betas:
    params = (beta, gamma, mu)

    covs = []
    Is_interpolate = []
    for _ in range(num_runs):
        ts, ys = SIR_GDA(y0=y0, params=params, t_max=t_max)

        S = ys[:, 0]
        I = ys[:, 1]
    
        covs.append(np.cov(S, I)[0][1])
        Is_interpolate.append(equal_time(ts_interpolate, ts, I))

    all_covs.append(covs)
    all_variances.append(np.var(Is_interpolate, axis=0))

In [None]:
errorbar_plot(betas, all_variances, {
    "xlabel": "beta", 
    "ylabel": "variance", 
    "title": "Mean variance of I for varying beta"
})

errorbar_plot(betas, all_covs, {
    "xlabel": "beta", 
    "ylabel": "covariance", 
    "title": "Mean covariance of S and I for varying beta"
})

In [None]:
t_max = 100
ts_interpolate = np.linspace(0, t_max, 1000)
num_runs = 20

gamma = 0.1
beta = 0.3
params = (beta, gamma, 1/80)

Ns = np.linspace(1000, 20000, 21)

all_covs = []
all_variances = []
for N in Ns:
    y0 = (N - 10, 10, 0)
    
    covs = []
    Is_interpolate = []
    for _ in range(num_runs):
        ts, ys = SIR_GDA(y0=y0, params=params, t_max=t_max)
    
        S = ys[:, 0]
        I = ys[:, 1]
    
        covs.append(np.cov(S, I)[0][1])
        Is_interpolate.append(equal_time(ts_interpolate, ts, I))

    all_covs.append(covs)
    all_variances.append(np.var(Is_interpolate, axis=0))

In [None]:
errorbar_plot(Ns, all_variances, {
    "xlabel": "N", 
    "ylabel": "variance", 
    "title": "Mean variance of I for varying N"
})

errorbar_plot(Ns, all_covs, {
    "xlabel": "N", 
    "ylabel": "covariance", 
    "title": "Mean covariance of S and I for varying N"
})

### 1.3 Stochastic Resonance and Increased Transients

In [None]:
def get_endemic_equilibrium(params):
    """Get the endemic equilibrium for the given setup."""
    beta, gamma, mu = params
    R0 = beta / (gamma + mu)

    Sp = 1/R0
    Ip = mu/beta*(R0 - 1)
    Rp = 1 - 1/R0 - mu/beta*(R0 - 1)

    return Sp, Ip, Rp

def normalize_population(ys, N):
    """Normalize the given values."""
    return ys / N

def plot_deviations(params, N):
    """Plot fluctuations from the endemic equilibrium for the given parameters."""
    beta, gamma, mu = params

    R0 = beta / (gamma + mu)
    
    title = f"Deviations for R0={R0}, N={N}"
    
    num_trials = 10
    t_max = 2000
    ts_equal = np.linspace(0, t_max, 1000)

    idx_equilibrium = np.searchsorted(ts_equal, 400)
    
    y0 = (N-10, 10, 0)
    
    _, Ip, _ = get_endemic_equilibrium(params)
    
    Is = []
    for _ in range(num_trials):
        ts, ys = SIR_GDA(y0, params, t_max)
    
        Y = ys[:, 1]
        I = normalize_population(equal_time(ts_equal, ts, Y), N)
    
        Is.append(I)
    
    title = rf"Plot for $R_0$ = {R0:.2f}, N = {N}"
    
    fig, ax = plt.subplots()
    
    for I in Is:
        ax.plot(ts_equal, I, color="red", alpha=0.2)
    ax.axhline(Ip, color="#660000", label="equilibrium")
    
    set_ax_props(ax, {"xlabel": "time (days)", "ylabel": "I", "title": title})

    ax.set_ylim([0, 1])
    plt.legend()
    save_fig(title)
    plt.show()

In [None]:
plot_deviations(params=(0.3, 0.1, 1/80), N=2000)
plot_deviations(params=(0.8, 0.1, 1/80), N=2000)
plot_deviations(params=(0.3, 0.1, 1/80), N=8000)
plot_deviations(params=(0.8, 0.1, 1/80), N=8000)

In [None]:
def MSE(a, b):
    """Calculate MSE"""
    return np.mean((a - b)**2)

def get_MSE(params, N):
    """Get the MSE for the epidemic curve against the equilibrium after the cutoff time."""
    beta, gamma, mu = params
    
    R0 = beta / (gamma + mu)
            
    num_trials = 10
    t_max = 3000
    ts_equal = np.linspace(0, t_max, 1000)
    
    idx_equilibrium = np.searchsorted(ts_equal, 500)
    
    y0 = (N-10, 10, 0)
    
    _, Ip, _ = get_endemic_equilibrium(params)
    
    Is = []
    for _ in range(num_trials):
        ts, ys = SIR_GDA(y0, params, t_max)
    
        Y = ys[:, 1]
        I = normalize_population(equal_time(ts_equal, ts, Y), N)
    
        Is.append(I[idx_equilibrium:])
    
    Is_flattened = np.array(Is).flatten()

    return MSE(Is_flattened, Ip)

In [None]:
N = 2000
betas = np.linspace(0.2, 1.2, 15)
gamma = 0.1
mu = 1/80

results = []
for beta in betas:
    results.append(get_MSE(params=(beta, gamma, mu), N=N))

title = "MSEs for betas"
fig, ax = plt.subplots()

ax.plot(betas, results, color="black", label="MSE", marker=".", linestyle="--")
set_ax_props(ax, {"xlabel": "beta", "ylabel": "MSE", "title": title})

plt.legend()
plt.grid()
save_fig(title)

plt.show()

beta = 0.3
gamma = 0.1
mu = 1/80
params = (0.3, 0.1, 1/80)

Ns = np.linspace(2000, 8000, 15)

results = []
for N in Ns:
    results.append(get_MSE(params=(beta, gamma, mu), N=N))

title = "MSEs for Ns"
fig, ax = plt.subplots()

ax.plot(Ns, results, color="black", label="MSE", marker=".", linestyle="--")
set_ax_props(ax, {"xlabel": "N", "ylabel": "MSE", "title": title})

plt.legend()
plt.grid()
save_fig(title)

plt.show()

### 1.4 Extinction events and Critical Community Size

In [None]:
def is_extinct(ys):
    """Check whether disease is extinct"""
    last_val = ys[:, 1][-1]
    
    return last_val == 0

In [None]:
t_max = 100000
Ns = [1000, 2000, 3000, 4000, 5000, 6000, 7000]
# Ns = [1000, 2000, 3000]
# betas = [0.1, 0.11]
gamma = 0.1
mu = 1/80
betas = np.array([0.1, 0.1025, 0.105, 0.1075, 0.110, 0.1125, 0.1150, 0.1175]) + mu
num_trials = 50

In [None]:
def get_average_extinction_time(y0, params):
    extinction_times = []
    for _ in range(num_trials):
        ts, ys = SIR_GDA(y0, params, t_max)
    
        extinct = is_extinct(ys)
                    
        if extinct:
            extinction_times.append(ts[-1])
    
    mean_extinction_time = np.mean(extinction_times)

    return mean_extinction_time

# Meshgrid for grid drawing
XX, YY = np.meshgrid(Ns, betas)
# Prepare to store result
ZZ = np.zeros_like(XX)

for i in range(XX.shape[0]):
    for j in range(XX.shape[1]):
        N = XX[i, j]
        beta = YY[i, j]

        print("Now processing", (N, beta))
        
        y0 = (N-10, 10, 0)
        params = (beta, gamma, mu)
        ZZ[i, j] = get_average_extinction_time(y0, params)

In [None]:
title=r"Exctinction time, $R_0$ vs $N$"
fig, ax = plt.subplots()

YY_R0 = YY / (gamma + mu)

for i in range(XX.shape[0]):
    for j in range(XX.shape[1]):
        N = XX[i, j]
        R0 = YY_R0[i, j]
        time = ZZ[i, j]
        
        ax.text(N, R0, f'{np.log10(time):.1f}', ha='center', va='center', color='black', bbox=dict(facecolor='white', edgecolor='none', pad=0, alpha=0.7))

c = ax.pcolormesh(XX, YY_R0, np.log10(ZZ), cmap='Blues')
fig.colorbar(c, ax=ax, label=r"mean extinction time ($log_{10}$ days)")

set_ax_props(ax, {"xlabel": r"$N$", "ylabel": r"$R_0$", "title": title})

save_fig(title)

# 2. Networks

### 2.1 Implement SIR Disease Spread on Network 

In [None]:
def SIR_network(graph, beta, gamma, fraction_infected, model_it):
    ''' Run the SIR model on a network '''
    
    # Model selection
    model = ep.SIRModel(graph)
    
    # Model Configuration
    cfg = mc.Configuration()
    cfg.add_model_parameter('beta', beta)
    cfg.add_model_parameter('gamma', gamma)
    cfg.add_model_parameter("fraction_infected", fraction_infected)
    model.set_initial_status(cfg)
    
    # Simulation execution
    iterations = model.iteration_bunch(model_it)
    
    # Extract S, I, R trends
    Susceptible = [iteration['node_count'][0] for iteration in iterations]
    Infected = [iteration['node_count'][1] for iteration in iterations]
    Recovered = [iteration['node_count'][2] for iteration in iterations]

    return Susceptible, Infected, Recovered

    
def viz_subplot(ax, time, all_susceptible, all_infected, all_recovered, avg_susceptible, avg_infected, avg_recovered, title):
    ''' Plot different trajectories for Susceptible, Infected and Recoveren
    and the average trajectorie '''

    for susceptible, infected, recovered in zip(all_susceptible, all_infected, all_recovered):
        ax.plot(time, susceptible, alpha=.2, color="red")
        ax.plot(time, infected, alpha=.2, color="green")
        ax.plot(time, recovered, alpha=.2, color="blue")
    
    # Plot the averages
    ax.plot(time, avg_susceptible, color="red", label="S (average)", linewidth=2)
    ax.plot(time, avg_infected, color="green", label="I (average)", linewidth=2)
    ax.plot(time, avg_recovered, color="blue", label="R (average)", linewidth=2)

    ax.set_title(title)
    

def SIR_network_plot(time, network, betas, gammas, fraction_infected, model_it, n_it):
    ''' Compute the data for different SIR trajectories on a network for different parameter values,
    combining the subplots one figure '''

    dbeta = len(betas)
    dgamma = len(gammas)
    fig, axs = plt.subplots(dbeta, dgamma, figsize=(5*dbeta, 4*dgamma))
    axs = axs.flatten()
     
    j=0
    for beta in betas:
        for gamma in gammas:
            n_time_points = len(time)  

            # Create arrays
            all_S = np.zeros((n_it, n_time_points))
            all_I = np.zeros((n_it, n_time_points))
            all_R = np.zeros((n_it, n_time_points))

            for i in range(n_it):
                S, I, R = SIR_network(network, beta, gamma, fraction_infected, model_it)
                
                all_S[i]=S
                all_I[i]=I
                all_R[i]=R
            
            # Compute averages over all iterations
            avg_S = np.mean(all_S, axis=0)
            avg_I = np.mean(all_I, axis=0)
            avg_R = np.mean(all_R, axis=0)
                
            title = f"Beta: {beta}, Gamma: {gamma}"
            viz_subplot(axs[j], time, all_S, all_I, all_R, avg_S, avg_I, avg_R, title)
            j+=1
    fig.suptitle(network_name(network), fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 1])
    props_fig()
    save_fig(network_name(network))
    plt.show()

### 2.2 Generating Networks

In [None]:
# Network name
def network_name(network):
    ''' Function to assign names to the networks'''
    if network == g_ER:
        network_name = "Erdos-Reyni Network"
    elif network == g_BA:
        network_name = "Barbasi-Albert Network"
    elif network == g_WS:
        network_name = "Watts-Strogatz Network"
    return network_name

### 2.3 Network Statistics

In [None]:

def statistics_plot(ax, network, statistics):
    ''' Plots the given statistics in a frequency plot'''
    
    title = (f"Frequency of {statistic_name(statistics)} in a {network_name(network)}")

    ax.hist(statistics, bins=100, color='blue')
    #ax.tight_layout()
    ax.set_xlabel("value")
    ax.set_ylabel("count")
    ax.set_title(title)
    #plt.show()
    #save_fig(title)
    
def statistics_viz(ax, network, statistics):
    ''' Visualizes the given statistics in a graph'''

    title = (f"Visualization of {statistic_name(statistics)} in a small {network_name(network)}")

    # Generate layout for the graph
    format = nx.spring_layout(network)

    # Set nodes colors with color mapping based on given statistics
    node_colors = np.array(statistics) 

    # Draw nodes and edges
    nodes = nx.draw_networkx_nodes(network, format, cmap=plt.cm.plasma, node_size = 50 ,node_color = node_colors,  edgecolors = 'black', linewidths= 0.5, ax=ax)
    nx.draw_networkx_edges(network, format , edge_color= 'black', alpha =0.4, ax=ax)

    # Add a color bar to the plot
    plt.colorbar(nodes, ax=ax)

    ax.set_title(title)
    ax.set_axis_off()


def statistic_name(statistics):
    if statistics == centrality_values:
        name = "Degree centrality"
    if statistics == betweenness_values:
        name = "Betweenness centrality"
    if statistics == clustering_values:
        name = "Clustering coefficient"
    return name

### 2.3.1 Erdos-Reyni Graph

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(16, 10))
axs = axs.flatten()
it = 100

all_cent = np.zeros((it,1000))
all_betw = np.zeros((it,1000))
all_clus = np.zeros((it,1000))

for t in range(it):
    # Network topology: Large graph for statistics
    g_ER = nx.erdos_renyi_graph(1000, 0.1)          
    network = g_ER

    # Centrality
    centrality = nx.degree_centrality(network)
    centrality_values = list(centrality.values())

    # Betweenness
    betweenness = nx.betweenness_centrality(network)
    betweenness_values = list(betweenness.values())

    # clustering
    clustering = nx.clustering(network)
    clustering_values = list(clustering.values())

    all_cent[t] = centrality_values
    all_betw[t] = betweenness_values
    all_clus[t] = clustering_values

centrality_values = [sum(group) / len(group) for group in zip(*all_cent)]
betweenness_values = [sum(group) / len(group) for group in zip(*all_betw)]
clustering_values = [sum(group) / len(group) for group in zip(*all_clus)]

statistics_plot(axs[0], network, centrality_values)
statistics_plot(axs[2], network, betweenness_values)
statistics_plot(axs[4], network, clustering_values)


# Network topology: small graph for visualization
g_ER = nx.erdos_renyi_graph(150, 0.1)  #Erdos-Reyni graph
network = g_ER

# Centrality
centrality = nx.degree_centrality(network)
centrality_values = list(centrality.values())

# Betweenness
betweenness = nx.betweenness_centrality(network)
betweenness_values = list(betweenness.values())

# clustering
clustering = nx.clustering(network)
clustering_values = list(clustering.values())

statistics_viz(axs[1], network, centrality_values)
statistics_viz(axs[3], network, betweenness_values)
statistics_viz(axs[5], network, clustering_values)


plt.show

### 2.3.2 Barabasi-Albert Graph

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(16, 10))
axs = axs.flatten()

it = 100

all_cent = np.zeros((it,1000))
all_betw = np.zeros((it,1000))
all_clus = np.zeros((it,1000))

for t in range(it):
    # Network topology: Large graph for statistics
    g_BA = nx.barabasi_albert_graph(1000, 6)          
    network = g_BA

    # Centrality
    centrality = nx.degree_centrality(network)
    centrality_values = list(centrality.values())

    # Betweenness
    betweenness = nx.betweenness_centrality(network)
    betweenness_values = list(betweenness.values())

    # clustering
    clustering = nx.clustering(network)
    clustering_values = list(clustering.values())

    all_cent[t] = centrality_values
    all_betw[t] = betweenness_values
    all_clus[t] = clustering_values

centrality_values = [sum(group) / len(group) for group in zip(*all_cent)]
betweenness_values = [sum(group) / len(group) for group in zip(*all_betw)]
clustering_values = [sum(group) / len(group) for group in zip(*all_clus)]

statistics_plot(axs[0], network, centrality_values)
statistics_plot(axs[2], network, betweenness_values)
statistics_plot(axs[4], network, clustering_values)


# Network topology: small graph for visualization
g_BA = nx.barabasi_albert_graph(150, 6)          
network = g_BA

# Centrality
centrality = nx.degree_centrality(network)
centrality_values = list(centrality.values())

# Betweenness
betweenness = nx.betweenness_centrality(network)
betweenness_values = list(betweenness.values())

# clustering
clustering = nx.clustering(network)
clustering_values = list(clustering.values())

statistics_viz(axs[1], network, centrality_values)
statistics_viz(axs[3], network, betweenness_values)
statistics_viz(axs[5], network, clustering_values)


plt.show

### 2.3.3 Watts-Strogatz Network

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(16, 10))
axs = axs.flatten()

it = 100

all_cent = np.zeros((it,1000))
all_betw = np.zeros((it,1000))
all_clus = np.zeros((it,1000))

for t in range(it):
    # Network topology: Large graph for statistics
    g_WS = nx.watts_strogatz_graph(1000, 6, 0.1)          
    network = g_WS

    # Centrality
    centrality = nx.degree_centrality(network)
    centrality_values = list(centrality.values())

    # Betweenness
    betweenness = nx.betweenness_centrality(network)
    betweenness_values = list(betweenness.values())

    # clustering
    clustering = nx.clustering(network)
    clustering_values = list(clustering.values())

    all_cent[t] = centrality_values
    all_betw[t] = betweenness_values
    all_clus[t] = clustering_values

centrality_values = [sum(group) / len(group) for group in zip(*all_cent)]
betweenness_values = [sum(group) / len(group) for group in zip(*all_betw)]
clustering_values = [sum(group) / len(group) for group in zip(*all_clus)]

statistics_plot(axs[0], network, centrality_values)
statistics_plot(axs[2], network, betweenness_values)
statistics_plot(axs[4], network, clustering_values)


# Network topology: small graph for visualization
g_WS = nx.watts_strogatz_graph(150, 6, 0.1)          
network = g_WS

# Centrality
centrality = nx.degree_centrality(network)
centrality_values = list(centrality.values())

# Betweenness
betweenness = nx.betweenness_centrality(network)
betweenness_values = list(betweenness.values())

# clustering
clustering = nx.clustering(network)
clustering_values = list(clustering.values())

statistics_viz(axs[1], network, centrality_values)
statistics_viz(axs[3], network, betweenness_values)
statistics_viz(axs[5], network, clustering_values)


plt.show

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(16, 10))
axs = axs.flatten()

it = 100

all_cent = np.zeros((it,1000))
all_betw = np.zeros((it,1000))
all_clus = np.zeros((it,1000))

for t in range(it):
    # Network topology: Large graph for statistics
    g_WS = nx.watts_strogatz_graph(1000, 3, 0.05)          
    network = g_WS

    # Centrality
    centrality = nx.degree_centrality(network)
    centrality_values = list(centrality.values())

    # Betweenness
    betweenness = nx.betweenness_centrality(network)
    betweenness_values = list(betweenness.values())

    # clustering
    clustering = nx.clustering(network)
    clustering_values = list(clustering.values())

    all_cent[t] = centrality_values
    all_betw[t] = betweenness_values
    all_clus[t] = clustering_values

centrality_values = [sum(group) / len(group) for group in zip(*all_cent)]
betweenness_values = [sum(group) / len(group) for group in zip(*all_betw)]
clustering_values = [sum(group) / len(group) for group in zip(*all_clus)]

statistics_plot(axs[0], network, centrality_values)
statistics_plot(axs[2], network, betweenness_values)
statistics_plot(axs[4], network, clustering_values)


# Network topology: small graph for visualization
g_WS = nx.watts_strogatz_graph(150, 3, 0.05)          
network = g_WS

# Centrality
centrality = nx.degree_centrality(network)
centrality_values = list(centrality.values())

# Betweenness
betweenness = nx.betweenness_centrality(network)
betweenness_values = list(betweenness.values())

# clustering
clustering = nx.clustering(network)
clustering_values = list(clustering.values())

statistics_viz(axs[1], network, centrality_values)
statistics_viz(axs[3], network, betweenness_values)
statistics_viz(axs[5], network, clustering_values)


plt.show

### 2.4 Simulate SIR Disease Spread on Network

In [None]:
def SIR_network(graph, beta, gamma, fraction_infected, model_it):
    ''' Run the SIR model on a network '''
    
    # Model selection
    model = ep.SIRModel(graph)
    
    # Model Configuration
    cfg = mc.Configuration()
    cfg.add_model_parameter('beta', beta)
    cfg.add_model_parameter('gamma', gamma)
    cfg.add_model_parameter("fraction_infected", fraction_infected)
    model.set_initial_status(cfg)
    
    # Simulation execution
    iterations = model.iteration_bunch(model_it)
    
    # Extract S, I, R trends
    Susceptible = [iteration['node_count'][0] for iteration in iterations]
    Infected = [iteration['node_count'][1] for iteration in iterations]
    Recovered = [iteration['node_count'][2] for iteration in iterations]

    return Susceptible, Infected, Recovered

    
def viz_subplot(ax, time, all_susceptible, all_infected, all_recovered, avg_susceptible, avg_infected, avg_recovered, title):
    ''' Plot different trajectories for Susceptible, Infected and Recoveren
    and the average trajectorie '''

    for susceptible, infected, recovered in zip(all_susceptible, all_infected, all_recovered):
        ax.plot(time, susceptible, alpha=.2, color="blue")
        ax.plot(time, infected, alpha=.2, color="red")
        ax.plot(time, recovered, alpha=.2, color="green")
    
    # Plot the averages
    ax.plot(time, avg_susceptible, color="blue", label="Susceptible", linewidth=2)
    ax.plot(time, avg_infected, color="red", label="Infected", linewidth=2)
    ax.plot(time, avg_recovered, color="green", label="Recovered", linewidth=2)

    # Add legend
    ax.legend(loc='upper right')
    ax.set_title(title)
    

def SIR_network_plot(time, network, betas, gammas, fraction_infected, model_it, n_it, N):
    ''' Compute the data for different SIR trajectories on a network for different parameter values,
    combining the subplots one figure '''

    dbeta = len(betas)
    dgamma = len(gammas)
    fig, axs = plt.subplots(dbeta, dgamma, figsize=(5*dbeta, 4*dgamma))
    axs = axs.flatten()
     
    j=0
    for beta in betas:
        for gamma in gammas:
            n_time_points = len(time)  

            # Create arrays
            all_S = np.zeros((n_it, n_time_points))
            all_I = np.zeros((n_it, n_time_points))
            all_R = np.zeros((n_it, n_time_points))

            for i in range(n_it):
                S, I, R = SIR_network(network, beta, gamma, fraction_infected, model_it)
                
                S = np.array(S)
                I = np.array(I)
                R = np.array(R)
                

                all_S[i] = S / N
                all_I[i] = I / N
                all_R[i] = R / N
            
            # Compute averages over all iterations
            avg_S = np.mean(all_S, axis=0)
            avg_I = np.mean(all_I, axis=0)
            avg_R = np.mean(all_R, axis=0)
                
            title = f"Beta: {beta}, Gamma: {gamma}"
            viz_subplot(axs[j], time, all_S, all_I, all_R, avg_S, avg_I, avg_R, title)
            j+=1
    # Create dummy data for each category
    #plt.plot([], [], 'ro', label='Susceptible')
    #plt.plot([], [], 'go', label='Infected')
    #plt.plot([], [], 'bo', label='Recovered')

    fig.suptitle(network_name(network), fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 1])
    save_fig(network_name(network))
    plt.show()


### 2.4.1 Erdos-Reyni Graph

In [None]:
fraction_infected = 0.1
betas = [.1, .01]
gammas = [.01, .001]
N = 1000

model_it = 300 #time
n_it = 50 #number of plots

time = list(range(model_it))

# Network topology: Large graph for statistics
g_ER = nx.erdos_renyi_graph(N, 0.1)          
network = g_ER

SIR_network_plot(time, network, betas, gammas, fraction_infected, model_it, n_it, N)

### 2.4.2 Barabasi-Albert graph

In [None]:
fraction_infected = 0.1
betas = [.1, .01]
gammas = [.01, .001]
N = 1000

model_it = 300 #time
n_it = 50 #number of plots

time = list(range(model_it))

# Network topology: Large graph for statistics
g_BA = nx.barabasi_albert_graph(N, 6)          
network = g_BA

SIR_network_plot(time, network, betas, gammas, fraction_infected, model_it, n_it, N)

### 2.4.3 Watts-Strogatz graph

In [None]:
fraction_infected = 0.1
betas = [.1, .01]
gammas = [.01, .001]
N = 1000

model_it = 300 #time
n_it = 50 #number of plots

time = list(range(model_it))

# Network topology: Large graph for statistics
g_WS = nx.watts_strogatz_graph(N, 6, 0.1)          
network = g_WS

SIR_network_plot(time, network, betas, gammas, fraction_infected, model_it, n_it, N)

In [None]:
fraction_infected = 0.1
betas = [.1, .01]
gammas = [.01, .001]
N = 1000

model_it = 300 #time
n_it = 50 #number of plots

time = list(range(model_it))

# Network topology: Large graph for statistics
g_WS = nx.watts_strogatz_graph(N, 3, 0.05)          
network = g_WS

SIR_network_plot(time, network, betas, gammas, fraction_infected, model_it, n_it, N)