### CartPole solved with a simple deep REINFORCE implementation

In [None]:
import torch
from tqdm import tqdm
import gymnasium as gym
from RLTools.RLPolicies.REINFORCE import DiscreteREINFORCE
# torch based implementation

input_size = 4
hidden_size = 4
output_size = 2

env = gym.make('CartPole-v1')

network = torch.nn.Sequential(
    torch.nn.Linear(input_size, hidden_size),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(hidden_size, output_size),
    torch.nn.Softmax(dim=-1)
)

policy = DiscreteREINFORCE(network) # DNN implementation of REINFORCE's policy

optimizer = torch.optim.Adam(policy.parameters(), 0.01)

### --- traing loop 

for _ in tqdm(range(1000)):
    observation, info = env.reset()
    
    episode_over = False

    rewards = []
    actions = []
    observations = [observation]
    log_probs = []
    
    counter = 0
    while not episode_over:
        action, log_prob = policy.sample_training(observation)
        actions.append(action)
        log_probs.append(log_prob)

        observation, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)

        counter += 1
        episode_over = terminated or truncated
        if not episode_over:
            observations.append(observation)

    loss = policy.reward(log_probs, rewards)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

env.close()

### Visualization

In [None]:

env = gym.make('CartPole-v1', render_mode="human")

for _ in range(10):
    observation, info = env.reset()


    episode_over = False

    rewards = []
    actions = []
    observations = [observation]

    counter = 0
    while not episode_over:

        #action = env.action_space.sample()  # agent policy that uses the observation and info
        action = policy.sample_best(observation)
        actions.append(action)


        observation, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)
        # print(counter)
        counter +=1
        episode_over = counter > 100 #terminated or truncated#counter > 100 #truncated #terminated # or truncated
        if not episode_over:
            observations.append(observation)

    #gradients = policy.policy_gradient(actions, rewards, observations)
    #policy.update(0.01, gradients)

env.close()