[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rezer9dai/blob/her/td3_bc_her.ipynb)

In [None]:
!git clone https://github.com/rezer0dai/TD3_BC -b her

!git clone https://github.com/qgallouedec/panda-gym

!pip install -e panda-gym
!pip install dataclasses

In [3]:
import sys

libs = ["TD3_BC", "panda-gym"]
for lib in libs:
    sys.path.append(lib)
    sys.path.append("/content/"+lib)

In [4]:
import numpy as np
import torch
import gym
import argparse
import os

import utils
import TD3_BC

In [5]:
def eval_policy(policy, eval_env, seed, mean, std, seed_offset=100, eval_episodes=10):
    load_state = lambda obs: (obs["observation"].reshape(1,-1) - mean)/std

    eval_env.seed(seed + seed_offset)

    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            state = load_state(state)
            action = policy.select_action(state)
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward


In [6]:
import random
def her(replay_buffer):
    ep = []
    i = replay_buffer.ptr

    if not i:
        return

    while replay_buffer.not_done[i-len(ep)-2]:
        ep.append([
            replay_buffer.state[i-len(ep)-1], replay_buffer.action[i-len(ep)-1], replay_buffer.next_state[i-len(ep)+1], replay_buffer.reward[i-len(ep)+1], 0 == len(ep)
            ])

    for _ in range(20):
        ep_ = []
        for j, e in enumerate(ep[1:]):
            s, a, n, r, d = e
            s, n = s.copy(), n.copy()
            s[-3:] = n[-3:] = random.choice(ep[:j+1])[0][:3]
            r = -1. * (np.linalg.norm(n[:3] - s[-3:]) > .05)
            ep_.append([s, a, n, r, d])

        for e in ep_:
            replay_buffer.add(*e)
    replay_buffer.add(*ep[0])


In [None]:
from config import Config
from open_gym import make_env

if __name__ == "__main__":
    cfg = Config()
    
    file_name = f"{cfg.env}_{cfg.seed}"
    print("---------------------------------------")
    print(f"Policy: , Env: {cfg.env}, Seed: {cfg.seed}")
    print("---------------------------------------")

    env = make_env(cfg.env, render=True, colab=True)

    # Set seeds
    env.seed(cfg.seed)
    env.action_space.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    
    state_dim = env.state_size()
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])

    kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": cfg.discount,
            "tau": cfg.tau,
            # TD3
            "policy_noise": cfg.policy_noise * max_action,
            "noise_clip": cfg.noise_clip * max_action,
            "policy_freq": cfg.policy_freq,
            # TD3 + BC
            "alpha": cfg.alpha
    }

    # Initialize policy
    policy = TD3_BC.TD3_BC(**kwargs)

    replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
    if cfg.normalize:
        mean,std = replay_buffer.normalize_states() 
    else:
        mean,std = 0,1
	

    print("---------------------------------------")
    print(f"Policy TD+BC+HER: , Env: {cfg.env}, Seed: {cfg.seed}, Observation shape: {state_dim}")
    print("---------------------------------------")

    load_state = lambda obs: (obs["observation"].reshape(1,-1) - mean)/std

    done = True
    total_steps = cfg.steps_per_epoch * cfg.epochs
    for t in range(total_steps):

        if done:
            her(replay_buffer)
            state = load_state(env.reset())

        if t > cfg.start_steps:
            action = policy.select_action(state)
        else:
            action = env.action_space.sample()

        observation, reward, done, _ = env.step(action)
        next_state = load_state(observation)

        replay_buffer.add(state, action, next_state, reward, done)

        state = next_state

        if t > cfg.update_after and 0 == t % cfg.update_every:
            for j in range(cfg.update_every):
                policy.train(replay_buffer, cfg.batch_size)

        # Evaluate episode
        if (t + 1) % cfg.eval_freq == 0:
            print(f"Time steps: {t+1}")
            eval_policy(policy, env, cfg.seed, mean, std)