In [3]:
import numpy as np
import jax.numpy as jnp
import jax
from itertools import product


from typing import Dict
from typing import Tuple

In [33]:
PARENT_MIN_AGE = 68
PARENT_MAX_AGE = 98

RETIREMENT_AGE = 62

GOOD_HEALTH = 0
MEDIUM_HEALTH = 0
BAD_HEALTH = 0

In [35]:
NO_WORK = [0, 1, 2, 3]
PART_TIME = [4, 5, 6, 7]
FULL_TIME = [8, 9, 10, 11]
WORK = PART_TIME + FULL_TIME

NO_CARE = [0, 4, 8]
FORMAL_CARE = [1, 3, 5, 7, 9, 11]  # % 2 == 1
INFORMAL_CARE = [2, 3, 6, 7, 10, 11]

TOTAL_WEEKLY_HOURS = 80
WEEKLY_INTENSIVE_INFORMAL_HOURS = 14  # (21 + 7) / 2

In [32]:
def is_not_working(lagged_choice):
    return lagged_choice in NO_WORK


def is_part_time(lagged_choice):
    return lagged_choice in PART_TIME


def is_full_time(lagged_choice):
    return lagged_choice in FULL_TIME


def is_formal_care(lagged_choice):
    return lagged_choice in FORMAL_CARE


def is_informal_care(lagged_choice):
    # intensive only here
    return lagged_choice in INFORMAL_CARE

# 0. Exogenous Processes

In [None]:
def probability_full_time_offer(period, lagged_choice, options, params):
    """Compute logit probability of full time offer."""

    logit = (
        params["full_time_constant"]
        + params["full_time_not_working_last_period"] * is_not_working(lagged_choice)
        + params["full_time_working_part_time_last_period"]
        * is_part_time(lagged_choice)
        + params["full_time_above_retirement_age"]
        * (period + options["min_age"] >= RETIREMENT_AGE)
        # + params["full_time_high_education"] * high_educ
    )

    # _prob = jnp.exp(logit) / (1 + jnp.exp(logit))
    prob_logit = 1 / (1 + jnp.exp(-logit))

    return (
        is_full_time(lagged_choice) * 1 + (1 - is_full_time(lagged_choice)) * prob_logit
    )


def probability_part_time_offer(period, lagged_choice, options, params):
    """Compute logit probability of part time offer."""
    logit = (
        params["part_time_constant"]
        + params["part_time_not_working_last_period"] * is_not_working(lagged_choice)
        + params["part_time_working_part_time_last_period"]
        * is_part_time(lagged_choice)
        + params["part_time_above_retirement_age"]
        * (period + options["min_age"] >= RETIREMENT_AGE)
        # + params["part_time_high_education"] * high_educ
    )
    
    prob_logit = 1 / (1 + jnp.exp(-logit))

    return (
        is_part_time(lagged_choice) * 1 + (1 - is_part_time(lagged_choice)) * prob_logit
    )

In [None]:
def prob_survival_mother(mother_age, params):
    """Predicts the survival probability based on logit parameters.

    coefs_male = np.array(
        [11.561515476144223, -0.11058331994203506, -1.0998977981246952e-05],
    )
    coefs_female = np.array(
        [17.01934835131644, -0.21245937682111807, 0.00047537366767865137],
    )

    if sex.lower() == "male":
        coefs = coefs_male
    elif sex.lower() == "female":
        coefs = coefs_female

    logit = coefs[0] + coefs[1] * age + coefs[2] * (age**2)

    Parameters:
        age (int): The age of the individual. Age >= 65.
        sex (str): The gender of the individual ('male' or 'female').

    Returns:
        float: Predicted binary survival probability.

    """

    logit = (
        params["survival_probability_mother_constant"]
        + params["survival_probability_mother_age"] * mother_age
        + params["survival_probability_mother_age_squared"] * (mother_age**2)
    )

    return 1 / (1 + jnp.exp(-logit))

In [None]:
def prob_survival_father(father_age, params):
    """Predicts the survival probability based on logit parameters.

    coefs_male = np.array(
        [11.561515476144223, -0.11058331994203506, -1.0998977981246952e-05],
    )
    coefs_female = np.array(
        [17.01934835131644, -0.21245937682111807, 0.00047537366767865137],
    )

    if sex.lower() == "male":
        coefs = coefs_male
    elif sex.lower() == "female":
        coefs = coefs_female

    logit = coefs[0] + coefs[1] * age + coefs[2] * (age**2)

    Parameters:
        age (int): The age of the individual. Age >= 65.
        sex (str): The gender of the individual ('male' or 'female').

    Returns:
        float: Predicted binary survival probability.

    """

    logit = (
        params["survival_probability_father_constant"]
        + params["survival_probability_father_age"] * father_age
        + params["survival_probability_father_age_squared"] * (father_age**2)
    )

    return 1 / (1 + jnp.exp(-logit))

In [34]:
def exog_health_transition_mother(mother_age, mother_health, params):
    """Compute exogenous health transition probabilities.

    Multinomial logit model with three health states: good, medium, bad.

    This function computes the transition probabilities for an individual's health
    state based on their current age, squared age, and lagged health states.
    It uses a set of predefined parameters for medium and bad health states to
    calculate linear combinations, and then applies the softmax function to these
    linear combinations to get the transition probabilities.


    Returns:
        jnp.ndarray: Array of shape (3,) representing the probabilities of
            transitioning to good, medium, and bad health states, respectively.

    """
    # mother_age = period + options["min_age"]
    mother_age_squared = mother_age**2

    good_health = mother_health == GOOD_HEALTH
    medium_health = mother_health == MEDIUM_HEALTH
    bad_health = mother_health == BAD_HEALTH

    # Linear combination for medium health
    lc_medium_health = (
        params["mother_medium_health"]["medium_health_age"] * mother_age
        + params["mother_medium_health"]["medium_health_age_squared"]
        * mother_age_squared
        + params["mother_medium_health"]["medium_health_lagged_good_health"]
        * good_health
        + params["mother_medium_health"]["medium_health_lagged_medium_health"]
        * medium_health
        + params["mother_medium_health"]["medium_health_lagged_bad_health"] * bad_health
        + params["mother_medium_health"]["medium_health_constant"]
    )

    # Linear combination for bad health
    lc_bad_health = (
        params["mother_bad_health"]["bad_health_age"] * mother_age
        + params["mother_bad_health"]["bad_health_age_squared"] * mother_age_squared
        + params["mother_bad_health"]["bad_health_lagged_good_health"] * good_health
        + params["mother_bad_health"]["bad_health_lagged_medium_health"] * medium_health
        + params["mother_bad_health"]["bad_health_lagged_bad_health"] * bad_health
        + params["mother_bad_health"]["bad_health_constant"]
    )

    linear_comb = np.array([0, lc_medium_health, lc_bad_health])
    transition_probs = _softmax(linear_comb)

    return jnp.array([transition_probs[0], transition_probs[1], transition_probs[2]])


def exog_health_transition_father(father_age, father_health, params):
    """Compute exogenous health transition probabilities.

    Multinomial logit model with three health states: good, medium, bad.

    This function computes the transition probabilities for an individual's health
    state based on their current age, squared age, and lagged health states.
    It uses a set of predefined parameters for medium and bad health states to
    calculate linear combinations, and then applies the softmax function to these
    linear combinations to get the transition probabilities.


    Returns:
        jnp.ndarray: Array of shape (3,) representing the probabilities of
            transitioning to good, medium, and bad health states, respectively.

    """
    # father_age = period + options["min_age"]
    father_age_squared = father_age**2

    good_health = father_health == GOOD_HEALTH
    medium_health = father_health == MEDIUM_HEALTH
    bad_health = father_health == BAD_HEALTH

    # Linear combination for medium health
    lc_medium_health = (
        params["father_medium_health"]["medium_health_age"] * father_age
        + params["father_medium_health"]["medium_health_age_squared"] * father_age_squared
        + params["father_medium_health"]["medium_health_lagged_good_health"] * good_health
        + params["father_medium_health"]["medium_health_lagged_medium_health"] * medium_health
        + params["father_medium_health"]["medium_health_lagged_bad_health"] * bad_health
        + params["father_medium_health"]["medium_health_constant"]
    )

    # Linear combination for bad health
    lc_bad_health = (
        params["father_bad_health"]["bad_health_age"] * father_age
        + params["father_bad_health"]["bad_health_age_squared"] * father_age_squared
        + params["father_bad_health"]["bad_health_lagged_good_health"] * good_health
        + params["father_bad_health"]["bad_health_lagged_medium_health"] * medium_health
        + params["father_bad_health"]["bad_health_lagged_bad_health"] * bad_health
        + params["father_bad_health"]["bad_health_constant"]
    )

    linear_comb = np.array([0, lc_medium_health, lc_bad_health])
    transition_probs = _softmax(linear_comb)

    return jnp.array([transition_probs[0], transition_probs[1], transition_probs[2]])


def _softmax(lc):
    """Compute the softmax of each element in an array of linear combinations.

    The softmax function is applied to an array of linear combination values (lc)
    to calculate the probabilities of each class in a multinomial logistic
    regression model.
    This function is typically used for multi-class classification problems.

    Args:
        lc (np.ndarray): An array of linear combination values. This can be a 1D array
            representing linear combinations for each class in a single data point,
            or a 2D array representing multiple data points.

    Returns:
        np.ndarray: An array of the same shape as `lc` where each value is transformed
            into the probability of the corresponding class, ensuring that the sum of
            probabilities across classes (for each data point if 2D) equals 1.

    Example:
    >>> lc = np.array([0, 1, 2])
    >>> softmax(lc)
    array([0.09003057, 0.24472847, 0.66524096])

    Note:
    - The function applies np.exp to each element in `lc` and then normalizes so that
      the sum of these exponentials is 1.
    - For numerical stability, the maximum value in each set of linear combinations
      is subtracted from each linear combination before exponentiation.

    """
    e_lc = np.exp(lc - np.max(lc))  # Subtract max for numerical stability
    return e_lc / e_lc.sum(axis=0)

In [None]:
def prob_care_demand_mother(
    mother_age,
    mother_alive,
    mother_health,
    options,
    params,
):
    """Create nested exogenous care demand probabilities.

    Compute based on parent alive. Otherwise zero.
    Done outside?!

    Nested exogenous transitions:
    - First, a parent's health state is determined by their age and lagged health state.

    Args:
        parental_age (int): Age of parent.
        parent_alive (int): Binary indicator of whether parent is alive.
        good_health (int): Binary indicator of good health.
        medium_health (int): Binary indicator of medium health.
        bad_health (int): Binary indicator of bad health.
        params (dict): Dictionary of parameters.

    Returns:
        jnp.ndarray: Array of shape (2,) representing the probabilities of
            no care demand and care demand, respectively.

    """

    good_health = mother_health == GOOD_HEALTH
    medium_health = mother_health == MEDIUM_HEALTH
    bad_health = mother_health == BAD_HEALTH

    survival_prob = prob_survival_mother(mother_age, options)  # mother

    trans_probs_health = exog_health_transition_mother(
        mother_age,
        mother_health,
        options,
        params,
    )
    # parent alive?

    prob_care_good = _exog_care_demand_mother(
        mother_age=mother_age, mother_health=0, options=options, params=params
    )
    prob_care_medium = _exog_care_demand_mother(
        mother_age=mother_age, parental_health=1, options=options, params=params
    )
    prob_care_bad = _exog_care_demand_mother(
        mother_age=mother_age, mother_health=2, options=options, params=params
    )

    _trans_probs_care_demand = jnp.array(
        [prob_care_bad, prob_care_medium, prob_care_good],
    )

    # Non-zero probability of care demand only if parent is alive,
    # weighted by the parent's survival probability
    prob_care_demand = (survival_prob * mother_alive) * (
        trans_probs_health @ _trans_probs_care_demand
    )

    return jnp.array([1 - prob_care_demand, prob_care_demand])


def _exog_care_demand_mother(mother_age, mother_health, params):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    logit = (
        params["exog_care_mother_constant"]
        + params["exog_care_mother_age"] * mother_age
        + params["exog_care_mother_age_squared"] * (mother_age**2)
        + params["exog_care_mother_medium_health"] * (mother_health == MEDIUM_HEALTH)
        + params["exog_care_mother_bad_health"] * (mother_health == BAD_HEALTH)
    )
    return 1 / (1 + np.exp(-logit))

In [None]:
def prob_care_demand_father(
    father_age,
    father_alive,
    father_health,
    options,
    params,
):
    """Create nested exogenous care demand probabilities.

    Compute based on parent alive. Otherwise zero.
    Done outside?!

    Nested exogenous transitions:
    - First, a parent's health state is determined by their age and lagged health state.

    Args:
        parental_age (int): Age of parent.
        parent_alive (int): Binary indicator of whether parent is alive.
        good_health (int): Binary indicator of good health.
        medium_health (int): Binary indicator of medium health.
        bad_health (int): Binary indicator of bad health.
        params (dict): Dictionary of parameters.

    Returns:
        jnp.ndarray: Array of shape (2,) representing the probabilities of
            no care demand and care demand, respectively.

    """

    good_health = mother_health == GOOD_HEALTH
    medium_health = mother_health == MEDIUM_HEALTH
    bad_health = mother_health == BAD_HEALTH

    survival_prob = prob_survival_mother(father_age, options)  # mother

    trans_probs_health = exog_health_transition_father(
        father_age,
        father_health,
        options,
        params,
    )
    # parent alive?

    prob_care_good = _exog_care_demand_father(
        father_age=father_age, father_health=0, options=options, params=params
    )
    prob_care_medium = _exog_care_demand_mother(
        father_age=father_age, father_health=1, options=options, params=params
    )
    prob_care_bad = _exog_care_demand_mother(
        father_age=father_age, fathermother_health=2, options=options, params=params
    )

    _trans_probs_care_demand = jnp.array(
        [prob_care_bad, prob_care_medium, prob_care_good],
    )

    # Non-zero probability of care demand only if parent is alive,
    # weighted by the parent's survival probability
    prob_care_demand = (survival_prob * father_alive) * (
        trans_probs_health @ _trans_probs_care_demand
    )

    return jnp.array([1 - prob_care_demand, prob_care_demand])


def _exog_care_demand_mother(father_age, father_health, params):
    """Compute scalar care demand probability.

    Returns:
        float: Probability of needing care given health state.

    """
    logit = (
        params["exog_care_father_constant"]
        + params["exog_care_father_age"] * father_age
        + params["exog_care_father_age_squared"] * (father_age**2)
        + params["exog_care_father_medium_health"] * (mother_health == MEDIUM_HEALTH)
        + params["exog_care_father_bad_health"] * (mother_health == BAD_HEALTH)
    )
    return 1 / (1 + np.exp(-logit))

# 1. State Space

In [24]:
model_params = {
    "quadrature_points_stochastic": 5,
    "min_age": 50,
    "max_age": 65,
    "consumption_floor": 400,
    "unemployment_benefits": 500,
    "informal_care_benefits": 444.0466,  # 0.4239 * 316 + 0.2793 * 545 + 728 *0.1405 + 901 * 0.0617
    "formal_care_costs": 118.10658099999999,  # >>> 79.31 * 0.0944 + 0.4239 * 70.77 + 0.2793 * 176.16 + 224.26 *0.1401
    "interest_rate": 0.04,  # Adda et al (2017)
    "choices": np.arange(12),
    # ===================
    # EXOGENOUS PROCESSES
    # ===================
    # survival probability
    "survival_probability_mother_constant": 17.01934835131644,
    "survival_probability_mother_age": -0.21245937682111807,
    "survival_probability_mother_age_squared": 0.00047537366767865137,
    "survival_probability_father_constant": 11.561515476144223,
    "survival_probability_father_age": -0.11058331994203506,
    "survival_probability_father_age_squared": -1.0998977981246952e-05,
    # health
    "medium_health_mother": {
        "medium_health_age": 0.0304,
        "medium_health_age_squared": -1.31e-05,
        "medium_health_lagged_good_health": -1.155,
        "medium_health_lagged_medium_health": 0.736,
        "medium_health_lagged_bad_health": 1.434,
        "medium_health_constant": -1.550,
    },
    "bad_health_mother": {
        "bad_health_age": 0.196,
        "bad_health_age_squared": -0.000885,
        "bad_health_lagged_good_health": -2.558,
        "bad_health_lagged_medium_health": -0.109,
        "bad_health_lagged_bad_health": 2.663,
        "bad_health_constant": -9.220,
    },
    "medium_health_father": {
        "medium_health_age": 0.176,
        "medium_health_age_squared": -0.000968,
        "medium_health_lagged_good_health": -1.047,
        "medium_health_lagged_medium_health": 1.016,
        "medium_health_lagged_bad_health": 1.743,
        "medium_health_constant": -7.374,
    },
    "bad_health_father": {
        "bad_health_age": 0.260,
        "bad_health_age_squared": -0.00134,
        "bad_health_lagged_good_health": -2.472,
        "bad_health_lagged_medium_health": 0.115,
        "bad_health_lagged_bad_health": 3.067,
        "bad_health_constant": -11.89,
    },
    # TODO: care demand
    "exog_care_mother_constant": 0,
    "exog_care_mother_age": 0,
    "exog_care_mother_age_squared": 0,
    "exog_care_mother_medium_health": 0,
    "exog_care_mother_bad_health": 0,
    #
    "exog_care_mother_constant": 0,
    "exog_care_father_age": 0,
    "exog_care_father_age_squared": 0,
    "exog_care_father_medium_health": 0,
    "exog_care_father_bad_health": 0,
}

In [26]:
options = {
    "state_space": {
        "n_periods": 30,
        "choices": np.arange(12),
        "endogenous_states": {
            "married": np.arange(2),
            "has_sibling": np.arange(2),
            "mother_age": np.arange(start=PARENT_MIN_AGE, stop=PARENT_MAX_AGE),
            "father_age": np.arange(start=PARENT_MIN_AGE, stop=PARENT_MAX_AGE),
            # "mother_alive": np.arange(2),
            # "father_alive": np.arange(2),
            # "mother_health": np.arange(3),
            # "father_health": np.arange(3),
        },
        "exogenous_processes": {
            "prob_part_time_offer": {
                "states": np.arange(2),
                "transition": jnp.array([0]),
            },
            "prob_full_time_offer": {
                "states": np.arange(2),
                "transition": jnp.array([0]),
            },
            "prob_care_demand": {"states": np.arange(2), "transition": jnp.array([0])},
            "prob_mother_alive": {"states": np.arange(2), "transition": jnp.array([0])},
            "prob_father_alive": {
                "states": np.arange(2),
                "transition": jnp.array([0]),
            },
            "health_transition_mother": {
                "states": np.arange(3),
                "transition": jnp.array([0]),
            },
            "health_transition_father": {
                "states": np.arange(3),
                "transition": jnp.array([0]),
            },
        },
    },
}

def sparsity_condtion(
    mother_age, father_age, mother_alive, father_alive, mother_health, father_health
):
    if (mother_alive == 0) & (mother_age >= 0):
        return False
    elif (father_alive == 0) & (father_age >= 0):
        return False
    # set mother and father health to -1 if not alive
    elif (mother_alive == 0) & (mother_health >= 0):
        return False
    elif (father_alive == 0) & (father_health >= 0):
        return False
    else:
        return True

In [None]:
def update_endog_state(
    period,
    married,
    has_sister,
    mother_age,
    father_age,
    choice,
):
    next_state = {}
    
    next_state["period"] = period + 1
    next_state["lagged_choice"] = choice
    
    next_state["mother_age"] = mother_age + 1
    next_state["father_age"] = father_age + 1
    
    # alive based on exog state
    
    # health based on exog state
    
    return next_state

In [11]:
np.arange(start=68, stop=98)

array([68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97])

# 2. Choices

### Labor choices (3)

- no work
- part-time
- full-time


### Caregiving choices (4)

- no informal + formal
- no informal + no formal (no care)
- intensive informal + no_formal
- intensive informal + formal

The outside care option (neither organize
formal care nor organize formal care once care demand arises) captures that siblings, the more healthy
parent or others organize or provide care to the parent. BFischer, p. 13

In [68]:
labor = ["no_work", "part_time", "full_time"]
caregiving = [
    "no_informal_no_formal",
    "no_informal_formal",
    "intensive_informal_no_formal",
    "intensive_informal_formal",
]

combinations_dict = {}

for i, labor_element in enumerate(labor):
    for j, caregiving_element in enumerate(caregiving):
        key = i * len(caregiving) + j  # Generating unique keys
        value = [labor_element, caregiving_element]
        combinations_dict[key] = value

combinations_dict

{0: ['no_work', 'no_informal_no_formal'],
 1: ['no_work', 'no_informal_formal'],
 2: ['no_work', 'intensive_informal_no_formal'],
 3: ['no_work', 'intensive_informal_formal'],
 4: ['part_time', 'no_informal_no_formal'],
 5: ['part_time', 'no_informal_formal'],
 6: ['part_time', 'intensive_informal_no_formal'],
 7: ['part_time', 'intensive_informal_formal'],
 8: ['full_time', 'no_informal_no_formal'],
 9: ['full_time', 'no_informal_formal'],
 10: ['full_time', 'intensive_informal_no_formal'],
 11: ['full_time', 'intensive_informal_formal']}

In [69]:
formal = [
    combinations_dict[1],
    combinations_dict[3],
    combinations_dict[5],
    combinations_dict[7],
    combinations_dict[9],
    combinations_dict[11],
]

intensive_informal = [
    combinations_dict[2],
    combinations_dict[3],
    combinations_dict[6],
    combinations_dict[7],
    combinations_dict[10],
    combinations_dict[11],
]

In [70]:
formal

[['no_work', 'no_informal_formal'],
 ['no_work', 'intensive_informal_formal'],
 ['part_time', 'no_informal_formal'],
 ['part_time', 'intensive_informal_formal'],
 ['full_time', 'no_informal_formal'],
 ['full_time', 'intensive_informal_formal']]

In [71]:
intensive_informal

[['no_work', 'intensive_informal_no_formal'],
 ['no_work', 'intensive_informal_formal'],
 ['part_time', 'intensive_informal_no_formal'],
 ['part_time', 'intensive_informal_formal'],
 ['full_time', 'intensive_informal_no_formal'],
 ['full_time', 'intensive_informal_formal']]

In [72]:
formal_care = choice % 2 == 1  # uneven numbers mark formal care
intensive_informal_care = choice in INFORMAL_CARE

no_work = choice in NO_WORK
part_time = choice in PART_TIME
full_time = choice in FULL_TIME

In [73]:
def test_choice(choice):
    formal_care = choice % 2 == 1  # uneven numbers mark formal care
    intensive_informal_care = choice in INFORMAL_CARE

    return intensive_informal_care, formal_care

In [74]:
n_choices = 12

for choice in range(n_choices):
    print(test_choice(choice))

(False, False)
(False, True)
(True, False)
(True, True)
(False, False)
(False, True)
(True, False)
(True, True)
(False, False)
(False, True)
(True, False)
(True, True)


In [75]:
combinations_dict

{0: ['no_work', 'no_informal_no_formal'],
 1: ['no_work', 'no_informal_formal'],
 2: ['no_work', 'intensive_informal_no_formal'],
 3: ['no_work', 'intensive_informal_formal'],
 4: ['part_time', 'no_informal_no_formal'],
 5: ['part_time', 'no_informal_formal'],
 6: ['part_time', 'intensive_informal_no_formal'],
 7: ['part_time', 'intensive_informal_formal'],
 8: ['full_time', 'no_informal_no_formal'],
 9: ['full_time', 'no_informal_formal'],
 10: ['full_time', 'intensive_informal_no_formal'],
 11: ['full_time', 'intensive_informal_formal']}

In [76]:
MIN_AGE = 50

In [77]:
def utility_func(
    consumption: jnp.array, period, choice: int, options: dict, params: dict
) -> jnp.array:
    """Computes the agent's current utility based on a CRRA utility function.

    Args:
        consumption (jnp.array): Level of the agent's consumption.
            Array of shape (i) (n_quad_stochastic * n_grid_wealth,)
            when called by :func:`~dcgm.call_egm_step.map_exog_to_endog_grid`
            and :func:`~dcgm.call_egm_step.get_next_period_value`, or
            (ii) of shape (n_grid_wealth,) when called by
            :func:`~dcgm.call_egm_step.get_current_period_value`.
        choice (int): Choice of the agent, e.g. 0 = "retirement", 1 = "working".
        params (dict): Dictionary containing model parameters.
            Relevant here is the CRRA coefficient theta.

    Returns:
        utility (jnp.array): Agent's utility . Array of shape
            (n_quad_stochastic * n_grid_wealth,) or (n_grid_wealth,).

    """
    theta = params["theta"]
    age = period + options["min_age"]

    intensive_informal_care = is_informal_care(choice)
    formal_care = is_formal_care(choice)
    part_time = is_part_time(choice)
    full_time = is_full_time(choice)

    working_hours = (
        part_time * WEEKLY_HOURS_PART_TIME + full_time * WEEKLY_HOURS_FULL_TIME
    )
    # From SOEP data we know that the 25% and 75% percentile in the care hours
    # distribution are 7 and 21 hours per week in a comparative sample.
    # We use these discrete mass-points as discrete choices of non-intensive and
    # intensive informal care.
    # In SHARE, respondents inform about the frequency with which they provide
    # informal care. We use this information to proxy the care provision in the data.
    caregiving_hours = intensive_informal_care * WEEKLY_INTENSIVE_INFORMAL_HOURS
    leisure_hours = (TOTAL_WEEKLY_HOURS - working_hours - caregiving_hours) * 4.33

    utility_consumption = (consumption ** (1 - theta) - 1) / (1 - theta)

    # age is a proxy for health impacting the taste for free-time.
    utility_leisure = (
        params["utility_leisure_constant"]
        + params["utility_leisure_age"] * (age - MIN_AGE)
    ) * np.log(leisure_hours)

    return (
        utility_consumption
        - params["disutility_part_time"] * part_time
        - params["disutility_full_time"] * full_time
        + utility_leisure
        ## utility from caregiving
        + options["utility_informal_care"] * intensive_informal_care
        + options["utility_formal_care"] * formal_care
        + options["utility_informal_and_formal_care"]
        * (formal_care & intensive_informal_care)
    )

In [None]:
def marginal_utility(consumption, params):
    return consumption ** -params["theta"]


def inverse_marginal_utility(marginal_utility, params):
    return marginal_utility ** (-1 / params["theta"])


utility_functions = {
    "utility": utility_func,
    "inverse_marginal_utility": inverse_marginal_utility,
    "marginal_utility": marginal_utility,
}

In [78]:
l = True
i = False
l = False

a = l | i
a

False

In [79]:
combinations_dict

{0: ['no_work', 'no_informal_no_formal'],
 1: ['no_work', 'no_informal_formal'],
 2: ['no_work', 'intensive_informal_no_formal'],
 3: ['no_work', 'intensive_informal_formal'],
 4: ['part_time', 'no_informal_no_formal'],
 5: ['part_time', 'no_informal_formal'],
 6: ['part_time', 'intensive_informal_no_formal'],
 7: ['part_time', 'intensive_informal_formal'],
 8: ['full_time', 'no_informal_no_formal'],
 9: ['full_time', 'no_informal_formal'],
 10: ['full_time', 'intensive_informal_no_formal'],
 11: ['full_time', 'intensive_informal_formal']}

In [80]:
combinations_dict[0]

['no_work', 'no_informal_no_formal']

In [81]:
combinations_dict[4]

['part_time', 'no_informal_no_formal']

In [82]:
combinations_dict[8]

['full_time', 'no_informal_no_formal']

In [83]:
no_care = [0, 4, 8]
all_choices = list(np.arange(12))
care = [choice for choice in all_choices if choice not in no_care]

In [84]:
care

[1, 2, 3, 5, 6, 7, 9, 10, 11]

In [87]:
def get_state_specific_feasible_choice_set(
    lagged_choice, mother_age, father_age, part_time_offer, full_time_offer, options
):
    # formal_care = choice % 2 == 1  # uneven numbers mark formal care
    # light_informal_care = [2, 3, 8, 9, 14, 15]
    # intensive_informal_care =[4, 5, 10, 11, 16, 17]
    no_care = NO_CARE
    care = [choice for choice in all_choices if choice not in no_care]

    no_work = NO_WORK
    work = WORK
    part_time = PART_TIME
    full_time = FULL_TIME

    # state_vec including exog?
    feasible_choice_set = list(np.arange(options["n_discrete_choices"]))

    # care demand
    if mother_alive or father_alive:
        feasible_choice_set = [i for i in feasible_choice_set if i in care]
    else:
        feasible_choice_set = [i for i in feasible_choice_set if i in no_care]

    # job offer
    if (full_time_offer == True) | (part_time_offer == True):
        feasible_choice_set = [i for i in feasible_choice_set if i in work]
    elif (full_time_offer == False) & (part_time_offer == True):
        feasible_choice_set = [i for i in feasible_choice_set if i in part_time]
    elif (full_time_offer == False) & (part_time_offer == False):
        feasible_choice_set = [i for i in feasible_choice_set if i in full_time]
    else:
        # (full_time_offer == False) & (part_time_offer == False)
        feasible_choice_set = [i for i in feasible_choice_set if i in no_work]

    return np.array(feasible_choice_set)

In [88]:
feasible_choice_set = list(np.arange(12))
feasible_choice_set

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

In [19]:
var = False
bar = True

In [20]:
if bar and not var:
    print("success")

success


In [22]:
if not var:
    print("corr")

corr


In [23]:
0 == False

True

In [89]:
RETIREMENT_AGE = 65

In [90]:
formal_care = choice % 2 == 1  # uneven numbers mark formal care
intensive_informal_care = choice in [2, 3, 6, 7, 10, 11]

no_work = choice in [0, 1, 2, 3]
part_time = choice in [4, 5, 6, 7]
full_time = choice in [8, 9, 10, 11]

In [91]:
l = [2, 3, 8, 9, 14, 15, 20, 21] + [4, 5, 10, 11, 16, 17, 22, 23]
l.sort()
l

[2, 3, 4, 5, 8, 9, 10, 11, 14, 15, 16, 17, 20, 21, 22, 23]

In [92]:
def is_not_working(lagged_choice):
    return lagged_choice in NO_WORK


def is_part_time(lagged_choice):
    return lagged_choice in PART_TIME


def is_full_time(lagged_choice):
    return lagged_choice in FULL_TIME


def is_formal_care(lagged_choice):
    return lagged_choice in FORMAL_CARE


def is_informal_care(lagged_choice):
    # intensive only here
    return lagged_choice in INFORMAL_CARE

In [93]:
combinations_dict

{0: ['no_work', 'no_informal_no_formal'],
 1: ['no_work', 'no_informal_formal'],
 2: ['no_work', 'intensive_informal_no_formal'],
 3: ['no_work', 'intensive_informal_formal'],
 4: ['part_time', 'no_informal_no_formal'],
 5: ['part_time', 'no_informal_formal'],
 6: ['part_time', 'intensive_informal_no_formal'],
 7: ['part_time', 'intensive_informal_formal'],
 8: ['full_time', 'no_informal_no_formal'],
 9: ['full_time', 'no_informal_formal'],
 10: ['full_time', 'intensive_informal_no_formal'],
 11: ['full_time', 'intensive_informal_formal']}

In [94]:
is_full_time(0)

False

## Budget constraint

In [None]:
non_working_benefits = 500 # 400?

In [None]:
def budget_constraint(
    period: int,
    married: int,
    high_educ: int,
    lagged_choice: int,
    savings_end_of_previous_period: float,
    income_shock_previous_period: float,
    options: Dict[str, Any],
    params: Dict[str, float],
) -> float:
    # already done in preprocessing
    # model_params = options["model_params"]

    # monthly
    working_hours = (
        is_part_time(lagged_choice) * 20 * 4.33
        + is_full_time(lagged_choice) * 40 * 4.33
    )

    wage_from_previous_period = _calc_stochastic_wage(
        period=period,
        lagged_choice=lagged_choice,
        wage_shock=income_shock_previous_period,
        min_age=options["min_age"],
        constant=params["constant"],
    )

    wealth_beginning_of_period = (
        wage_from_previous_period * working_hours
        # + non_labor_income(age, high_educ, options)
        # + spousal_income(period, high_educ, options) * married
        + options["unemployment_benefits"] * is_not_working(lagged_choice)
        + options["informal_care_benefits"] * is_informal_care(lagged_choice)
        - options["formal_care_costs"] * is_formal_care(lagged_choice)
        + (1 + options["interest_rate"]) * savings_end_of_previous_period
    )

    # needed at all?
    wealth_beginning_of_period = jnp.maximum(
        wealth_beginning_of_period, options["consumption_floor"]
    )

    return wealth_beginning_of_period

In [None]:
@jax.jit
def _calc_stochastic_wage(
    period: int,
    lagged_choice: int,
    wage_shock: float,
    min_age: int,
    params: Dict[str, float],
) -> float:
    """Computes the current level of deterministic and stochastic income.

    Note that income is paid at the end of the current period, i.e. after
    the (potential) labor supply choice has been made. This is equivalent to
    allowing income to be dependent on a lagged choice of labor supply.
    The agent starts working in period t = 0.
    Relevant for the wage equation (deterministic income) are age-dependent
    coefficients of work experience:
    labor_income = constant + alpha_1 * age + alpha_2 * age**2
    They include a constant as well as two coefficients on age and age squared,
    respectively. Note that the last one (alpha_2) typically has a negative sign.

    Args:
        state (jnp.ndarray): 1d array of shape (n_state_variables,) denoting
            the current child state.
        wage_shock (float): Stochastic shock on labor income;
            may or may not be normally distributed. This float represents one
            particular realization of the income_shock_draws carried over from
            the previous period.
        params (dict): Dictionary containing model parameters.
            Relevant here are the coefficients of the wage equation.
        options (dict): Options dictionary.

    Returns:
        stochastic_income (float): The potential end of period income. It consists of a
            deterministic component, i.e. age-dependent labor income,
            and a stochastic shock.

    """
    # For simplicity, assume current_age - min_age = experience
    age = period + min_age

    # Determinisctic component of income depending on experience:
    # constant + alpha_1 * age + alpha_2 * age**2
    # exp_coeffs = jnp.array([constant, exp, exp_squared])
    # labor_income = exp_coeffs @ (age ** jnp.arange(len(exp_coeffs)))
    # working_income = jnp.exp(labor_income + wage_shock)

    log_wage = (
        params["wage_constant"]
        + params["wage_age"] * age
        + params["wage_age_squared"] * age**2
        #+ params["wage_high_educ"] * high_educ
        + params["wage_part_time"] * is_part_time(lagged_choice)
        + params["wage_not_working"] * is_not_working(lagged_choice)
    )

    return jnp.exp(log_wage + wage_shock)

### Non-labor income

In [67]:
def non_labor_income(period, age, high_educ, params, options):
    log_income = (
        options["non_labor_inc_constant"]
        + options["non_labor_inc_above_retirement_age"] * (age >= RETIREMENT_AGE)
        + options["non_labor_inc_high_educ"] * high_educ
        + options["non_labor_inc_age"] * age
        + options["non_labor_inc_age_squared"] * age**2
        # + options["non_labor_year_fixed_effect"] * (period + 2004)
    )

    return log_income

In [68]:
options_non_labor = {
    "non_labor_inc_constant": 7.883,
    "non_labor_inc_high_educ": 0.752,
    "non_labor_inc_age": -0.147,
    "non_labor_inc_age_squared": 0.00184,
    "non_labor_inc_above_retirement_age": 0.00264,
    "non_labor_married": 0.497,
    # "non_labor_year_fixed_effect": 0.0210
}

### Spousal income

In [73]:
def spousal_income(period, age, high_educ, params, options):
    log_income = (
        options["spousal_inc_constant"]
        + options["spousal_inc_above_retirement_age"] * (age >= RETIREMENT_AGE)
        + options["spousal_inc_high_educ"] * high_educ
        + options["spousal_inc_age"] * age
        + options["spousal_inc_age_squared"] * age**2
        + options["spousal_labor_year_fixed_effect"] * (period + 2004)
    )

    return log_income

In [74]:
options_spousal = {
    "spousal_inc_constant": -35.27,
    "spousal_inc_high_educ": 0.358,
    "spousal_inc_age": 0.117,
    "spousal_inc_age_squared": -0.00129,
    "spousal_inc_above_retirement_age": -0.00548,
    "spousal_labor_year_fixed_effect": 0.0210,
}

In [78]:
spousal_income(period=25, age=60, high_educ=0, params={}, options=options_spousal)

9.714999999999996

In [72]:
non_labor_income(period=5, age=70, high_educ=1, params={}, options=options_non_labor)

49.552640000000004

In [79]:
1800 / 4.35

413.7931034482759