In [2]:
%load_ext autoreload

%autoreload 2
import numpy as np
import numba
from scipy.stats import chi2
import numba
from robupy.auxiliary import get_worst_case_probs
from ruspy.estimation.estimation_cost_parameters import (
    lin_cost,
    cost_func,
    create_transition_matrix,
    calc_fixp
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


ModuleNotFoundError: No module named 'robupy'

In [2]:
num_states = 50
beta = 0.9999
params = np.array([10, 2.7])
costs = cost_func(num_states, lin_cost, params)
p_ml = np.array([0.39189189, 0.59529357, 0.01281454])
omega = 0.95
rho = chi2.ppf(omega, len(p_ml) - 1) / (2 * 4292)
# rho = 0

In [3]:
@numba.jit(nopython=True)
def calc_fixp_worst(num_states, p_ml, costs, beta, rho, threshold=1e-12, max_it=1000000):
    """
    The function to calculate the expected value fix point.

    :param num_states:  The size of the state space.
    :type num_states:   int
    :param trans_mat:   A two dimensional numpy array containing a s x s markov
                        transition matrix.
    :param costs:       A two dimensional float numpy array containing for each
                        state the cost to maintain in the first and to replace the bus
                        engine in the second column.
    :param beta:        The discount factor.
    :type beta:         float
    :param threshold:   A threshold for the convergence. By default set to 1e-6.
    :type threshold:    float
    :param max_it:      Maximum number of iterations. By default set to 1000000.
    :type max_it:       int

    :return: A numpy array containing for each state the expected value fixed point.
    """
    dim_p = len(p_ml)
    ev = np.zeros(num_states)
    trans_mat = create_transition_matrix(num_states, p_ml)
    ev_new = np.dot(trans_mat, np.log(np.sum(np.exp(-costs), axis=1)))
    while (np.max(np.abs(ev_new - ev)) > threshold) & (max_it != 0):
        ev = ev_new
        maint_cost = beta * ev - costs[:, 0]
        repl_cost = beta * ev[0] - costs[0, 1] - costs[0, 0]
        ev_min = maint_cost[0]
        log_sum = ev_min + np.log(
            np.exp(maint_cost - ev_min) + np.exp(repl_cost - ev_min)
        )
        worst_trans_mat = np.zeros(shape=(num_states, num_states), dtype=(np.float64))
        for s in range(num_states - dim_p):
            p = trans_mat[s, s : s + dim_p]
            v = log_sum[s : s + dim_p]
            worst_trans_mat[s, s: s + dim_p] = get_worst_case_probs(v, p, rho, is_cost=False)
        for i in range(2, dim_p + 1): # Indexing from the back starts with 1
            v = log_sum[-i:]
            p = trans_mat[-i, -i:]
            worst_trans_mat[-i, -i:] = get_worst_case_probs(v, p, rho, is_cost=False)
        worst_trans_mat[-1, -1] = 1
        ev_new = np.dot(worst_trans_mat, log_sum)
        max_it -= 1
    if max_it == 0:
        print("The value function didn't converge.")  
    return ev_new, worst_trans_mat

In [4]:
worst_ev, worst_trans_mat = calc_fixp_worst(num_states, p_ml, costs, beta, rho)

990000
980000
970000
960000
950000
940000
930000
920000
910000
900000
890000
880000
870000
860000
850000
840000
830000
820000
810000
800000
790000
780000
770000
760000
750000
740000
730000
720000
710000
700000
690000
680000
670000
660000
650000
640000
630000
620000
610000
600000
590000
580000
570000
560000
550000
540000
530000
520000
510000
500000
490000
480000
470000
460000
450000
440000
430000
420000
410000
400000
390000
380000
370000
360000
350000
340000
330000
320000
310000
300000
290000
280000
270000
260000
250000
240000
230000
220000
210000
200000
190000
180000
170000
160000
150000
140000
130000
120000
110000
100000
90000
80000
70000
60000
50000
40000
30000
20000
10000
0
The value function didn't converge.


In [5]:
trans_mat = create_transition_matrix(num_states, p_ml)
ev = calc_fixp(num_states, trans_mat, costs, beta)

In [8]:
worst_trans_mat

array([[0.37416461, 0.61168642, 0.01414897, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.37416433, 0.61168712, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.37416384, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.39072236, 0.59640738,
        0.01287026],
       [0.        , 0.        , 0.        , ..., 0.        , 0.39130355,
        0.60869645],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        1.        ]])