In [7]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import Callable

np.random.seed(42)

In [8]:
def sample_next_state(index: int, P: np.ndarray) -> int:
    probs = P[index]

    next = np.random.choice(range(P.shape[1]), p=probs)

    return int(next)

In [9]:
def td_loop(
        n_iter: int, 
        X: np.ndarray, 
        y: np.ndarray, 
        P: np.ndarray,
        link: Callable[[np.ndarray], np.ndarray], 
        inv_link: Callable[[np.ndarray], np.ndarray], 
        gamma: float, 
        alpha: float,
        epsilon: float,
    ) -> np.ndarray:
    n_samples = X.shape[0]
    n_features = X.shape[1]

    w = np.zeros(n_features)

    curr_index = int(np.random.randint(n_samples))
    curr_x = X[curr_index]
    curr_y = y[curr_index]

    i = 0
    grad = np.ones_like(w)

    while i < n_iter and np.linalg.norm(alpha * grad, 2) > epsilon:
        # Next state samples
        next_index = sample_next_state(index=curr_index, P=P)
        next_x = X[next_index]
        next_y = y[next_index]

        # Find rewards
        r = inv_link(curr_y) - gamma * inv_link(next_y)

        # TD target
        z = r + gamma * np.dot(next_x, w)

        # Find gradient
        grad = (link(np.dot(curr_x, w)) - link(z)) * curr_x

        # Update weights
        w -= alpha * grad
        
        # Update state and index
        curr_index, curr_x, curr_y = next_index, next_x, next_y
        i += 1

    return w

In [10]:
# Generate synthetic data
num_samples = 100
num_features = 3

X = np.random.randn(num_samples, num_features)
true_w = np.array([2.0, -3.5, 1.0])
y = X @ true_w + np.random.randn(num_samples) * 0.1  # Adding noise

alpha = 0.01  # Learning rate
gamma = 0.9   # Discount factor
num_iterations = 1e6  # Number of iterations
epsilon = 1e-7

P = np.ones((num_samples, num_samples)) / num_samples # Equal probability to move to any state

In [11]:
w_hat = td_loop(
    n_iter=num_iterations,
    X=X,
    y=y,
    P=P,
    link=lambda x : x,
    inv_link=lambda x : x,
    gamma=gamma,
    alpha=alpha,
    epsilon=epsilon,
)

error = np.linalg.norm(w_hat - true_w, 2)

print(w_hat)
print(f'Error: {error:.5f}')

[ 2.00494837 -3.50096249  0.99077277]
Error: 0.01051
