In [1]:
import numpy as np
import jax.numpy as jnp

In [28]:
def prob_survival_mother(period, options):
    """Predicts the survival probability based on logit parameters.

    coefs_male = np.array(
        [11.561515476144223, -0.11058331994203506, -1.0998977981246952e-05],
    )
    coefs_female = np.array(
        [17.01934835131644, -0.21245937682111807, 0.00047537366767865137],
    )

    if sex.lower() == "male":
        coefs = coefs_male
    elif sex.lower() == "female":
        coefs = coefs_female

    logit = coefs[0] + coefs[1] * age + coefs[2] * (age**2)

    Parameters:
        age (int): The age of the individual. Age >= 65.
        sex (str): The gender of the individual ('male' or 'female').

    Returns:
        float: Predicted binary survival probability.

    """
    mother_age = period + options["mother_min_age"]

    logit = (
        options["survival_probability_mother_constant"]
        + options["survival_probability_mother_age"] * mother_age
        + options["survival_probability_mother_age_squared"] * (mother_age**2)
    )
    prob_survival = 1 / (1 + jnp.exp(-logit))

    return jnp.array([1 - prob_survival, prob_survival])


def prob_survival_father(period, options):
    """Predicts the survival probability based on logit parameters.

    coefs_male = np.array(
        [11.561515476144223, -0.11058331994203506, -1.0998977981246952e-05],
    )
    coefs_female = np.array(
        [17.01934835131644, -0.21245937682111807, 0.00047537366767865137],
    )

    if sex.lower() == "male":
        coefs = coefs_male
    elif sex.lower() == "female":
        coefs = coefs_female

    logit = coefs[0] + coefs[1] * age + coefs[2] * (age**2)

    Parameters:
        age (int): The age of the individual. Age >= 65.
        sex (str): The gender of the individual ('male' or 'female').

    Returns:
        float: Predicted binary survival probability.

    """
    father_age = period + options["father_min_age"]

    logit = (
        options["survival_probability_father_constant"]
        + options["survival_probability_father_age"] * father_age
        + options["survival_probability_father_age_squared"] * (father_age**2)
    )
    prob_survival = 1 / (1 + jnp.exp(-logit))

    return jnp.array([1 - prob_survival, prob_survival])

In [29]:
def prob_exog_care_demand(
    period,
    mother_alive,
    mother_health,
    father_alive,
    father_health,
    options,
):
    """Create nested exogenous care demand probabilities.

    Compute based on parent alive. Otherwise zero.
    Done outside?!

    Nested exogenous transitions:
    - First, a parent's health state is determined by their age and lagged health state.

    Args:
        parental_age (int): Age of parent.
        parent_alive (int): Binary indicator of whether parent is alive.
        good_health (int): Binary indicator of good health.
        medium_health (int): Binary indicator of medium health.
        bad_health (int): Binary indicator of bad health.
        params (dict): Dictionary of parameters.

    Returns:
        jnp.ndarray: Array of shape (2,) representing the probabilities of
            no care demand and care demand, respectively.

    """
    mother_survival_prob = prob_survival_mother(period, options)
    father_survival_prob = prob_survival_father(period, options)

    # ===============================================================

    # single mother
    prob_care_single_mother = _exog_care_demand_mother(
        period=period,
        options=options,
    )

    _mother_trans_probs_care_demand = jnp.array(prob_care_single_mother)

    # single father
    prob_care_single_father = _exog_care_demand_father(
        period=period,
        options=options,
    )

    _father_trans_probs_care_demand = jnp.array(prob_care_sinlge_father)

    # couple
    prob_care_coouple = _exog_care_demand_couple(
        period=period,
        options=options,
    )

    _couple_trans_probs_care_demand = jnp.array(prob_care_couple)

    # Non-zero probability of care demand only if parent is alive,
    # weighted by the parent's survival probability
    mother_single_prob_care_demand = (
        mother_survival_prob * mother_alive * (1 - father_alive)
    ) * _mother_trans_probs_care_demand

    father_single_prob_care_demand = (
        father_survival_prob * father_alive * (1 - mother_alive)
    ) * _father_trans_probs_care_demand

    couple_prob_care_demand = (
        father_survival_prob * father_alive * mother_survival_prob * mother_alive
    ) * _couple_trans_probs_care_demand

    prob_care_demand = (
        mother_single_prob_care_demand
        + father_single_prob_care_demand
        + couple_prob_care_demand
    )

    return jnp.array([1 - prob_care_demand, prob_care_demand])


#

In [30]:
def _exog_care_demand_mother(period, options):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    mother_age = period + options["mother_min_age"]

    logit = (
        options["exog_care_single_mother_constant"]
        + options["exog_care_single_mother_age"] * mother_age
        + options["exog_care_single_mother_age_squared"] * (mother_age**2)
    )
    return 1 / (1 + np.exp(-logit))


def _exog_care_demand_father(period, options):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    father_age = period + options["father_min_age"]

    logit = (
        options["exog_care_single_father_constant"]
        + options["exog_care_single_father_age"] * father_age
        + options["exog_care_single_father_age_squared"] * (father_age**2)
    )
    return 1 / (1 + np.exp(-logit))


def _exog_care_demand_couple(period, options):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    mother_age = period + options["mother_min_age"]
    father_age = period + options["father_min_age"]

    logit = (
        options["exog_care_couple_constant"]
        + options["exog_care_couple_mother_age"] * mother_age
        + options["exog_care_couple_mother_age_squared"] * (mother_age**2)
        + options["exog_care_couple_father_age"] * father_age
        + options["exog_care_couple_father_age_squared"] * (father_age**2)
    )
    return 1 / (1 + np.exp(-logit))

In [35]:
model_params = {
    "quadrature_points_stochastic": 5,
    "n_choices": 12,
    # "min_age": MIN_AGE,
    # "max_age": MAX_AGE,
    "mother_min_age": 70,
    "father_min_age": 70,
    # annual
    "consumption_floor": 400 * 12,
    "unemployment_benefits": 500 * 12,
    "informal_care_benefits": 444.0466
    * 12,  # 0.4239 * 316 + 0.2793 * 545 + 728 *0.1405 + 901 * 0.0617
    "formal_care_costs": 118.10658099999999
    * 12,  # >>> 79.31 * 0.0944 + 0.4239 * 70.77 + 0.2793 * 176.16 + 224.26 *0.1401
    "interest_rate": 0.04,  # Adda et al (2017)
    # ===================
    # EXOGENOUS PROCESSES
    # ===================
    # survival probability
    "survival_probability_mother_constant": 17.01934835131644,
    "survival_probability_mother_age": -0.21245937682111807,
    "survival_probability_mother_age_squared": 0.00047537366767865137,
    "survival_probability_father_constant": 11.561515476144223,
    "survival_probability_father_age": -0.11058331994203506,
    "survival_probability_father_age_squared": -1.0998977981246952e-05,
    # health
    "mother_medium_health": {
        "medium_health_age": 0.0304,
        "medium_health_age_squared": -1.31e-05,
        "medium_health_lagged_good_health": -1.155,
        "medium_health_lagged_medium_health": 0.736,
        "medium_health_lagged_bad_health": 1.434,
        "medium_health_constant": -1.550,
    },
    "mother_bad_health": {
        "bad_health_age": 0.196,
        "bad_health_age_squared": -0.000885,
        "bad_health_lagged_good_health": -2.558,
        "bad_health_lagged_medium_health": -0.109,
        "bad_health_lagged_bad_health": 2.663,
        "bad_health_constant": -9.220,
    },
    "father_medium_health": {
        "medium_health_age": 0.176,
        "medium_health_age_squared": -0.000968,
        "medium_health_lagged_good_health": -1.047,
        "medium_health_lagged_medium_health": 1.016,
        "medium_health_lagged_bad_health": 1.743,
        "medium_health_constant": -7.374,
    },
    "father_bad_health": {
        "bad_health_age": 0.260,
        "bad_health_age_squared": -0.00134,
        "bad_health_lagged_good_health": -2.472,
        "bad_health_lagged_medium_health": 0.115,
        "bad_health_lagged_bad_health": 3.067,
        "bad_health_constant": -11.89,
    },
    # TODO: care demand
    "exog_care_single_mother_constant": 22.322551,
    "exog_care_single_mother_age": -0.661611,
    "exog_care_single_mother_age_squared": 0.004840,
    #
    "exog_care_single_father_constant": 16.950484,
    "exog_care_single_father_age": -0.541042,
    "exog_care_single_father_age_squared": 0.004136,
    #
    "exog_care_couple_constant": 22.518664,
    "exog_care_couple_mother_age": -0.622648,
    "exog_care_couple_mother_age_squared": 0.004346,
    "exog_care_couple_father_age": -0.068347,
    "exog_care_couple_father_age_squared": 0.000769,
    #
}

In [36]:
# Test
period = 20
mother_survival_prob = prob_survival_mother(period, options=model_params)
father_survival_prob = prob_survival_father(period, options=model_params)
_exog_care_demand_father(period=period, options=model_params), _exog_care_demand_mother(
    period=period, options=model_params
)

(0.852997120909907, 0.8788474672693936)

In [37]:
father_survival_prob[1], mother_survival_prob[1]

(Array(0.82052743, dtype=float32), Array(0.8517675, dtype=float32))

In [38]:
_exog_care_demand_couple(period=period, options=model_params)

0.8532865426853096

## State transition and feasible choice set

In [39]:
options = {
    "state_space": {
        "n_periods": 20,
        "n_choices": 12,
        "choices": np.arange(12),
        "endogenous_states": {
            "married": np.arange(2),
            "has_sibling": np.arange(2),
        },
        "exogenous_processes": {
            "part_time_offer": {
                "states": np.arange(2),
                "transition": prob_part_time_offer,
            },
            "full_time_offer": {
                "states": np.arange(2),
                "transition": prob_full_time_offer,
            },
            "care_demand": {
                "states": np.arange(2),
                "transition": prob_exog_care_demand,
            },
            "mother_alive": {
                "states": np.arange(2),
                "transition": prob_survival_mother,
            },
            "father_alive": {
                "states": np.arange(2),
                "transition": prob_survival_father,
            },
            # "mother_health": {
            #    "states": np.arange(3),
            #    "transition": exog_health_transition_mother,
            # },
            # "father_health": {
            #    "states": np.arange(3),
            #    "transition": exog_health_transition_father,
            # },
        },
    },
}

NameError: name 'prob_part_time_offer' is not defined

In [40]:
def update_endog_state(
    period,
    choice,
    married,
    has_sibling,
    options,
):
    next_state = {}

    next_state["period"] = period + 1
    next_state["lagged_choice"] = choice

    next_state["married"] = married
    next_state["has_sibling"] = has_sibling

    # next_state["mother_age"] = options["mother_min_age"] + mother_age + 1
    # next_state["father_age"] = options["father_min_age"] + father_age + 1

    # alive based on exog state
    # health based on exog state

    return next_state


def get_state_specific_feasible_choice_set(
    part_time_offer,
    full_time_offer,
    mother_alive,
    father_alive,
    care_demand,
    options,
):
    """No need to be jax compatible."""
    # formal_care = choice % 2 == 1  # uneven numbers mark formal care
    # light_informal_care = [2, 3, 8, 9, 14, 15]
    # intensive_informal_care =[4, 5, 10, 11, 16, 17]
    # NO_CARE = NO_CARE
    # CARE = [choice for choice in all_choices if choice not in NO_CARE]

    # state_vec including exog?
    feasible_choice_set = list(np.arange(options["n_choices"]))

    # care demand
    # if mother_alive or father_alive:
    if exog_care_demand:
        feasible_choice_set = [i for i in feasible_choice_set if i in CARE]
    else:
        feasible_choice_set = [i for i in feasible_choice_set if i in NO_CARE]

    # job offer
    if (full_time_offer == True) | (part_time_offer == True):
        feasible_choice_set = [i for i in feasible_choice_set if i in WORK]
    elif (full_time_offer == False) & (part_time_offer == True):
        feasible_choice_set = [i for i in feasible_choice_set if i in PART_TIME]
    elif (full_time_offer == False) & (part_time_offer == False):
        feasible_choice_set = [i for i in feasible_choice_set if i in FULL_TIME]
    else:
        # (full_time_offer == False) & (part_time_offer == False)
        feasible_choice_set = [i for i in feasible_choice_set if i in NO_WORK]

    return np.array(feasible_choice_set)