In [3]:
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

# --- env ---
from envs import MultiAgentMazeEnv

# --- policies ---
from policies import SoftmaxPolicy
from utils.init_policy import initialize_policy_with_manual_probs   # ðŸ‘ˆ REQUIRED

# --- training ---
from training import reinforce_multi_rwd2go_alt_barrier

H0, W0 = 5, 5
A = 4
eps = 1e-3
pad = 2
# new padded size
H = H0 + 2 * pad
W = W0 + 2 * pad

# Define your action probabilities manually for each cell
probs = np.zeros((H0, W0, A))
probs[0, 0] = [0.0, 1/2, 1/2, 0.0]   # up, right, down, left
probs[0, 1] = [0.0, 1/3, 1/3, 1/3]   # up, right, down, left
probs[0, 2] = [0.0, 0.0, 1/2, 1/2]   # up, right, down, left
probs[0, 3] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left
probs[0, 4] = [0.0, 0.0, 1.0, 0.0]   # up, right, down, left

probs[1, 0] = [1/3, 1/3, 1/3, 0.0]   # up, right, down, left
probs[1, 1] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left
probs[1, 2] = [1/3, 0.0, 1/3, 1/3]    # up, right, down, left
probs[1, 3] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left
probs[1, 4] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left

probs[2, 0] = [1/3, 1/3, 1/3, 0.0]   # up, right, down, left
probs[2, 1] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left
probs[2, 2] = [1/3, 0.0, 1/3, 1/3]    # up, right, down, left
probs[2, 3] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left
probs[2, 4] = [1/2, 0.0, 1/2, 0.0]   # up, right, down, left

probs[3, 0] = [1/3, 1/3, 1/3, 0.0]   # up, right, down, left
probs[3, 1] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left
probs[3, 2] = [1/3, 0.0, 1/3, 1/3]    # up, right, down, left
probs[3, 3] = [1/4, 1/4, 1/4, 1/4]   # up, right, down, left
probs[3, 4] = [1/2, 0.0, 1/2, 0.0]   # up, right, down, left

probs[4, 0] = [1/2, 1/2, 0.0, 0.0]   # up, right, down, left
probs[4, 1] = [1/3, 1/3, 0.0, 1/3]   # up, right, down, left
probs[4, 2] = [1/3, 1/3, 0.0, 1/3]    # up, right, down, left
probs[4, 3] = [0.0, 1/2, 0.0, 1/2]   # up, right, down, left
probs[4, 4] = [1/2, 0.0, 0.0, 1/2]   # up, right, down, left

# initialize padded grid with uniform wall policy
probs_padded = np.ones((H, W, A)) / A

# copy original policy into the center
probs_padded[pad:pad+H0, pad:pad+W0, :] = probs

# numerical safety (optional but consistent)
probs_padded = np.maximum(probs_padded, eps)
probs_padded = probs_padded / probs_padded.sum(axis=2, keepdims=True)

outer_walls = []
for i in range(H):
    for j in range(W):
        if i < pad or i >= H - pad or j < pad or j >= W - pad:
            outer_walls.append((i, j)) 

inner_walls = [(0,3), (1,3), (2,3), (3,3)]
# shift internal walls into padded grid
inner_walls = [(x + pad, y + pad) for (x, y) in inner_walls]


# Initialize env with 2 agents
env = MultiAgentMazeEnv(
        size=(9,9),
        starts=[(2,3)],
        goals=[(3,6)],
        inner_walls=inner_walls,
        outer_walls=outer_walls
    )
env.render()

policies = [SoftmaxPolicy(width=W, height=H, num_actions=A) for _ in range(env.n_agents)]
optimizers = [torch.optim.Adam(p.parameters(), lr=0.1) for p in policies]

initialize_policy_with_manual_probs(policies[0], probs_padded)
scores, barrier, violation = reinforce_multi_rwd2go_alt_barrier(env, policies, optimizers)

fig, ax1 = plt.subplots()
ax2 = ax1.twinx()

for i in range(len(barrier)):
    ax1.plot(barrier[i], label=f"Agent {i} Penalty", linestyle='-')
    ax2.plot(violation[i], label=f"Agent {i} Violations", linestyle='--', color='red')

ax1.set_xlabel("Episode")
ax1.set_ylabel("Mean Barrier Penalty", color='tab:blue')
ax2.set_ylabel("Violation Number", color='tab:red')
ax1.legend(loc='upper left')
ax2.legend(loc='upper right')
plt.title("Barrier Penalty and Constraint Violations Over Training")
plt.show()

# --- Plot Expected Reward (Value Function) ---
plt.figure(figsize=(8,5))

for i in range(len(V)):
    plt.plot(V[i], label=f"Agent {i} Value", linewidth=2)

plt.xlabel("Batch Number")
plt.ylabel("Estimated Expected Return V(s)")
plt.title("Expected Return (Value Function) Over Training")
plt.legend()
plt.grid(True)

plt.show()

# # # # # # # # #
# # # # # # # # #
# #   0   #   # #
# #       # [92m0[0m # #
# #       #   # #
# #       #   # #
# #           # #
# # # # # # # # #
# # # # # # # # #
----------


NameError: name 'torch' is not defined