In [None]:
import gymnasium as gym
import time
import torch
import ppo

In [None]:
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=True)
print(env.action_space)
print(env.observation_space)

In [None]:
state_dim = env.observation_space.n
action_dim = env.action_space.n
lr_actor = 1e-3
lr_critic = 1e-3
gamma = 0.99
K_epochs = 5
eps_clip = 0.2
has_continuous_action_space = False
action_std_init = 1

ppo_agent = ppo.PPO(
    state_dim, action_dim, 
    lr_actor, lr_critic, 
    gamma, K_epochs, eps_clip, 
    has_continuous_action_space, 
    action_std_init)

In [None]:
time_step = 0
max_training_timesteps = 1000
max_ep_len = 100
update_timestep = 5
print_freq = 10

while time_step <= max_training_timesteps:
    observation, info = env.reset()
    current_ep_reward = 0

    for t in range(1, max_ep_len+1):
        
        # select action with policy
        obs = torch.nn.functional.one_hot(torch.LongTensor([observation]), num_classes=state_dim)
        action = ppo_agent.select_action(obs.float().squeeze(0))
        observation, reward, terminated, truncated, info = env.step(action)
        
        # saving reward and is_terminals
        ppo_agent.buffer.rewards.append(reward)
        ppo_agent.buffer.is_terminals.append(terminated)
        
        time_step += 1
        current_ep_reward += reward

        # update PPO agent
        if time_step % update_timestep == 0:
            ppo_agent.update()
            
        # break; if the episode is over
        if terminated or truncated:
            break
env.close()