In [7]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from numba import jit

In [179]:
@jit
def generate_policy(probs, max_period):
    num_choices = 2
    
    num_probs = len(probs)

    max_streak = max_period
    num_states = max_streak + 1

    ## Get array of choices and possible states
    max_streaks = np.arange(0, num_states)
    streaks = np.arange(0, num_states)
    choices = np.arange(0, num_choices)
    probs_plus_1 = np.concatenate((probs , np.array([1])))

    ## Calculate matrices
    c_repeat = choices.repeat(num_states*num_states*num_probs)
    c_reshape = c_repeat.reshape((-1, num_states, num_states, num_probs))
    choices_mat = c_reshape.transpose((1, 0, 2, 3))

    s_repeat = streaks.repeat(num_choices*num_states*num_probs)
    s_reshape = s_repeat.reshape((-1, num_choices, num_states, num_probs))
    streaks_mat = s_reshape

    ms_repeat = max_streaks.repeat(num_choices*num_states*num_probs)
    ms_reshape = ms_repeat.reshape((-1, num_states, num_choices, num_probs))
    max_streaks_mat = ms_reshape.transpose((1, 2, 0, 3))

    p_repeat = probs.repeat(num_states*num_choices*num_states)
    p_reshape = p_repeat.reshape((-1, num_states, num_choices, num_states))
    probs_mat = p_reshape.transpose((1, 2, 3, 0))
    
    ##### Current Streak Updating
    ## If choice is to skip, then states are just the same as what they were
    streaks_stay_mat = streaks_mat[:, 0, :, :]

    ## If choice is to take the risk, then potential state update is current streak increases by 1
    streaks_win_mat = np.minimum(streaks_mat[:, 1, :, :] + 1, max_period)

    ## The risk though is current streak goes to zero
    streaks_lose_mat = np.zeros((num_states, num_states, num_probs), dtype='int')


    ###### Max streak Updating
    ## If choice is to skip, then states are just the same as what they were
    max_of_max_current = np.maximum(max_streaks_mat[:, 0, :, :], streaks_mat[:, 0, :, :])
    max_streaks_stay_mat = max_of_max_current
    max_streaks_lose_mat = max_of_max_current

    ## If choice is to take the risk, then potential state update is increased if current streak == max streak
    max_of_max_winstreak = np.maximum(max_streaks_mat[:, 1, :, :], streaks_mat[:, 1, :, :] + 1)
    max_streaks_win_mat = np.minimum(max_of_max_winstreak , max_period)


    #### Probs updating
    probs_ind = np.arange(0, num_probs)
    p_ind_repeat = probs_ind.repeat(num_states*num_states)
    p_ind_reshape = p_ind_repeat.reshape((-1, num_states, num_states))
    probs_new_mat = p_ind_reshape.transpose((1, 2, 0))

    orig_shape = streaks_stay_mat.shape

    streaks_stay_mat_flat = streaks_stay_mat.flatten()
    max_streaks_stay_mat_flat = max_streaks_stay_mat.flatten()
    probs_new_mat_flat = probs_new_mat.flatten()
    
    streaks_win_mat_flat = streaks_win_mat.flatten()
    max_streaks_win_mat_flat = max_streaks_win_mat.flatten()
    
    streaks_lose_mat_flat = streaks_lose_mat.flatten()
    max_streaks_lose_mat_flat = max_streaks_lose_mat.flatten()
    
    
    V_funcs = {}
#     G_funcs = {}
#     Cutoffs = {}

    V_funcs['V' + str(max_period)] = max_of_max_current
    for period in range(max_period-1, -1, -1):
        next = period + 1
        next_V = V_funcs['V' + str(next)]

        
        stay = next_V[streaks_stay_mat_flat, max_streaks_stay_mat_flat, probs_new_mat_flat].reshape(orig_shape)
        win = next_V[streaks_win_mat_flat, max_streaks_win_mat_flat, probs_new_mat_flat].reshape(orig_shape)
        lose = next_V[streaks_lose_mat_flat, max_streaks_lose_mat_flat, probs_new_mat_flat].reshape(orig_shape)
        Exp_V_stay = np.mean(stay, axis=2)
        Exp_V_win = np.mean(win, axis=2)
        Exp_V_lose = np.mean(lose, axis=2)

#         Exp = np.zeros((num_states, num_choices, num_states, num_probs))
#         Exp[:, 0, :, :] = np.tile(Exp_V_stay, (num_probs, 1, 1)).transpose([1, 2, 0])
#         Exp[:, 1, :, :] = (
#             np.einsum("ij,k->ijk", Exp_V_win, probs_mat[0, 1, 0, :]) + 
#             np.einsum("ij,k->ijk", Exp_V_lose, (1 - probs_mat[0, 1, 0, :]))
#         )

#         G_funcs['G' + str(period)] = np.array(np.argmax(Exp, axis=1), dtype=float)
#         G_funcs['G' + str(period)][next:, :, :] = np.nan
#         G_funcs['G' + str(period)][:, next:, :] = np.nan
#         V_funcs['V' + str(period)] = np.array(np.amax(Exp, axis=1), dtype=float)
#         V_funcs['V' + str(period)][next:, :] = np.nan
#         V_funcs['V' + str(period)][:, next:, :] = np.nan
      
        V_funcs['V' + str(period)] = V_funcs['V' + str(next)]

In [180]:
probs = np.arange(.7, .8, .01)

In [181]:
%%timeit
generate_policy(probs, 200)

Compilation is falling back to object mode WITH looplifting enabled because Function "generate_policy" failed type inference due to: No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(int64, 3d, C), UniTuple(array(int64, 1d, C) x 3))
 
There are 22 candidate implementations:
      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(int64, 3d, C), UniTuple(array(int64, 1d, C) x 3))':
       No match.
      - Of which 2 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 162.
        With argument(s): '(array(int64, 3d, C), UniTuple(array(int64, 1d, C) x 3))':
       Rejected as the implementation raised a specific error:
         NotImplementedError: only one advanced index supported
  raised from /Users/rsfletch/opt/anaconda3/envs/bts/lib/python3.7/site-packages/numba/core/typing/a

2.86 s ± 196 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
max_period = 5
num_choices = 2

num_probs = len(probs)

max_streak = max_period
num_states = max_streak + 1

## Get array of choices and possible states
max_streaks = np.arange(0, num_states)
streaks = np.arange(0, num_states)
choices = np.arange(0, num_choices)
probs_plus_1 = np.concatenate((probs , np.array([1])))

In [155]:
next_V[streaks_stay_mat, max_streaks_stay_mat, probs_new_mat]

NameError: name 'next_V' is not defined

In [157]:
num_choices = 2

num_probs = len(probs)

max_streak = max_period
num_states = max_streak + 1

## Get array of choices and possible states
max_streaks = np.arange(0, num_states)
streaks = np.arange(0, num_states)
choices = np.arange(0, num_choices)
probs_plus_1 = np.concatenate((probs , np.array([1])))

## Calculate matrices
c_repeat = choices.repeat(num_states*num_states*num_probs)
c_reshape = c_repeat.reshape((-1, num_states, num_states, num_probs))
choices_mat = c_reshape.transpose((1, 0, 2, 3))

s_repeat = streaks.repeat(num_choices*num_states*num_probs)
s_reshape = s_repeat.reshape((-1, num_choices, num_states, num_probs))
streaks_mat = s_reshape

ms_repeat = max_streaks.repeat(num_choices*num_states*num_probs)
ms_reshape = ms_repeat.reshape((-1, num_states, num_choices, num_probs))
max_streaks_mat = ms_reshape.transpose((1, 2, 0, 3))

p_repeat = probs.repeat(num_states*num_choices*num_states)
p_reshape = p_repeat.reshape((-1, num_states, num_choices, num_states))
probs_mat = p_reshape.transpose((1, 2, 3, 0))

##### Current Streak Updating
## If choice is to skip, then states are just the same as what they were
streaks_stay_mat = streaks_mat[:, 0, :, :]

## If choice is to take the risk, then potential state update is current streak increases by 1
streaks_win_mat = np.minimum(streaks_mat[:, 1, :, :] + 1, max_period)

## The risk though is current streak goes to zero
streaks_lose_mat = np.zeros((num_states, num_states, num_probs), dtype='int')


###### Max streak Updating
## If choice is to skip, then states are just the same as what they were
max_of_max_current = np.maximum(max_streaks_mat[:, 0, :, :], streaks_mat[:, 0, :, :])
max_streaks_stay_mat = max_of_max_current
max_streaks_lose_mat = max_of_max_current

## If choice is to take the risk, then potential state update is increased if current streak == max streak
max_of_max_winstreak = np.maximum(max_streaks_mat[:, 1, :, :], streaks_mat[:, 1, :, :] + 1)
max_streaks_win_mat = np.minimum(max_of_max_winstreak , max_period)


#### Probs updating
probs_ind = np.arange(0, num_probs)
p_ind_repeat = probs_ind.repeat(num_states*num_states)
p_ind_reshape = p_ind_repeat.reshape((-1, num_states, num_states))
probs_new_mat = p_ind_reshape.transpose((1, 2, 0))

V_funcs = {}
#     G_funcs = {}
#     Cutoffs = {}

V_funcs['V' + str(max_period)] = max_of_max_current

In [175]:
test = V_funcs['V5']
test_get = test[streaks_stay_mat, max_streaks_stay_mat, probs_new_mat]

orig_shape = streaks_stay_mat.shape

streaks_stay_mat_flat = streaks_stay_mat.flatten()
max_streaks_stay_mat_flat = max_streaks_stay_mat.flatten()
probs_new_mat_flat = probs_new_mat.flatten()
test_get_2 = test[streaks_stay_mat_flat, max_streaks_stay_mat_flat, probs_new_mat_flat].reshape(orig_shape)

(test_get != test_get_2).sum()

0