In [None]:
##



import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Define the MDP class
class MDP:
    def __init__(self, num_states, num_actions, transition_probabilities=None, reward=None):
        self.num_states = num_states
        self.num_actions = num_actions
        
        if transition_probabilities is None:
            # Randomly generate transition probabilities if not provided
            self.transition_probabilities = np.random.rand(num_states, num_actions, num_states)
            self.transition_probabilities /= self.transition_probabilities.sum(axis=2, keepdims=True)
        else:
            self.transition_probabilities = transition_probabilities
            
        if reward is None:
            self.rewards = np.random.rand(num_states, num_actions)
        else:
            self.rewards = reward
        
    def visualize_mdp(self):
        G = nx.DiGraph()  # Use DiGraph to avoid multi-edges
        edge_labels = {}

        for s in range(self.num_states):
            for a in range(self.num_actions):
                for s_prime in range(self.num_states):
                    prob = self.transition_probabilities[s, a, s_prime]
                    if prob > 0.05:  # only include significant transitions
                        # LaTeX format for action and edge label
                        label = r"$a_{" + str(a) + r"}$: P=" + f"{prob:.2f}, R={self.rewards[s, a]:.2f}"
                        if (f"S{s}", f"S{s_prime}") in edge_labels:
                            edge_labels[(f"S{s}", f"S{s_prime}")] += f"\n{label}"
                        else:
                            edge_labels[(f"S{s}", f"S{s_prime}")] = label
                        G.add_edge(f"S{s}", f"S{s_prime}", action=f"a_{{{a}}}")

        pos = nx.spring_layout(G)  # Position nodes for better visibility
        plt.figure(figsize=(8, 6))
        
        # Draw nodes with LaTeX formatting
        labels = {f"S{s}": r"$s_{" + str(s) + "}$" for s in range(self.num_states)}
        nx.draw(G, pos, with_labels=True, labels=labels, node_size=3000, node_color='lightblue', font_size=10, font_weight='bold', arrows=True)
        
        # Draw edge labels including actions, probabilities, and rewards in LaTeX format
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
        
        # Add action labels as well on the edges
        edge_action_labels = nx.get_edge_attributes(G, 'action')
        nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrowsize=20, connectionstyle="arc3,rad=0.2")
        
        plt.title("MDP Visualization with LaTeX-style Labels")
        plt.show()

# Get user inputs
num_states = int(input("Enter the number of states: "))
num_actions = int(input("Enter the number of actions: "))

# Get user input for transition probabilities
user_input_transition = input("Do you want to provide a transition probability matrix? (yes/no): ").strip().lower()
transition_probabilities = None

if user_input_transition == "yes":
    transition_probabilities = np.zeros((num_states, num_actions, num_states))
    for s in range(num_states):
        for a in range(num_actions):
            print(f"Enter transition probabilities for state {s}, action {a}:")
            for s_prime in range(num_states):
                transition_probabilities[s, a, s_prime] = float(input(f"P(s'={s_prime} | s={s}, a={a}): "))
            # Normalize the probabilities
            transition_probabilities[s, a] /= transition_probabilities[s, a].sum()
else:
    print("Random transition probabilities will be generated.")

# Optionally allow user to input rewards
user_input_reward = input("Do you want to provide a reward matrix? (yes/no): ").strip().lower()
reward = None
if user_input_reward == "yes":
    reward = np.zeros((num_states, num_actions))
    for s in range(num_states):
        for a in range(num_actions):
            reward[s, a] = float(input(f"Enter the reward for state {s}, action {a}: "))

# Create the MDP
mdp = MDP(num_states, num_actions, transition_probabilities, reward)

# Visualize the MDP
mdp.visualize_mdp()
