# 1. Discretize the state space

In [None]:
import numpy as np
from LSTM import LSTM
from System import Plant
from Policy import Estimator
import warnings
from tqdm import tqdm
warnings.filterwarnings('ignore')
np.random.seed(0)

resolution = 0.2
c_dis = np.arange(start=0, stop=7.01, step=resolution)
p_dis = np.arange(start=0, stop=15.01, step=resolution) 
l_dis = np.arange(start=0, stop=10.01, step=resolution)
u_dis = np.arange(start=-5, stop=4, step=1)

def close_idx(list, target):
    return np.argmin(np.abs(list - target))

def get_idx(p_idx, c_idx, l_idx):
    return (p_idx * c_dis.shape[0] * l_dis.shape[0] + 
            c_idx * l_dis.shape[0]                  +
            l_idx)
    
def get_pcl(idx):
    p_idx = idx // (c_dis.shape[0] * l_dis.shape[0])
    c_idx = (idx % (c_dis.shape[0] * l_dis.shape[0])) // l_dis.shape[0]
    l_idx = idx % l_dis.shape[0]
    return [p_dis[p_idx], c_dis[c_idx], l_dis[l_idx]]

n_sample = 10

n = c_dis.shape[0] * p_dis.shape[0] * l_dis.shape[0]

# 2. Generate transition matrix

In [None]:
P_trans = {}

plant = Plant(dt=1/30)
estimator = Estimator()
std = 0.1 # variation of estimated load

for u_idx, u in enumerate(u_dis):
    for p_idx, p in enumerate(tqdm(p_dis)):
        for c_idx, c in enumerate(c_dis):
            for l_idx, l in enumerate(l_dis):
                # l_hat = estimator.estimate(l)
                real_loads = np.random.normal(loc=l, scale=std, size=n_sample)
                for real_l in real_loads:
                    state_idx = get_idx(p_idx, c_idx, l_idx)
                    plant.reset(c, p, l)
                    plant.step(u, real_l)
                    hit_state_idx = get_idx(close_idx(p_dis, plant.p),
                                            close_idx(c_dis, plant.battery.c),
                                            close_idx(l_dis, plant.l))
                    if P_trans.__contains__((u_idx, state_idx, hit_state_idx)):
                        P_trans[u_idx, state_idx, hit_state_idx] += (1/n_sample)
                    else:
                        P_trans[u_idx, state_idx, hit_state_idx] = (1/n_sample)



In [None]:
import matplotlib.pyplot as plt
sum = 0
for j in range(n):
    if P_trans.__contains__((0, j, 0)):
        sum += P_trans[0, j, 0]

print(sum)


# 3. Dynamic programming

In [None]:
def dict_to_np(dict, state_idx):
    """Generate numpy from dictionary, dimension:[u, target_state_idx]

    Args:
        dict (_type_): _description_
        state_idx (_type_): _description_
        result: numpy array with dimension [u, target_state_idx]
    """
    result = np.zeros(shape=[u_dis.shape[0],n])
    for u_idx in range(u_dis.shape[0]):
        for target_state_idx in range(n):
            if dict.__contains__((u_idx, state_idx, target_state_idx)):
                result[u_idx, target_state_idx] = dict[u_idx, state_idx, target_state_idx]
            else:
                result[u_idx, target_state_idx] = 0
    return result

# Initialization of time horizon and cost at time T
T = 1000
J = np.array([get_pcl(i)[0] for i in range(n)]).reshape(-1,1)
Actions = np.zeros(shape=(n, T), dtype=np.int8)

probs = []
for i in tqdm(range(n)):
    probs.append(dict_to_np(P_trans, i))

for t in range(999, 998, -1):
    J_hat = np.zeros(shape=[n, 1])
    for i in tqdm(range(n)):
        J_hat[i] = np.min(probs[i] @ J)
        Actions[i,t] = np.argmin(probs[i] @ J)
