### Small example of the REINFORCE algorithm on a GridPath problem

We first create a simple enviroment, a grid word with boundaries plus obstacles (X). We want to reach the goal (G) from the current position of the agent (A).

In [2]:
from RLTools.Envs.GridEnv import GridEnv

env = GridEnv(size=5)
env.set_obstacles([(1, 1), (2, 2), (3, 3)])
env.render(mode='human')
env.close()

 . . . . 
 .X. . . 
 . .X.A. 
 . . .X. 
 . . . .G



In [3]:
import torch
from tqdm import tqdm
from RLTools.RLPolicies.REINFORCE import DiscreteREINFORCE

input_size = 2
hidden_size = 4
output_size = 4

network = torch.nn.Sequential(
    torch.nn.Linear(input_size, hidden_size),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(hidden_size, output_size),
    torch.nn.Softmax(dim=-1)
)

policy = DiscreteREINFORCE(network)

optimizer = torch.optim.Adam(policy.parameters(), 0.01)

### --- training loop

for _ in tqdm(range(100000)):
    observation, info = env.reset()

    episode_over = False

    rewards = []
    actions = []
    observations = [observation]
    log_probs = []
    
    counter = 0
    while not episode_over:
        action, log_prob = policy.sample_training(observation)
        actions.append(action)
        log_probs.append(log_prob)

        observation, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)
        # print(observation, action, reward)

        counter += 1
        episode_over = terminated or truncated
        if not episode_over:
            observations.append(observation)

    # if sum(rewards) > 0:
    #     print(sum(rewards), len(rewards))

    loss = policy.reward(log_probs, rewards)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

env.close()


 23%|██▎       | 22685/100000 [01:06<03:47, 340.30it/s]


KeyboardInterrupt: 

### Visualization

In [4]:
env = GridEnv(size=5) # new env
env.set_obstacles([(1, 1), (2, 2), (3, 3)])

for _ in range(1):

    observation, info = env.reset(state=(0, 0))


    episode_over = False

    rewards = []
    actions = []
    observations = [observation]

    counter = 0
    while not episode_over:
        env.render()

        #action = env.action_space.sample()  # agent policy that uses the observation and info
        action = policy.sample_best(observation)
        actions.append(action)
        print(env.translate_action_to_human(action))


        observation, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)
        print(sum(rewards))
        counter +=1
        episode_over = terminated or truncated 
        if not episode_over:
            observations.append(observation)

env.render()
env.close()

A. . . . 
 .X. . . 
 . .X. . 
 . . .X. 
 . . . .G

down
-0.01
 . . . . 
A.X. . . 
 . .X. . 
 . . .X. 
 . . . .G

down
-0.02
 . . . . 
 .X. . . 
A. .X. . 
 . . .X. 
 . . . .G

down
-0.03
 . . . . 
 .X. . . 
 . .X. . 
A. . .X. 
 . . . .G

down
-0.04
 . . . . 
 .X. . . 
 . .X. . 
 . . .X. 
A. . . .G

right
-0.05
 . . . . 
 .X. . . 
 . .X. . 
 . . .X. 
 .A. . .G

right
-0.060000000000000005
 . . . . 
 .X. . . 
 . .X. . 
 . . .X. 
 . .A. .G

right
-0.07
 . . . . 
 .X. . . 
 . .X. . 
 . . .X. 
 . . .A.G

right
999.93
 . . . . 
 .X. . . 
 . .X. . 
 . . .X. 
 . . . .G

