In [None]:
import numpy as np
from scipy.ndimage import convolve
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

In [None]:
#set random seed for reproducibility
np.random.seed(212)

# Define states
IDLE, SECRETORY, RESTING = 0, 1, 2

# Parameters 
dt = 0.04                      # time step, small to simulate continuous time
beta  = 1.0                    # activation rate 
gamma = 1.0                    # IL2 secretion duration 
rho   = 1.0                    # Refractory duration 

# Time schedule for snapshots
time_space = np.arange(1, 30, 2)  # for continuous time simulation
steps = int(np.round(np.max(time_space) / dt))
print("steps:", steps)

# Grid initialization
grid_shape = (200, 200)
grid = np.zeros(grid_shape, dtype=np.uint8)
initiation_timer = np.zeros(grid_shape)  # timer for starting secretion
recovery_timer = np.zeros(grid_shape)  # timer for recovery
sensitive_timer = np.zeros(grid_shape)  # timer for cells that are idle but remain sensitive

# start at random places
centers = np.random.randint(3, 180, size=(3, 2))  # 3 random centers in the grid

for x, y in centers:
    dx, dy = np.random.randint(-5, 5, 2)
    xi, yi = np.clip([x + dx, y + dy], 0, 200)
    grid[xi, yi] = SECRETORY

# 8-cell Moore neighborhood kernel 
kernel = np.array([[1, 1, 1],
                   [1, 0, 1],
                   [1, 1, 1]])

# Run the simulation
frames = []
for t in range(steps):
    # Count initiated neighbors
    initiated_neighbors = convolve((grid == SECRETORY).astype(np.uint8), kernel, mode='constant')

    # Calculate infection probability
    p_inf = 1 - np.exp(-beta * initiated_neighbors * dt)

    # activate sensitive cells
    rand_vals = np.random.rand(*grid_shape)
    new_initiations = (grid == IDLE) & (rand_vals < p_inf)
    grid[new_initiations] = SECRETORY

    recovery_timer[new_initiations] = 0  # Reset recovery timer when infected
    initiation_timer[new_initiations] = 0
    sensitive_timer[new_initiations] = 0  # Reset sensitive timer when infected

    # Update timers and state transitions
    initiation_timer[grid == SECRETORY] += dt
    to_recover = (grid == SECRETORY) & (initiation_timer >= gamma)
    grid[to_recover] = RESTING

    sensitive_timer[to_recover] = 0  # Reset sensitive timer when recovering
    recovery_timer[to_recover] = 0
    initiation_timer[to_recover] = 0  # Optional: reset initiation timer

    recovery_timer[grid == RESTING] += dt
    to_reset = (grid == RESTING) & (recovery_timer >= rho)
    grid[to_reset] = IDLE

    initiation_timer[to_reset] = 0  # Reset initiation timer on return to susceptible
    sensitive_timer[to_reset] = 0  # Reset timer on return to susceptible
    recovery_timer[to_reset] = 0     # Optional: reset recovery timer

    # Increment sensitive timer for sensitive cells
    sensitive_timer[grid == IDLE] += dt

    # Store timers as a 3D tensor for visualization/ML
    frames.append(grid.copy())

# Convert to one-hot encoded 3D array for visualization/ML
frames = np.array(frames)
frames_idx = (time_space / dt).astype(int) - 1
output = frames[frames_idx]

In [None]:
# my color map for visualization 
my_cmap = ListedColormap(['#000000',  '#009E73', "#C98B05"]) # using colorblind-friendly colors

# plot frames for visualization as subplots
fig, axs = plt.subplots(3, 5, figsize=(15, 10))
for i, ax in enumerate(axs.flatten()):
    if i < len(frames_idx):
        ax.imshow(output[i], cmap=my_cmap, vmin=0, vmax=2)
        ax.set_title(f'Time: {time_space[i]:}')
    else:
        ax.axis('off')  # Hide unused subplots
plt.tight_layout()

In [None]:
# # visualize a single frame
# plt.figure(figsize=(6, 6))
# plt.imshow(output[14], cmap=my_cmap, vmin=0, vmax=2)
# plt.axis('off')
# # add legend
# import matplotlib.patches as mpatches
# sensitive_patch = mpatches.Patch(color='#009E73', label='SECRETORY')
# initiated_patch = mpatches.Patch(color='#000000', label='IDLE')
# recovered_patch = mpatches.Patch(color='#C98B05', label='RESTING')
# plt.legend(handles=[sensitive_patch, initiated_patch, recovered_patch], loc='lower right', fontsize='large')

In [None]:
# # plot simulation as a video
# import matplotlib.animation as animation    
# fig, ax = plt.subplots(figsize=(6, 6))
# im = ax.imshow(frames[0], cmap=my_cmap, vmin=0, vmax=2)
# ax.set_title('Model Simulation')   
# def update(frame_idx):
#     # add title with time
#     ax.set_title(f'Time: {frame_idx*dt:.1f}')
#     im.set_array(frames[frame_idx])
#     return [im]
# ani = animation.FuncAnimation(fig, update, frames=len(frames), interval=100, blit=True)

# # faster framerate for gif
# # Try using pillow writer for gif format, which is more universally available
# ani.save('sirs_simulation2.gif', writer='pillow', fps=60)