In [None]:
import torch
import numpy as np
from vctr.pg.model import PolicyNetwork

num_episodes = 1000
policy = PolicyNetwork()

for episode in range(num_episodes):
    # Reset the environment
    state = env.reset()
    states, actions, rewards = [], [], []

    done = False
    while not done:
        # Convert the state to a PyTorch tensor
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

        # Get action probabilities from the policy network
        action_probs = policy(state_tensor).detach().numpy().squeeze()

        # Sample an action based on the probabilities
        action = np.random.choice(len(action_probs), p=action_probs)

        # Take the action in the environment
        next_state, reward, done, _ = env.step(action)

        # Store the state, action, and reward
        states.append(state)
        actions.append(action)
        rewards.append(reward)

        # Update the state
        state = next_state

    # Calculate rewards-to-go
    rewards_to_go = []
    for t in range(len(rewards)):
        reward_sum = sum(rewards[t:] * np.power(gamma, np.arange(len(rewards) - t)))
        rewards_to_go.append(reward_sum)

    # Convert states, actions, and rewards-to-go to PyTorch tensors
    states_tensor = torch.tensor(states, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long)
    rewards_to_go_tensor = torch.tensor(rewards_to_go, dtype=torch.float32)

    # Zero the gradients
    optimizer.zero_grad()

    # Compute the loss
    loss = compute_loss(states_tensor, actions_tensor, rewards_to_go_tensor)

    # Perform a backward pass
    loss.backward()

    # Update the policy network parameters
    optimizer.step()

    # Print the episode and loss for monitoring
    if episode % 100 == 0:
        print(f'Episode {episode}, Loss: {loss.item()}')
