In [1]:
import numpy as np
from scipy.integrate import odeint
import rasterio
import rasterio.mask
import matplotlib.pyplot as plt
import os

# Directory set
os.chdir("/Users/pranavkulkarni/SDM/main_repo/Climate_Models_Arenaviruses/Guan_SDM")
os.getcwd()



'/Users/pranavkulkarni/SDM/main_repo/Climate_Models_Arenaviruses/Guan_SDM'

In [2]:
# Define parameters
beta0 = 0.5     # Transmission rate
gamma = 0.1     # Recovery rate
max_distance = 10  # Maximum distance for transmission



In [None]:
# Load your population density raster
# Replace 'population_raster.tif' with the path to your population density raster file.
with rasterio.open('population_raster.tif') as src:
    population_raster = src.read(1)

# Create a distance kernel function
def distance_kernel(distance, max_distance):
    if distance <= max_distance:
        return 1
    else:
        return 0

# Define the differential equations for the SIR model with distance kernel
def sir_model(state, t, beta0, gamma, max_distance, population_raster):
    S, I, R = state

    # Get population density at current location
    population_density = population_raster[int(S), int(I)]

    dS = -beta0 * S * I / N * distance_kernel(1, max_distance) * population_density
    dI = beta0 * S * I / N * distance_kernel(1, max_distance) * population_density - gamma * I
    dR = gamma * I

    return [dS, dI, dR]



In [None]:
# Set the initial state
N = np.sum(population_raster)
initial_infected = 10
initial_state = [N - initial_infected, initial_infected, 0]

# Define time points
times = np.arange(0, 101, 1)

# Solve the differential equations
output = odeint(sir_model, initial_state, times, args=(beta0, gamma, max_distance, population_raster))

# Create a new raster to store the results
result_raster = population_raster.copy()
result_raster[:] = output[:, 1]  # Store the infected individuals at each location

In [None]:
# Plot the results
plt.imshow(result_raster, cmap='Reds', origin='lower', extent=(0, result_raster.shape[1], 0, result_raster.shape[0))
plt.colorbar(label='Infected Individuals')
plt.title('Infected Individuals')
plt.show()

# Export the result as a GeoTIFF
with rasterio.open('disease_transmission_results.tif', 'w', driver='GTiff', height=result_raster.shape[0], width=result_raster.shape[1], count=1, dtype=str(result_raster.dtype)) as dst:
    dst.write(result_raster, 1)