<a href="https://colab.research.google.com/github/whsu00/project/blob/master/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
!rm -rf .config .*
!git clone https://github.com/whsu00/project .
!apt-get install -y libglu1-mesa-dev libgl1-mesa-dev libosmesa6-dev xvfb ffmpeg curl patchelf libglfw3 libglfw3-dev cmake zlib1g zlib1g-dev swig libosmesa6-dev libgl1-mesa-glx libglfw3
!pip install mujoco_py==2.0.2.8 mpi4py

!wget https://whsu00.com/mujoco.zip
!unzip mujoco.zip
!rm mujoco.zip
!mv .mujoco ~

rm: refusing to remove '.' or '..' directory: skipping '.'
rm: refusing to remove '.' or '..' directory: skipping '..'
fatal: destination path '.' already exists and is not an empty directory.
Reading package lists... Done
Building dependency tree       
Reading state information... Done
libglu1-mesa-dev is already the newest version (9.0.0-2.1build1).
zlib1g is already the newest version (1:1.2.11.dfsg-0ubuntu2).
zlib1g-dev is already the newest version (1:1.2.11.dfsg-0ubuntu2).
libglfw3 is already the newest version (3.2.1-1).
libglfw3-dev is already the newest version (3.2.1-1).
patchelf is already the newest version (0.9-1).
swig is already the newest version (3.0.12-1).
cmake is already the newest version (3.10.2-1ubuntu2.18.04.1).
curl is already the newest version (7.58.0-2ubuntu3.10).
libgl1-mesa-dev is already the newest version (20.0.8-0ubuntu1~18.04.1).
libgl1-mesa-glx is already the newest version (20.0.8-0ubuntu1~18.04.1).
libosmesa6-dev is already the newest version (20.0

In [53]:
 # VALOR implementation 
import numpy as np 
import torch
import torch.nn.functional as F 
import gym 
import time
import scipy.signal
from network import Discriminator, ActorCritic, count_vars
from buffer import Buffer
from torch.distributions.categorical import Categorical
# from utils.mpi_tools import mpi_fork, proc_id, mpi_statistics_scalar, num_procs
# from utils.mpi_torch import average_gradients, sync_all_params
from utils.logx import EpochLogger

def valor(env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), disc=Discriminator, dc_kwargs=dict(), seed=0, episodes_per_epoch=40,
        epochs=50, gamma=0.99, pi_lr=3e-4, vf_lr=1e-3, dc_lr=5e-4, train_v_iters=80, train_dc_iters=10, train_dc_interv=10, 
        lam=0.97, max_ep_len=1000, logger_kwargs=dict(), con_dim=4, max_context_dim = 64, save_freq=10, k=1):

    '''
    Notes: Discriminator is the decoder
    TODO: 1) Go through policy update
        2) Write out flow to yourself
        3) Implement Curriculum Learning
        4) Scrutinize possible mistakes (why is logp instead of log_gt stored in a lot of places)
            a) Why isn't LSTM Policy used in encoder?
            b) Why is policy loss the way it is?
            c) Why is expectation taken the way that it is in decoder?

    '''
    '''
    Major Bug:
    Buffer gets overriden every 5 timesteps -> because indeces get reset everytime you pull from the buffer
    '''

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    ac_kwargs['action_space'] = env.action_space

    # Model
    actor_critic = actor_critic(input_dim=obs_dim[0]+con_dim, **ac_kwargs)
    disc = disc(input_dim=obs_dim[0], context_dim=con_dim, **dc_kwargs)

    # Buffer
    local_episodes_per_epoch = episodes_per_epoch # int(episodes_per_epoch / num_procs())
    buffer = Buffer(con_dim, obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len, train_dc_interv)

    # Count variables
    var_counts = tuple(count_vars(module) for module in
        [actor_critic.policy, actor_critic.value_f, disc.policy])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n'%var_counts)    

    # Optimizers
    #Optimizer for RL Policy
    train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)

    #Optimizer for value function (for actor-critic)
    train_v = torch.optim.Adam(actor_critic.value_f.parameters(), lr=vf_lr)

    #Optimizer for decoder
    train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr)

    # Parameters Sync
    # sync_all_params(actor_critic.parameters())
    # sync_all_params(disc.parameters())

    '''
    Training function
    '''
    def update(e):
        obs, act, adv, pos, ret, logp_old = [torch.Tensor(x) for x in buffer.retrieve_all()]
        
        # Policy
        _, logp, _ = actor_critic.policy(obs, act)
        entropy = (-logp).mean()

        # Policy loss
        pi_loss = -(logp*(k*adv+pos)).mean()

        # Train policy (Go through policy update)
        train_pi.zero_grad()
        pi_loss.backward()
        # average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function
        v = actor_critic.value_f(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            v = actor_critic.value_f(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function train
            train_v.zero_grad()
            v_loss.backward()
            # average_gradients(train_v.param_groups)
            train_v.step()

        # Discriminator
        if (e+1) % train_dc_interv == 0:
            print('Discriminator Update!')
            #pdb.set_trace()
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            _, logp_dc, _ = disc(s_diff, con)
            d_l_old = -logp_dc.mean()

            # Discriminator train
            for _ in range(train_dc_iters):
                _, logp_dc, _ = disc(s_diff, con)
                d_loss = -logp_dc.mean()
                train_dc.zero_grad()
                d_loss.backward()
                # average_gradients(train_dc.param_groups)
                train_dc.step()

            _, logp_dc, _ = disc(s_diff, con)
            dc_l_new = -logp_dc.mean()
        else:
            d_l_old = 0
            dc_l_new = 0

        # Log the changes
        _, logp, _, v = actor_critic(obs, act)
        pi_l_new = -(logp*(k*adv+pos)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()
        logger.store(LossPi=pi_loss, LossV=v_l_old, KL=kl, Entropy=entropy, DeltaLossPi=(pi_l_new-pi_loss),
            DeltaLossV=(v_l_new-v_l_old), LossDC=d_l_old, DeltaLossDC=(dc_l_new-d_l_old))
        # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist())

    start_time = time.time()
    #Resets observations, rewards, done boolean
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    #Creates context distribution where each logit is equal to one (This is first place to make change)
    init_context_prob_arr = con_dim * [1/con_dim]
    context_dist = Categorical(probs=torch.Tensor(init_context_prob_arr))
    total_t = 0

    for epoch in range(epochs):
        #Sets actor critic and decoder (discriminator) into eval mode
        actor_critic.eval()
        disc.eval()
        
        #Runs the policy local_episodes_per_epoch before updating the policy
        for index in range(local_episodes_per_epoch):
            # Sample from context distribution and one-hot encode it (Step 2)
            # Every time we run the policy we sample a new context

            if epoch == 52:
                pdb.set_trace()
    
            c = context_dist.sample()
            c_onehot = F.one_hot(c, con_dim).squeeze().float()
            for _ in range(max_ep_len):
                concat_obs = torch.cat([torch.Tensor(o.reshape(1, -1)), c_onehot.reshape(1, -1)], 1)
                '''
                Feeds in observation and context into actor_critic which spits out a distribution 
                Label is a sample from the observation
                pi is the action sampled
                logp is the log probability of some other action a
                logp_pi is the log probability of pi 
                v_t is the value function
                '''
                a, _, logp_t, v_t = actor_critic(concat_obs)

                #Stores context and all other info about the state in the buffer
                buffer.store(c, concat_obs.squeeze().detach().numpy(), a.detach().numpy(), r, v_t.item(), logp_t.detach().numpy())
                logger.store(VVals=v_t)

                o, r, d, _ = env.step(a.detach().numpy()[0])
                ep_ret += r
                ep_len += 1
                total_t += 1

                terminal = d or (ep_len == max_ep_len)
                if terminal:
                    # Key stuff with discriminator
                    dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0)
                    #Context
                    con = torch.Tensor([float(c)]).unsqueeze(0)
                    #Feed in differences between each state in your trajectory and a specific context
                    #Here, this is just the log probability of the label it thinks it is
                    _, _, log_p = disc(dc_diff, con)
                    buffer.end_episode(log_p.detach().numpy())
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, [actor_critic, disc], None)

        # Sets actor_critic and discriminator into training mode
        actor_critic.train()
        disc.train()

        update(epoch)
        #Need to implement curriculum learning here to update context distribution 
        ''' 
            #Psuedocode:
            Loop through each of d episodes taken in local_episodes_per_epoch and check log probability from discrimantor
            If >= 0.86, increase k in the following manner: k = min(int(1.5*k + 1), Kmax)
            Kmax = 64
        '''
        if (epoch + 1 )% train_dc_interv == 0 and epoch > 0:
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            print("Context: ",  con)
            print("State Diffs", s_diff)
            print("num_contexts", len(con))
            _, logp_dc, _ = disc(s_diff, con)
            log_p_context_sample = logp_dc.mean().detach().numpy()

            print("Log Probability context sample", log_p_context_sample)

            decoder_accuracy = np.exp(log_p_context_sample)
            print("Decoder Accuracy", decoder_accuracy)

            if decoder_accuracy >= 0.86:
                new_context_dim = min(int(1.5*context_dim + 1), max_context_dim)
                print("new_context_dim: ", new_context_dim)
                new_context_prob_arr = new_context_dim * [1/new_context_dim]
                context_dist = Categorical(probs= new_context_prob_arr )
                context_dim = new_context_dim

            buffer.clear_dc_buff()


        # Log
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', total_t)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('LossDC', average_only=True)
        logger.log_tabular('DeltaLossDC', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('Time', time.time()-start_time)
        logger.dump_tabular()

In [61]:
#@title Run
env = "HalfCheetah-v2" #@param {type:"string"}
hid = 64 #@param {type:"integer"}
l = 2 #@param {type:"integer"}
gamma = 0.97 #@param {type:"number"}
seed = 0 #@param {type:"integer"}
episodes = 40 #@param {type:"integer"}
epochs = 1000 #@param {type:"integer"}
exp_name = "valor" #@param {type:"string"}
con = 5 #@param {type:"integer"}


from utils.run_utils import setup_logger_kwargs
logger_kwargs = setup_logger_kwargs(exp_name, seed)

valor(lambda: gym.make(env), actor_critic=ActorCritic, ac_kwargs=dict(hidden_dims=[hid]*l),
    disc=Discriminator, dc_kwargs=dict(hidden_dims=hid),
    gamma=gamma, seed=seed, episodes_per_epoch=episodes, epochs=epochs, logger_kwargs=logger_kwargs, con_dim=con)


[32;1mLogging data to /content/data/valor/valor_s0/progress.txt[0m
[36;1mSaving config:
[0m
{
    "ac_kwargs":	{
        "hidden_dims":	[
            64,
            64
        ]
    },
    "actor_critic":	"ActorCritic",
    "con_dim":	5,
    "dc_kwargs":	{
        "hidden_dims":	64
    },
    "dc_lr":	0.0005,
    "disc":	"Discriminator",
    "env_fn":	"<function <lambda> at 0x7feed08e26a8>",
    "episodes_per_epoch":	40,
    "epochs":	1000,
    "exp_name":	"valor",
    "gamma":	0.97,
    "k":	1,
    "lam":	0.97,
    "logger":	{
        "<utils.logx.EpochLogger object at 0x7feed08f7518>":	{
            "epoch_dict":	{},
            "exp_name":	"valor",
            "first_row":	true,
            "log_current_row":	{},
            "log_headers":	[],
            "output_dir":	"/content/data/valor/valor_s0",
            "output_file":	{
                "<_io.TextIOWrapper name='/content/data/valor/valor_s0/progress.txt' mode='w' encoding='UTF-8'>":	{
                    "mode":	"w"
   

Exception: ignored