In [10]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import ListedColormap, BoundaryNorm
import time
import os

In [11]:
class RK4Result:
    def __init__(self, t, y):
        self.t = t
        self.y = y

In [12]:
class SEIQRDP_2D:
    """
    SEIQRDP model for 2d fire spread simulation
    """
    def __init__(self, grid_shape, N_total, params, initial_state, times, dt):
        self.rows, self.cols = grid_shape
        self.n_patches = self.rows*self.cols
        self.N = N_total
        self.N_patch = self.N/self.n_patches

        self.params = params
        self.times = times
        self.t_span = (times[0], times[-1])
        self.dt = dt

        for r in range(self.rows):
            for c in range(self.cols):
                total = np.sum(initial_state[r, c])
                
                if abs(total - self.N_patch) > 1e-3:
                    
                    if total > 0:
                        factor = self.N_patch/total
                        initial_state[r, c] =  initial_state[r, c]*factor

                    else:
                        initial_state[r, c, 0] = self.N_patch 

        self.y0 = initial_state.flatten()
        self.initial_state = initial_state

    def _idx(self, r, c):
        return r*self.cols + c

    def _coords(self, idx):
        row = idx//self.cols 
        col = idx % self.cols
        return row, col

    def _neighbors(self, r, c):
        neighbors = []
        
        if r > 0:
            neighbors.append(self._idx(r - 1, c))
        
        if r < self.rows - 1:
            neighbors.append(self._idx(r + 1, c))

        if c > 0:
            neighbors.append(self._idx(r, c - 1))
        
        if c < self.cols - 1:
            neighbors.append(self._idx(r, c + 1))

        return neighbors

    def _dydt(self, t, y):
        y_grid = y.reshape((self.n_patches, 7))
        dydt = np.zeros_like(y_grid)

        alpha = self.params['alpha'](t)
        beta_within = self.params['beta_within'](t)
        beta_between = self.params['beta_between'](t)
        gamma = self.params['gamma'](t)
        delta = self.params['delta'](t)
        kappa = self.params['kappa'](t)
        r_prob = self.params['rho'](t)

        for i in range(self.n_patches):
            r, c = self._coords(i)
            S, E, I, Q, _, D, P = np.maximum(0, y_grid[i])
            
            local_spread = beta_within*S*I/self.N_patch
            neighbor_spread = 0.0 
            neighbors = self._neighbors(r, c)
            
            for n in neighbors:
                neighbor_infectious = max(0, y_grid[n, 2])  
                spread_from_neighbor = beta_between*S*neighbor_infectious/self.N_patch
                neighbor_spread += spread_from_neighbor 

            dS = -alpha*S - local_spread - neighbor_spread
            dE = local_spread + neighbor_spread - gamma*E
            dI = gamma*E - delta*I
            dR = 0
            dQ = r_prob*delta*I - kappa*Q
            dD = (1 - r_prob)*delta*I + kappa*Q
            dP = alpha*S

            dydt[i] = [dS, dE, dI, dQ, dR, dD, dP]
            
        return dydt.flatten()

    def solve_rk4(self):
        num_steps = len(self.times)
        y = self.y0.copy()
        result = np.zeros((num_steps, len(self.y0)))
        result[0] = y

        for i in range(1, num_steps):
            t = self.times[i - 1]
            k1 = self._dydt(t, y)
            k2 = self._dydt(t + self.dt/2, y + k1*self.dt/2)
            k3 = self._dydt(t + self.dt/2, y + k2*self.dt/2)
            k4 = self._dydt(t + self.dt, y + k3*self.dt)
            y += (self.dt/6)*(k1 + 2*k2 + 2*k3 + k4)
            y = np.maximum(y, 0)
            result[i] = y

        print("RK4 complete.")
        sol = RK4Result(t=self.times, y=result.T)
        return sol


In [13]:
def plot_snapshots(model, sol, times):
    time_points = [0, times[-1]*0.25, times[-1]*0.5, times[-1]*0.75, times[-1]]
    labels = ["Initial", "Quarter", "Half", "3/4", "Final"]

    colors = ['green', 'orange', 'red', 'hotpink', 'black', 'purple']
    cmap = ListedColormap(colors)
    norm = BoundaryNorm(np.arange(len(colors)+1) - 0.5, cmap.N)
    tick_labels = ['S', 'E', 'I', 'Q', 'D', 'P']
    indices = [0, 1, 2, 3, 5, 6]

    for tp, label in zip(time_points, labels):
        idx = np.argmin(np.abs(sol.t - tp))
        grid = sol.y[:, idx].reshape((model.rows, model.cols, 7))
        dominant = np.argmax(grid[:, :, indices], axis = 2)

        plt.figure(figsize = (8, 6))
        im = plt.imshow(dominant, cmap = cmap, norm = norm, origin = 'lower')

        plt.xticks(np.arange(model.cols))
        plt.yticks(np.arange(model.rows))
        plt.grid(True, which = 'both', color = 'gray', linewidth = 0.5, linestyle = '-', alpha = 0.5)
        plt.tick_params(axis = 'both', which = 'both', length = 0)

        cbar = plt.colorbar(im, ticks = np.arange(len(colors)))
        cbar.ax.set_yticklabels(tick_labels)
        plt.title(f"{label} - Time: {sol.t[idx]:.1f}")
        plt.tight_layout()
        plt.savefig(f"snapshot_{label.lower().replace('/', '')}.png", dpi = 150)
        plt.close()

In [None]:
def animate_grid(model, sol, times, interval = 100, filename = "fire_spread.gif"):
    colors = ['green', 'orange', 'red', 'hotpink', 'black', 'purple']
    cmap = ListedColormap(colors)
    norm = BoundaryNorm(np.arange(len(colors)+1) - 0.5, cmap.N)
    indices = [0, 1, 2, 3, 5, 6]

    fig, ax = plt.subplots(figsize = (8, 6))
    initial = sol.y[:, 0].reshape((model.rows, model.cols, 7))
    img = ax.imshow(np.argmax(initial[:, :, indices], axis = 2),
                    cmap = cmap, norm = norm, animated = True, origin = 'lower')

    cbar = plt.colorbar(img, ticks = np.arange(len(colors)))
    cbar.ax.set_yticklabels(['S', 'E', 'I', 'Q', 'D', 'P'])
    title = ax.set_title(f"Time: {sol.t[0]:.1f}")

    ax.set_xticks(np.arange(model.cols+1)-0.5, minor = True)
    ax.set_yticks(np.arange(model.rows+1)-0.5, minor = True)
    ax.grid(True, which = 'minor', color = 'gray', linewidth = 0.5)
    ax.tick_params(which = 'minor', size = 0)

    def update(frame):
        data = sol.y[:, frame].reshape((model.rows, model.cols, 7))
        dominant = np.argmax(data[:, :, indices], axis = 2)
        img.set_array(dominant)
        title.set_text(f"Time: {sol.t[frame]:.1f} Days")
        return img, title
    
    ani = animation.FuncAnimation(fig, update, frames = len(sol.t), interval = interval, blit = True)
    ani.save(filename, writer = 'pillow', fps = max(1, 1000//interval))
    plt.close(fig)

In [15]:
def plot_param_curves(params, times, save_loc):
    plt.figure(figsize = (10, 6))
    for key, func in params.items():
        plt.plot(times, [func(t) for t in times], label = key)
    plt.xlabel("Time")
    plt.ylabel("Rate/Probability")
    plt.title("Model Parameters Over Time")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    filename = os.path.join(save_loc, "parameter_evolution.png")
    plt.savefig(filename, dpi = 500)
    plt.close()

In [16]:

def alpha_func(t, rate=0.01, max_effort_day=15):
    if t < max_effort_day:
        return rate*t/max_effort_day
    else:
        return rate

def beta_func(t, initial_spread=0.9, peak_day=10, decay_rate=0.01):
    if t < peak_day:
        return initial_spread*(0.5 + 0.5*t/peak_day)
    else:
        return initial_spread

def gamma_func(t, constant_rate=1.0/3.0):
    return constant_rate

def delta_func(t, initial_rate=0.05, max_rate=0.4, growth_factor=0.15, midpoint_day=20):
    return initial_rate + (max_rate - initial_rate)/(1 + np.exp(-growth_factor*(t - midpoint_day)))

def lamda_func(t, lambda0=0):
    return lambda0

def kappa_func(t, k0=0.03, center_t=30, sigma=8):
    return k0*np.exp(-((t - center_t)**2)/(2*sigma**2))


In [17]:
def main():
    rows, cols = 25, 25
    N_total = 1850000
    N_patch = N_total/(rows*cols)
    t_max = 100
    dt = 0.5
    times = np.arange(0, t_max + dt, dt)

    np.random.seed(0)
    initial_state = np.zeros((rows, cols, 7))
    initial_state[:, :, 0] = N_patch 

    n_infected = 10
    infected_indices = np.random.choice(rows*cols, size = n_infected, replace = False) # random infection
    
    for idx in infected_indices:
        r = idx//cols 
        c = idx%cols
    
        initial_state[r, c, 0] = N_patch*0.8  # S
        initial_state[r, c, 2] = N_patch*0.2  # I

    params = {'alpha': alpha_func, 'beta_within': beta_func, 'beta_between': lambda t: 0.2*(0.5 + 0.5*min(t/12, 1)), 
    'gamma': gamma_func, 'delta': delta_func, 'lamda': lamda_func, 'kappa': kappa_func, 'rho': lambda t: 0.8}

    model = SEIQRDP_2D(grid_shape = (rows, cols), N_total = N_total, params = params, initial_state = initial_state, times = times, dt = dt)
    sol = model.solve_rk4()
    
    # plot_snapshots(model, sol, times)
    save_loc = os.path.join(os.getcwd(), 'output')
    if not os.path.exists(save_loc):
        os.makedirs(save_loc)
    animate_grid(model, sol, times, filename = os.path.join(save_loc, "fire_spread.gif"))
    plot_param_curves(params, times, save_loc)
    print('Animation complete, please find it in the output folder')

In [18]:
main()

RK4 complete.
Animation complete, please find it in the output folder
