In [15]:
from scipy.integrate import odeint
import numpy as np

from functools import partial

In [16]:
def dpopdt(
    pop: np.ndarray,
    t: float,
    birth_rates: np.ndarray,
    death_rates: np.ndarray,
    mutation_rates: np.ndarray,
    transition_matrix: np.ndarray,
):
    return (
        birth_rates * pop - death_rates * pop + mutation_rates * pop @ transition_matrix
    )

In [28]:
# multitype_high_birth_fitness

init_pop = np.array([0.0, 1.0])
birth_rates = np.array([0.25, 1.0])
death_rates = np.array([0.5, 0.25])
mutation_rates = np.array([0.1, 0.25])
transition_matrix = np.array([[0.0, 1.0], [1.0, 0.0]])

model = partial(
    dpopdt,
    birth_rates=birth_rates,
    death_rates=death_rates,
    mutation_rates=mutation_rates,
    transition_matrix=transition_matrix,
)

ts = np.linspace(0, 25, 5)
pop_hist = odeint(model, init_pop, ts)
pop_hist

array([[0.00000000e+00, 1.00000000e+00],
       [3.01038906e+01, 1.23534194e+02],
       [3.81274261e+03, 1.56231940e+04],
       [4.82209438e+05, 1.97590993e+06],
       [6.09864071e+07, 2.49898981e+08]])

In [18]:
init_pop = np.array([0.0, 1.0])
birth_rates = np.array([0.0, 1.0])
death_rates = np.array([0.01, 0.1])
mutation_rates = np.array([0.01, 0.8])
transition_matrix = np.array([[0.0, 1.0], [1.0, 0.0]])

model = partial(
    dpopdt,
    birth_rates=birth_rates,
    death_rates=death_rates,
    mutation_rates=mutation_rates,
    transition_matrix=transition_matrix,
)

ts = np.linspace(0, 6, 5)
pop_hist = odeint(model, init_pop, ts)
pop_hist

array([[  0.        ,   1.        ],
       [  2.53246755,   3.88057562],
       [ 12.35963016,  15.13903429],
       [ 50.69743016,  59.13942161],
       [200.46070999, 231.09986611]])

In [19]:
# single_type_w_death_huge

init_pop = np.array([1.0])
birth_rates = np.array([1.0])
death_rates = np.array([0.5])
mutation_rates = np.array([0.0])
transition_matrix = np.array([[1.0]])

model = partial(
    dpopdt,
    birth_rates=birth_rates,
    death_rates=death_rates,
    mutation_rates=mutation_rates,
    transition_matrix=transition_matrix,
)

ts = np.linspace(0, 50, 5)
pop_hist = odeint(model, init_pop, ts)
pop_hist

array([[1.00000000e+00],
       [5.18012910e+02],
       [2.68337362e+05],
       [1.39002211e+08],
       [7.20049365e+10]])