In [None]:
#Note that all variables named reward actually denote cost
import numpy as np
from scipy.optimize import minimize
np.random.seed(0)


class SimpleMDP:
    def __init__(self, num_states, num_actions):
        self.num_states = num_states
        self.num_actions = num_actions
        self.transition_probs = np.random.rand(num_states, num_actions, num_states)
        self.transition_probs /= self.transition_probs.sum(axis=2, keepdims=True)
        print(self.transition_probs)

        self.rewards = np.zeros((num_states, num_actions))+0.5
        for i in range(num_states):
            for j in range(num_actions):
                if i%3 == j:
                    self.rewards[i,j] = 1
        self.gamma = 0.95  
        print(self.rewards)
    def get_reward(self, state, action):
        return self.rewards[state, action]

    def get_transition_probs(self, state, action):
        return self.transition_probs[state, action]

    def next_state(self, state, action):
        return np.random.choice(self.num_states, p=self.transition_probs[state, action])

def initialize_policy(num_states, num_actions):
    return np.full((num_states, num_actions), 1.0 / num_actions)

def compute_bregman_divergence(pi, pi_0):
    if pi.ndim == 1:
        return np.sum(pi * (np.log(pi) - np.log(pi_0 + 1e-10)))  
    else:
        return np.sum(pi * (np.log(pi) - np.log(pi_0 + 1e-10)), axis=1)  




def compute_q_values(mdp, pi, num_states, num_actions):
    Q = np.zeros((num_states, num_actions))
    for s in range(num_states):
        for a in range(num_actions):
            next_state_probs = mdp.get_transition_probs(s, a)  
            immediate_reward = mdp.get_reward(s, a)  
            
            expected_future_value = 0
            for i in range(len(next_state_probs)):
                next_state = i
                prob = next_state_probs[i]
                expected_future_value += prob * (immediate_reward + mdp.gamma * np.dot(Q[next_state], pi[next_state]))
            
            Q[s, a] = expected_future_value
    return Q


def update_policy(pi_k, Q, eta_k, num_states, num_actions, pi_0):
    pi_next = np.zeros_like(pi_k)
    for s in range(num_states):
        def objective(p):
            entropy_term = -np.sum(p * np.log(p + 1e-10))  # Adding entropy regularization
            divergence = compute_bregman_divergence(p, pi_0[s])
            return eta_k * np.dot(Q[s], p) + entropy_term + divergence

        constraints = [
            {'type': 'eq', 'fun': lambda p: np.sum(p) - 1},
            {'type': 'ineq', 'fun': lambda p: p}
        ]
        initial_p = np.full(num_actions, 1.0 / num_actions)
        result = minimize(objective, initial_p, constraints=constraints, method='SLSQP', options={'disp': False})
        if not result.success:
            raise ValueError("Optimization failed: {}".format(result.message))
        pi_next[s] = result.x
        #Make all pi_next non-negative and values sum to 1
        pi_next[s] = pi_next[s] - min(pi_next[s])
        pi_next[s] = pi_next[s]/sum(pi_next[s])

    return pi_next


def simulate_policy(mdp, pi, num_states, num_actions, timesteps=50):
    total_reward = 0
    current_state = np.random.randint(num_states)
    for _ in range(timesteps):
        action = np.random.choice(num_actions, p=pi[current_state])
        reward = mdp.get_reward(current_state, action)
        total_reward += reward
        current_state = mdp.next_state(current_state, action)
    return total_reward / timesteps  # Average reward

num_states = 5
num_actions = 3
eta_k = 1e-2
num_iterations = 1

mdp = SimpleMDP(num_states, num_actions)
pi_0 = initialize_policy(num_states, num_actions)
pi_k = np.copy(pi_0)
Q = compute_q_values(mdp, pi_k, num_states, num_actions)

for k in range(num_iterations):
    pi_k = update_policy(pi_k, Q, eta_k, num_states, num_actions, pi_0)
    print(pi_k)
    print(Q)
    Q = compute_q_values(mdp, pi_k, num_states, num_actions) 
    avg_reward = simulate_policy(mdp, pi_k, num_states, num_actions, timesteps=500)
    print(f"Iteration {k + 1}: Average Reward = {avg_reward:.4f}")

print("Final policy:", pi_k)
