[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rezer0dai/TD3_BC/blob/her/td3_bc_her.ipynb)

In [None]:
!git clone https://github.com/rezer0dai/TD3_BC -b her
!git clone https://github.com/sfujim/TD3
    
!git clone https://github.com/qgallouedec/panda-gym

!pip install -e panda-gym

In [None]:
import sys

libs = ["TD3_BC", "TD3", "panda-gym"]
for lib in libs:
    sys.path.append(lib)
    sys.path.append("/content/"+lib)

In [None]:
import numpy as np
import torch
import gym
import argparse
import os

import utils
import TD3_BC
import TD3
import OurDDPG
import config

In [None]:
def eval_policy(policy, eval_env, seed, normalize_state, seed_offset=100, eval_episodes=10):
    load_state = lambda obs: obs["observation"].reshape(1,-1)

    eval_env.seed(seed + seed_offset)

    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            state = load_state(state)
            action = policy.select_action(normalize_state(state))
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward


In [None]:
import random
from copy import deepcopy

def her(replay_buffer, achieved_goals):
    if not len(achieved_goals):
        return False
    leftover = 1 * (replay_buffer.size != len(achieved_goals))
    replay_buffer.ptr = replay_buffer.ptr - len(achieved_goals) - leftover
    replay_buffer.size = replay_buffer.size - len(achieved_goals) - leftover

    if all(np.linalg.norm(achieved_goals[0]- g) < .05 for g in achieved_goals):
        return False

    norm_ind = replay_buffer.ptr

    ep = deepcopy([ (
            replay_buffer.state[replay_buffer.ptr + i + leftover],
            replay_buffer.action[replay_buffer.ptr + i + leftover],
            replay_buffer.next_state[replay_buffer.ptr + i + leftover],
            replay_buffer.reward[replay_buffer.ptr + i + leftover],
            False
            ) for i in range(len(achieved_goals)) ])

    for _ in range(config.HER_PER_EP):
        ep_ = []
        for j, e in enumerate(ep):
            s, a, n, r, d = deepcopy(e)
            goal = random.choice(achieved_goals[j:])
            if random.random() < (.9 if (np.linalg.norm(achieved_goals[0] - goal) < .05) else 0.):
                continue
            
            #assert all(achieved_goals[j] == n[:config.GOAL_SIZE]), "A)failed with {} + {} [{}][{}] <{}>".format(
            #    j, len(ep), achieved_goals[j], n[:config.GOAL_SIZE], np.linalg.norm(achieved_goals[j] - n[:config.GOAL_SIZE])
            #    )
            
            s[-config.GOAL_SIZE:] = deepcopy(goal)
            n[-config.GOAL_SIZE:] = deepcopy(goal)
            
            r = -1. * (np.linalg.norm(achieved_goals[j] - goal) > .05)
            if -1 == r and random.random() < .8:
                continue
            replay_buffer.add(s, a, n, r, d)

        if random.random() < config.HER_RATIO:
            continue

        for e in ep:
            replay_buffer.add(*e)

    #assert all(replay_buffer.not_done[:replay_buffer.size])
    #print("\n diff", norm_ind, replay_buffer.ptr, replay_buffer.ptr-norm_ind, sum(0 == replay_buffer.reward[norm_ind:replay_buffer.ptr]))

    #assert all(replay_buffer.not_done[:replay_buffer.size])
    replay_buffer.add(s, a, n, r, True)
    if len(replay_buffer.state[norm_ind:replay_buffer.ptr]) > 1:# edge of buffer
        replay_buffer.normalize_state(replay_buffer.state[norm_ind:replay_buffer.ptr], update=True)

    return True

In [None]:
import random
from open_gym import make_env

if True:#__name__ == "__main__":
    file_name = f"{config.ENV}_{config.SEED}"
    print("---------------------------------------")
    print(f"Policy: , Env: {config.ENV}, Seed: {config.SEED}")
    print("---------------------------------------")

    env = make_env(config.ENV, render=False, colab=True)
    eval_env = make_env(config.ENV, render=True, colab=True)

    # Set seeds
    env.seed(config.SEED)
    env.action_space.seed(config.SEED)
    torch.manual_seed(config.SEED)
    np.random.seed(config.SEED)
    
    state_dim = env.state_size()
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])

    kwargs = { # let it default for td3 and td3+bc
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": config.DISCOUNT,
            "tau": config.TAU,
    }

    # Initialize policy
#    policy = TD3.TD3(**kwargs)
    policy = TD3_BC.TD3_BC(**kwargs)
#    policy = OurDDPG.DDPG(**kwargs)

    replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
    #replay_buffer.ptr = replay_buffer.size = replay_buffer.max_size - 10

    print("---------------------------------------")
    print(f"Policy TD3+BC+HER: , Env: {config.ENV}, Seed: {config.SEED}, Observation shape: {state_dim}")
    print("---------------------------------------")

    done = True
    achieved_goals = []
    load_state = lambda obs: obs["observation"].reshape(1,-1)

    t = 0
    add_prev_exp = 0
    total_steps = config.STEPS_PER_EPOCH * config.EPOCHS
    while t < total_steps:

        if done:
            add_prev_exp = her(replay_buffer, achieved_goals)
            achieved_goals = []
            mc_w = random.randint(1, 5) if random.random() > .3 else 0.
            state = load_state(env.reset())

        t += add_prev_exp

        if t < config.START_STEPS or (-1. == reward and random.random() < mc_w / 10.):
            action = env.action_space.sample()
        else:
            action = policy.select_action(replay_buffer.normalize_state(state))

        observation, reward, done, _ = env.step(action)
        next_state = load_state(observation)
        achieved_goals.append(observation["achieved_goal"])

        replay_buffer.add(state, action, next_state, reward, done)

        state = next_state

        if t > config.UPDATE_AFTER and 0 == t % config.UPDATE_EVERY:
            print("LEARN")
            for j in range(config.UPDATE_COUNT):
                policy.train(replay_buffer, config.BATCH_SIZE)
            policy.polyak()

        score = -100
        
        # Evaluate episode
        if (t + 1) % config.EVAL_FREQ == 0:
            print(f"Epochs : {(t + 1) / config.EVAL_FREQ}")
            score = eval_policy(policy, eval_env, config.SEED, replay_buffer.normalize_state)
        if score > -10.:
            break

In [None]:
[ eval_policy(policy, eval_env, random.randint(0, 100000), replay_buffer.normalize_state) for _ in range(10) ]