In [6]:
import pytest
import pandas as pd
import numpy as np
from pySODM.models.base import SDEModel

##################################
## Model without stratification ##
##################################

class SIR(SDEModel):

    # state variables and parameters
    state_names = ['S', 'I', 'R']
    parameter_names = ['beta', 'gamma']

    @staticmethod
    def compute_rates(t, S, I, R, beta, gamma):
        """Basic SIR model"""
        return {'S': [beta*(I/(S + I + R)),], 'I': [np.array([gamma]),]}

    @staticmethod
    def apply_transitionings(t, tau, transitionings, S, I, R, beta, gamma):
        S_new = S - transitionings['S'][0]
        I_new = I + transitionings['S'][0] - transitionings['I'][0]
        R_new = R + transitionings['I'][0]
        return S_new, I_new, R_new

def test_SIR_time():

    # Define parameters and initial states
    parameters = {"beta": 0.9, "gamma": 0.2}
    initial_states = {"S": [1_000_000 - 10], "I": [10], "R": [0]}
    # Build model
    model = SIR(initial_states, parameters)
    # Simulate using a mixture of int/float
    time = [int(10), float(50.3)]
    output = model.sim(time)
    # Simulate using just one timestep
    output = model.sim(50)
    # Simulate using a list of timesteps
    time = [0, 50]
    output = model.sim(time)

    # Validate
    assert 'time' in list(output.dims.keys())
    np.testing.assert_allclose(output["time"], np.arange(0, 51))
    S = output["S"].values.squeeze()
    assert S[0] == 1_000_000 - 10
    assert S.shape == (51, )
    assert S[-1] < 12_000
    I = output["I"].squeeze()
    assert I[0] == 10
    assert S.shape == (51, )

def test_SIR_date():

    # Define parameters and initial states
    parameters = {"beta": 0.9, "gamma": 0.2}
    initial_states = {"S": [1_000_000 - 10], "I": [10], "R": [0]}
    # Build model
    model = SIR(initial_states, parameters)

    # Simulate using dates
    output = model.sim(['2020-01-01', '2020-02-20'])
    output = model.sim([pd.Timestamp('2020-01-01'), pd.Timestamp('2020-02-20')])

    # Validate
    assert 'date' in list(output.dims.keys())
    S = output["S"].values.squeeze()
    assert S[0] == 1_000_000 - 10
    assert S.shape == (51, )
    assert S[-1] < 12_000
    I = output["I"].squeeze()
    assert I[0] == 10
    assert S.shape == (51, )

    # Simulate using a mixture of timestamp and string
    with pytest.raises(TypeError, match="List-like input of simulation start"):
        output = model.sim(['2020-01-01', pd.Timestamp('2020-02-20')])

def test_model_init_validation():
    # valid initialization
    parameters = {"beta": 0.9, "gamma": 0.2}
    initial_states = {"S": [1_000_000 - 10], "I": [10], "R": [0]}
    model = SIR(initial_states, parameters)
    assert model.initial_states == initial_states
    assert model.parameters == parameters
    # model state/parameter names didn't change
    assert model.state_names == ['S', 'I', 'R']
    assert model.parameter_names == ['beta', 'gamma']

    # wrong initial states
    initial_states2 = {"S": [1_000_000 - 10], "II": [10]}
    with pytest.raises(ValueError, match="specified initial states don't"):
        SIR(initial_states2, parameters)

    # wrong parameters
    parameters2 = {"beta": 0.9, "gamma": 0.2, "other": 1}
    with pytest.raises(ValueError, match="specified parameters don't"):
        SIR(initial_states, parameters2)

    # validate model class itself
    SIR.state_names = ["S", "R"]
    with pytest.raises(ValueError, match=The states in the 'compute_rates' function definition do not match the provided 'state_names'):
        SIR(initial_states, parameters)

    SIR.state_names = ["S", "II", "R"]
    with pytest.raises(ValueError, match="The states in the 'compute_rates' function definition"):
        SIR(initial_states, parameters)

    SIR.state_names = ["S", "I", "R"]
    SIR.parameter_names = ['beta', 'alpha']
    with pytest.raises(ValueError, match="The parameters in the 'compute_rates' function"):
        SIR(initial_states, parameters)

    # ensure to set back to correct ones
    SIR.state_names = ["S", "I", "R"]
    SIR.parameter_names = ['beta', 'gamma']

In [7]:
test_model_init_validation()

ValueError: The states in the 'compute_rates' function definition do not match the provided 'state_names': ['S', 'I'] vs ['S', 'R']