In [7]:
import numpy as np
import jax.numpy as jnp
import jax

In [45]:
def prob_survival_mother(period, options):
    """Predicts the survival probability based on logit parameters.

    coefs_male = np.array(
        [11.561515476144223, -0.11058331994203506, -1.0998977981246952e-05],
    )
    coefs_female = np.array(
        [17.01934835131644, -0.21245937682111807, 0.00047537366767865137],
    )

    if sex.lower() == "male":
        coefs = coefs_male
    elif sex.lower() == "female":
        coefs = coefs_female

    logit = coefs[0] + coefs[1] * age + coefs[2] * (age**2)

    Parameters:
        age (int): The age of the individual. Age >= 65.
        sex (str): The gender of the individual ('male' or 'female').

    Returns:
        float: Predicted binary survival probability.

    """
    mother_age = period + options["mother_min_age"]

    logit = (
        options["survival_probability_mother_constant"]
        + options["survival_probability_mother_age"] * mother_age
        + options["survival_probability_mother_age_squared"] * (mother_age**2)
    )
    prob_survival = 1 / (1 + jnp.exp(-logit))

    return jnp.array([1 - prob_survival, prob_survival])


def prob_survival_father(period, options):
    """Predicts the survival probability based on logit parameters.

    coefs_male = np.array(
        [11.561515476144223, -0.11058331994203506, -1.0998977981246952e-05],
    )
    coefs_female = np.array(
        [17.01934835131644, -0.21245937682111807, 0.00047537366767865137],
    )

    if sex.lower() == "male":
        coefs = coefs_male
    elif sex.lower() == "female":
        coefs = coefs_female

    logit = coefs[0] + coefs[1] * age + coefs[2] * (age**2)

    Parameters:
        age (int): The age of the individual. Age >= 65.
        sex (str): The gender of the individual ('male' or 'female').

    Returns:
        float: Predicted binary survival probability.

    """
    father_age = period + options["father_min_age"]

    logit = (
        options["survival_probability_father_constant"]
        + options["survival_probability_father_age"] * father_age
        + options["survival_probability_father_age_squared"] * (father_age**2)
    )
    prob_survival = 1 / (1 + jnp.exp(-logit))

    return jnp.array([1 - prob_survival, prob_survival])

In [78]:
def prob_exog_care_demand(
    period,
    mother_alive,
    father_alive,
    options,
):
    """Create nested exogenous care demand probabilities.

    Compute based on parent alive. Otherwise zero.
    Done outside?!

    Nested exogenous transitions:
    - First, a parent's health state is determined by their age and lagged health state.

    Args:
        parental_age (int): Age of parent.
        parent_alive (int): Binary indicator of whether parent is alive.
        good_health (int): Binary indicator of good health.
        medium_health (int): Binary indicator of medium health.
        bad_health (int): Binary indicator of bad health.
        params (dict): Dictionary of parameters.

    Returns:
        jnp.ndarray: Array of shape (2,) representing the probabilities of
            no care demand and care demand, respectively.

    """
    mother_survival_prob = prob_survival_mother(period, options)
    father_survival_prob = prob_survival_father(period, options)

    # ===============================================================

    # single mother
    prob_care_single_mother = _exog_care_demand_mother(
        period=period,
        options=options,
    )

    _mother_trans_probs_care_demand = jnp.array(prob_care_single_mother)

    # single father
    prob_care_single_father = _exog_care_demand_father(
        period=period,
        options=options,
    )

    _father_trans_probs_care_demand = jnp.array(prob_care_single_father)

    # couple
    prob_care_couple = _exog_care_demand_couple(
        period=period,
        options=options,
    )

    _couple_trans_probs_care_demand = jnp.array(prob_care_couple)

    # Non-zero probability of care demand only if parent is alive,
    # weighted by the parent's survival probability
    mother_single_prob_care_demand = (
        mother_survival_prob * mother_alive * (1 - father_alive)
    ) * _mother_trans_probs_care_demand

    father_single_prob_care_demand = (
        father_survival_prob * father_alive * (1 - mother_alive)
    ) * _father_trans_probs_care_demand

    couple_prob_care_demand = (
        father_survival_prob * father_alive * mother_survival_prob * mother_alive
    ) * _couple_trans_probs_care_demand

    prob_care_demand = (
        mother_single_prob_care_demand[1]
        + father_single_prob_care_demand[1]
        + couple_prob_care_demand[1]
    )

    #return mother_single_prob_care_demand, father_single_prob_care_demand, couple_prob_care_demand
    return jnp.array([1 - prob_care_demand, prob_care_demand])



In [79]:
def _exog_care_demand_mother(period, options):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    mother_age = period + options["mother_min_age"]

    logit = (
        options["exog_care_single_mother_constant"]
        + options["exog_care_single_mother_age"] * mother_age
        + options["exog_care_single_mother_age_squared"] * (mother_age**2)
    )
    return 1 / (1 + np.exp(-logit))


def _exog_care_demand_father(period, options):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    father_age = period + options["father_min_age"]

    logit = (
        options["exog_care_single_father_constant"]
        + options["exog_care_single_father_age"] * father_age
        + options["exog_care_single_father_age_squared"] * (father_age**2)
    )
    return 1 / (1 + np.exp(-logit))


def _exog_care_demand_couple(period, options):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    mother_age = period + options["mother_min_age"]
    father_age = period + options["father_min_age"]

    logit = (
        options["exog_care_couple_constant"]
        + options["exog_care_couple_mother_age"] * mother_age
        + options["exog_care_couple_mother_age_squared"] * (mother_age**2)
        + options["exog_care_couple_father_age"] * father_age
        + options["exog_care_couple_father_age_squared"] * (father_age**2)
    )
    return 1 / (1 + np.exp(-logit))

In [80]:
model_params = {
    "quadrature_points_stochastic": 5,
    "n_choices": 12,
    # "min_age": MIN_AGE,
    # "max_age": MAX_AGE,
    "mother_min_age": 70,
    "father_min_age": 70,
    # annual
    "consumption_floor": 400 * 12,
    "unemployment_benefits": 500 * 12,
    "informal_care_benefits": 444.0466
    * 12,  # 0.4239 * 316 + 0.2793 * 545 + 728 *0.1405 + 901 * 0.0617
    "formal_care_costs": 118.10658099999999
    * 12,  # >>> 79.31 * 0.0944 + 0.4239 * 70.77 + 0.2793 * 176.16 + 224.26 *0.1401
    "interest_rate": 0.04,  # Adda et al (2017)
    # ===================
    # EXOGENOUS PROCESSES
    # ===================
    # survival probability
    "survival_probability_mother_constant": 17.01934835131644,
    "survival_probability_mother_age": -0.21245937682111807,
    "survival_probability_mother_age_squared": 0.00047537366767865137,
    "survival_probability_father_constant": 11.561515476144223,
    "survival_probability_father_age": -0.11058331994203506,
    "survival_probability_father_age_squared": -1.0998977981246952e-05,
    # health
    "mother_medium_health": {
        "medium_health_age": 0.0304,
        "medium_health_age_squared": -1.31e-05,
        "medium_health_lagged_good_health": -1.155,
        "medium_health_lagged_medium_health": 0.736,
        "medium_health_lagged_bad_health": 1.434,
        "medium_health_constant": -1.550,
    },
    "mother_bad_health": {
        "bad_health_age": 0.196,
        "bad_health_age_squared": -0.000885,
        "bad_health_lagged_good_health": -2.558,
        "bad_health_lagged_medium_health": -0.109,
        "bad_health_lagged_bad_health": 2.663,
        "bad_health_constant": -9.220,
    },
    "father_medium_health": {
        "medium_health_age": 0.176,
        "medium_health_age_squared": -0.000968,
        "medium_health_lagged_good_health": -1.047,
        "medium_health_lagged_medium_health": 1.016,
        "medium_health_lagged_bad_health": 1.743,
        "medium_health_constant": -7.374,
    },
    "father_bad_health": {
        "bad_health_age": 0.260,
        "bad_health_age_squared": -0.00134,
        "bad_health_lagged_good_health": -2.472,
        "bad_health_lagged_medium_health": 0.115,
        "bad_health_lagged_bad_health": 3.067,
        "bad_health_constant": -11.89,
    },
    # TODO: care demand
    "exog_care_single_mother_constant": 22.322551,
    "exog_care_single_mother_age": -0.661611,
    "exog_care_single_mother_age_squared": 0.004840,
    "exog_care_single_father_constant": 16.950484,
    "exog_care_single_father_age": -0.541042,
    "exog_care_single_father_age_squared": 0.004136,
    "exog_care_couple_constant": 22.518664,
    "exog_care_couple_mother_age": -0.622648,
    "exog_care_couple_mother_age_squared": 0.004346,
    "exog_care_couple_father_age": -0.068347,
    "exog_care_couple_father_age_squared": 0.000769,
}

In [81]:
# Test
period = 20
mother_survival_prob = prob_survival_mother(period, options=model_params)
father_survival_prob = prob_survival_father(period, options=model_params)
_exog_care_demand_father(period=period, options=model_params), _exog_care_demand_mother(
    period=period, options=model_params,
)

(0.852997120909907, 0.8788474672693936)

In [82]:
father_survival_prob[1], mother_survival_prob[1]

(Array(0.82052743, dtype=float32), Array(0.8517675, dtype=float32))

In [83]:
_exog_care_demand_couple(period=period, options=model_params)

0.8532865426853096

In [86]:
prob_exog_care_demand(period=10, mother_alive=0, father_alive=1, options=model_params)

Array([0.5011158, 0.4988842], dtype=float32)

In [87]:
prob_exog_care_demand(period=10, mother_alive=0, father_alive=0, options=model_params)

Array([1., 0.], dtype=float32)

## State transition and feasible choice set

In [39]:
options = {
    "state_space": {
        "n_periods": 20,
        "n_choices": 12,
        "choices": np.arange(12),
        "endogenous_states": {
            "married": np.arange(2),
            "has_sibling": np.arange(2),
        },
        "exogenous_processes": {
            "part_time_offer": {
                "states": np.arange(2),
                "transition": prob_part_time_offer,
            },
            "full_time_offer": {
                "states": np.arange(2),
                "transition": prob_full_time_offer,
            },
            "care_demand": {
                "states": np.arange(2),
                "transition": prob_exog_care_demand,
            },
            "mother_alive": {
                "states": np.arange(2),
                "transition": prob_survival_mother,
            },
            "father_alive": {
                "states": np.arange(2),
                "transition": prob_survival_father,
            },
            # "mother_health": {
            #    "states": np.arange(3),
            #    "transition": exog_health_transition_mother,
            # },
            # "father_health": {
            #    "states": np.arange(3),
            #    "transition": exog_health_transition_father,
            # },
        },
    },
}

NameError: name 'prob_part_time_offer' is not defined

In [40]:
def update_endog_state(
    period,
    choice,
    married,
    has_sibling,
    options,
):
    next_state = {}

    next_state["period"] = period + 1
    next_state["lagged_choice"] = choice

    next_state["married"] = married
    next_state["has_sibling"] = has_sibling

    # next_state["mother_age"] = options["mother_min_age"] + mother_age + 1
    # next_state["father_age"] = options["father_min_age"] + father_age + 1

    # alive based on exog state
    # health based on exog state

    return next_state


def get_state_specific_feasible_choice_set(
    part_time_offer,
    full_time_offer,
    mother_alive,
    father_alive,
    care_demand,
    options,
):
    """No need to be jax compatible."""
    # formal_care = choice % 2 == 1  # uneven numbers mark formal care
    # light_informal_care = [2, 3, 8, 9, 14, 15]
    # intensive_informal_care =[4, 5, 10, 11, 16, 17]
    # NO_CARE = NO_CARE
    # CARE = [choice for choice in all_choices if choice not in NO_CARE]

    # state_vec including exog?
    feasible_choice_set = list(np.arange(options["n_choices"]))

    # care demand
    # if mother_alive or father_alive:
    if care_demand:
        feasible_choice_set = [i for i in feasible_choice_set if i in CARE]
    else:
        feasible_choice_set = [i for i in feasible_choice_set if i in NO_CARE]

    # job offer
    if (full_time_offer == True) | (part_time_offer == True):
        feasible_choice_set = [i for i in feasible_choice_set if i in WORK]
    elif (full_time_offer == False) & (part_time_offer == True):
        feasible_choice_set = [i for i in feasible_choice_set if i in PART_TIME]
    elif (full_time_offer == False) & (part_time_offer == False):
        feasible_choice_set = [i for i in feasible_choice_set if i in FULL_TIME]
    else:
        # (full_time_offer == False) & (part_time_offer == False)
        feasible_choice_set = [i for i in feasible_choice_set if i in NO_WORK]

    return np.array(feasible_choice_set)

## Try rng

In [8]:
def draw_random_array(seed, n_agents, values, probabilities):
    """Draw a random array with given probabilities.

    Usage:

    seed = 2024
    n_agents = 10_000

    # Parameters
    values = jnp.array([-1, 0, 1, 2])  # Values to choose from
    probabilities = jnp.array([0.3, 0.3, 0.2, 0.2])  # Corresponding probabilities

    table(pd.DataFrame(random_array)[0]) / 1000

    """
    key = jax.random.PRNGKey(seed)
    return jax.random.choice(key, values, shape=(n_agents,), p=probabilities)

In [28]:
seed = 2024
n_agents = 10_000
n_choices = 12
max_iter = 2_000

In [29]:
draw_random_array(
    seed=seed - 1,
    n_agents=n_agents,
    values=jnp.array([0, 1]),
    probabilities=jnp.array([0.3, 0.7]),
).astype(np.int16).mean()

Array(0.70379996, dtype=float32)

Array([[2747046042, 3946412319],
       [ 456906278, 1557099613],
       [2891578254, 2601165911],
       ...,
       [ 121031298, 1228053634],
       [1331106446, 1419707387],
       [3752188396, 4285775511]], dtype=uint32)

In [34]:
key = jax.random.PRNGKey(seed)
key

Array([   0, 2024], dtype=uint32)

In [38]:
iter_specific_keys = jax.random.split(jax.random.PRNGKey(seed), num=max_iter)

## Put together

In [None]:
initial_states = {
    "period": jnp.zeros(n_agents, dtype=np.int16),
    "lagged_choice": draw_random_array(
        seed=seed - 1,
        n_agents=n_agents,
        values=jnp.arange(n_choices),
        probabilities=lagged_choice_probs,
    ).astype(np.int16),
    "married": draw_random_array(
        seed=seed - 2,
        n_agents=n_agents,
        values=jnp.array([0, 1]),
        probabilities=married,
    ).astype(np.int16),
    "has_sibling": draw_random_array(
        seed=seed - 3,
        n_agents=n_agents,
        values=jnp.array([0, 1]),
        probabilities=has_sibling,
    ).astype(np.int16),
    # exogenous states
    "part_time_offer": jnp.ones(n_agents, dtype=np.int16),
    "full_time_offer": jnp.ones(n_agents, dtype=np.int16),
    "care_demand": jnp.zeros(n_agents, dtype=np.int16),
    "mother_alive": draw_random_array(
        seed=seed - 6,
        n_agents=n_agents,
        values=jnp.array([0, 1]),
        probabilities=mother_alive,
    ).astype(np.int16),
    "father_alive": draw_random_array(
        seed=seed - 7,
        n_agents=n_agents,
        values=jnp.array([0, 1]),
        probabilities=father_alive,
    ).astype(np.int16),
}

In [None]:
col_weights = jnp.eye(len(emp_moments))


def criterion_solve_and_simulate(params, n_agents=10_000, seed=2024):

    value, policy_left, policy_right, endog_grid = solve_func(params)

    result = simulate_all_periods(
        states_initial=initial_states,
        resources_initial=initial_resources,
        n_periods=options["state_space"]["n_periods"],
        params=params,
        state_space_names=state_space_names,
        seed=seed,
        endog_grid_solved=endog_grid,
        value_solved=value,
        policy_left_solved=policy_left,
        policy_right_solved=policy_right,
        map_state_choice_to_index=jnp.array(map_state_choice_to_index),
        choice_range=jnp.arange(map_state_choice_to_index.shape[-1], dtype=jnp.int16),
        compute_exog_transition_vec=model_funcs["compute_exog_transition_vec"],
        compute_utility=model_funcs["compute_utility"],
        compute_beginning_of_period_resources=model_funcs[
            "compute_beginning_of_period_resources"
        ],
        exog_state_mapping=exog_state_mapping,
        update_endog_state_by_state_and_choice=update_endog_state_by_state_and_choice,
        compute_utility_final_period=model_funcs["compute_utility_final"],
    )

    df = create_simulation_df(result, options=options, params=params)
    sim_moments = simulate_moments(df)

    err = sim_moments - emp_moments
    crit_val = np.dot(np.dot(err.T, col_weights), err)

    deviations = moms_model - np.array(moms_data)
    root_contribs = deviations @ chol_weights

    return {"root_contributions": root_contribs, "value": crit_val}