### DDPG on the CartPole problem

We consider a variant of the CartPole problem with continuos action space.

In [None]:
import torch
from tqdm import tqdm
import gymnasium as gym
from torch.utils.tensorboard import SummaryWriter
from RLTools.RLPolicies.DDPG import DDPG
from RLTools.Utils.ReplayBuffer import ReplayBuffer

# We also show how to make tensorboard logs!
import datetime
log_dir = 'logs/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/"

env = gym.make('InvertedPendulum-v5', reset_noise_scale=0.1)

state_size = 4
action_size = 1
hidden_size_p = 2
hidden_size_q = 16
output_size = 1

p_net = torch.nn.Sequential(
    torch.nn.Linear(state_size, hidden_size_p),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(hidden_size_p, output_size),
    # torch.nn.Softmax(dim=-1)
)
q_net = torch.nn.Sequential(
    torch.nn.Linear(state_size+action_size, hidden_size_q),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(hidden_size_q, output_size)
    # torch.nn.Softplus()
)

policy = DDPG(q_net, p_net, device='cuda')

optimizer = torch.optim.Adam(policy.parameters(), 0.0001)
writer = SummaryWriter(log_dir)

n_episodes = 100000
n_rollouts = 5
batch_size = 256
training_iters = 5

rb = ReplayBuffer() # for this algorithm we rely on a replay buffer

### --- training loop
# 
# Note that here the training is more complex as we need to implement soft updates for the V and P networks
# for which we need to compute separate losses carefully handling parameter freezing 

for episode in tqdm(range(n_episodes)):

    for rollout in range(n_rollouts):


        observation, info = env.reset()
        episode_over = False

        rewards = []
        actions = []
        observations = [observation]
        log_probs = []
        
        while not episode_over: # unroll
            action = policy.sample(observation, training=True)

            # clip in -3,+3
            noise = torch.normal(mean=0, std=1., size=action.shape).to(action.device)
            action = torch.clip(action+noise, -3., 3.) 
            

            actions.append(action)

            observation, reward, terminated, truncated, info = env.step((action.cpu().detach().squeeze(0),))
            rewards.append(reward)

            episode_over = terminated or truncated
            if not episode_over:
                observations.append(observation)

        writer.add_scalar("reward", sum(rewards), episode*n_rollouts + rollout)
        #rewards = torch.flip(torch.cumsum(torch.flip(torch.tensor(rewards), (0,)), 0), (0, ))
        # print(rewards)
        
        rb.add(observations, actions, rewards)

    for training_iter in range(training_iters):
        batch = rb.sample(batch_size)

        bstate = torch.tensor([b[0] for b in batch], dtype=torch.float32).to(policy.device)
        baction = torch.tensor([b[1] for b in batch], dtype=torch.float32).to(policy.device).unsqueeze(1)
        bnew_state = torch.tensor([b[2] for b in batch], dtype=torch.float32).to(policy.device)
        breward = torch.tensor([b[3] for b in batch], dtype=torch.float32).to(policy.device)
        bis_terminal = torch.tensor([b[4] for b in batch], dtype=torch.float32).to(policy.device)

        optimizer.zero_grad()

        policy.freeze_p()
        q_loss = policy.q_loss(bstate, baction, bnew_state, breward, bis_terminal).sum()/batch_size
        q_loss.backward()
        policy.unfreeze_p()
        
        policy.freeze_q()
        p_loss = policy.p_loss(bstate, baction, bnew_state, breward, bis_terminal).sum()/batch_size
        p_loss.backward()
        policy.unfreeze_q()

        torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=0.5) # grad clipping

        optimizer.step()

        writer.add_scalar("Ploss", -p_loss.item(), episode*training_iters + training_iter)
        writer.add_scalar("Qloss", q_loss.item(), episode*training_iters + training_iter)

        policy.soft_update()


env.close()

### Visualization

In [3]:

env = gym.make('InvertedPendulum-v5', render_mode="human", reset_noise_scale=0.3)

for _ in range(10):
    observation, info = env.reset()

    episode_over = False

    rewards = []
    actions = []
    observations = [observation]

    counter = 0
    while not episode_over:

        #action = env.action_space.sample()  # agent policy that uses the observation and info
        action = policy.sample(observation)
        action = torch.clip(action, -3., 3.)
        actions.append(action)
        
        observation, reward, terminated, truncated, info = env.step((action[0].cpu(),))
        rewards.append(reward)
        # print(counter)
        counter +=1
        episode_over = counter > 200 #terminated or truncated#counter > 100 #truncated #terminated # or truncated
        if not episode_over:
            observations.append(observation)

env.close()



KeyboardInterrupt: 