In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch

In [None]:
def sample_next_state(index: int, P: np.ndarray):
    probs = P[index]

    next = np.random.choice(range(P.shape[1]), probs)

    return next

In [None]:
def td_loop(
        n_iter: int, 
        X: np.ndarray, 
        y: np.ndarray, 
        P: np.ndarray,
        link: callable[np.ndarray, np.ndarray], 
        inv_link: callable[np.ndarray, np.ndarray], 
        gamma: float, 
        alpha: float
        ) -> np.ndarray:
    n_samples = X.shape[0]
    n_features = X.shape[1]

    w = np.zeros(n_features)

    curr_index = np.random.randint(n_samples)
    curr_x = X[curr_index]
    curr_y = y[curr_index]

    for i in range(n_iter):
        # Next state samples
        next_index = sample_next_state(index=curr_index, P=P)
        next_x = X[next_index]
        next_y = y[next_index]

        # Find rewards
        r = inv_link(curr_y) - gamma * inv_link(next_y)

        # TD target
        z = r + gamma * np.dot(next_x, w)

        grad = (link(np.dot(curr_x, w)) - link(z)) * curr_x

        w -= alpha * grad

        curr_index, curr_x, curr_y = next_index, next_x, next_y