In [2]:
import matplotlib.pyplot as plt
import numpy as np

In [14]:
SUSCEPTIBLE = 0
INFECTED = 1
RECOVERED = 2

def init(N):
    L = np.sqrt(N)
    
    if not L.is_integer():
        raise 'N must be a square'
    
    cells = np.empty(N)
    cells.fill(SUSCEPTIBLE)
    cells[0] = INFECTED
    
    neighbors_list = np.empty((N, 4), dtype=int)
    for i in range(N):
        up = i + 1
        down = i - 1
        left = i - L
        right = i + L
        
        if i % L == 0:
            down = i + L - 1
        if (i + 1) % L == 0:
            up = i - L + 1
        if i < L:
            left = N - L + i
        if i < N and i >= N - L:
            right = i - (N - L)
            
        neighbors_list[i] = [up, right, down, left]
        
    return cells, neighbors_list

In [25]:
def simulate(N, prob_contamination, prob_recovery):
    cells, neighbors_list = init(N)

    susceptible_count = np.count_nonzero(cells == SUSCEPTIBLE)
    infected_count = np.count_nonzero(cells == INFECTED)
    recovered_count = np.count_nonzero(cells == RECOVERED)

    susceptible_over_time = [susceptible_count]
    infected_over_time = [infected_count]
    recovered_over_time = [recovered_count]

    while infected_count != N and infected_count != 0:
        for i in range(N):
            if (cells[i] != INFECTED):
                continue

            r = np.random.uniform()

            if (r < prob_recovery):
                cells[i] = RECOVERED
                infected_count -= 1
                recovered_count += 1

            neighbors = neighbors_list[i]

            for neighbor in neighbors:
                if (cells[neighbor] == INFECTED or cells[neighbor] == RECOVERED):
                    continue

                if (r < prob_contamination):
                    cells[neighbor] = INFECTED
                    infected_count += 1
                    susceptible_count -= 1

        susceptible_over_time.append(susceptible_count)
        infected_over_time.append(infected_count)
        recovered_over_time.append(recovered_count)

    return susceptible_over_time, infected_over_time, recovered_over_time


In [None]:
N = 100

probs_contamination = [0.1, 0.3, 0.01]
probs_recovery = [0.1, 0.01, 0.3]

susceptible_over_time_arr = []
infected_over_time_arr = []
recovered_over_time_arr = []

for i in enumerate(probs_contamination):
    susceptible_over_time, infected_over_time, recovered_over_time = simulate(N, probs_contamination[i], probs_recovery[i])
    
    


fig, ax = plt.subplots()
ax.plot(susceptible_over_time, label="susceptible")
ax.plot(infected_over_time, label="infected")
ax.plot(recovered_over_time, label="recovered")

ax.legend()

plt.show()