In [None]:
import gym
import pybullet_envs
import torch as T
import torch.nn as nn
import pickle
from torch.optim import Adam
import itertools
import numpy as np
import scipy.signal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

import dmc2gym
import os

In [None]:
def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


LOG_STD_MAX = 2
LOG_STD_MIN = -20

class SquashedGaussianMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding 
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290) 
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class MLPActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()

In [None]:
def get_action(o, deterministic=False):
    #return test_env.action_space.sample()
    return ac.act(T.as_tensor(o, dtype=T.float32), deterministic)

seeds = ['123', '666', '742', '637', '4637']

In [None]:
test_env = dmc2gym.make(domain_name='walker', task_name='stand')
print(test_env.observation_space.shape, test_env.action_space.shape)

for i in range(0,5):
    dir_path = "data/sac_dmWalkerStand_256_retroloss_2/sac_dmWalkerStand_256_retroloss_2_s" + seeds[i]

    ac = T.load(dir_path+"/pyt_save/model.pt")

    ac.eval()

    max_ep_len = 1000

    trajs = []

    for i in range(100):
        o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
        while not(d or (ep_len == max_ep_len)):
            # Take deterministic actions at test time 
            prev_o = o
            action = get_action(o, True)
            o, r, d, _ = test_env.step(action)
            trajs.append([prev_o, action, o])
            ep_ret += r
            ep_len += 1
        print(ep_ret)
    with open(dir_path+'/demos.pkl','wb') as f:
        pickle.dump(trajs, f)

#print(trajs)

In [None]:
for k in range(0,5):
    dir_path = "data/sac_dmWalkerStand_256_retroloss_2/sac_dmWalkerStand_256_retroloss_2_s" + seeds[k]
    trajs_final = pickle.load(open(dir_path+'/demos.pkl', 'rb'))
    
    obss = []
    acts = []
    for traj in trajs_final:
        obs = traj[0]
        act = traj[1]
        obss.append(obs)
        acts.append(act)

    obss = np.array(obss)
    acts = np.array(acts)

    idxs = np.random.randint(0, len(obss), size=500)
    batch = dict(obss_sample=obss[idxs], acts_sample=acts[idxs])
    data = {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}
    
    o, a = data['obss_sample'], data['acts_sample']
    q1 = ac.q1(o,a)
    q2 = ac.q2(o,a)
    
    ac_kwargs = {}
    lr = 0.001

    dir_path2 = "data/sac_dmWalkerStand_256_retroloss_2/sac_dmWalkerStand_256_retroloss_2_s" + seeds[k]

    q_losses = []
    
    
    i = 10000
    while True:
        print(i)
        if i==2000000:
            break

        ac2 = MLPActorCritic(test_env.observation_space, test_env.action_space, **ac_kwargs)
        pi_optimizer = Adam(ac2.pi.parameters(), lr=lr)
        q_params = itertools.chain(ac2.q1.parameters(), ac2.q2.parameters())
        q_optimizer = Adam(q_params, lr=lr)

        model_path = "/checkpoints/model_checkpoint_{}.tar".format(i)
        checkpoint = T.load(dir_path2+model_path)
        ac2.load_state_dict(checkpoint['ac_state_dict'])
        pi_optimizer.load_state_dict(checkpoint['pi_optimizer_state_dict'])
        q_optimizer.load_state_dict(checkpoint['q_optimizer_state_dict'])
        ac2.eval()

        q1_2 = ac2.q1(o,a)
        q2_2 = ac2.q2(o,a)

        loss_q1 = ((q1 - q1_2)**2).mean()
        loss_q2 = ((q2 - q2_2)**2).mean()
        loss_q = loss_q1 + loss_q2

        print(loss_q.item())
        q_losses.append(loss_q.item())
        i = i+10000
    with open(dir_path2 +'/q_losses.pkl','wb') as f:
        pickle.dump(q_losses, f)
        
        
    dir_path3 = "data/sac_dmWalkerStand_256/sac_dmWalkerStand_256_s" + seeds[k]

    q_losses = []
    
    
    i = 10000
    while True:
        print(i)
        if i==2000000:
            break

        ac2 = MLPActorCritic(test_env.observation_space, test_env.action_space, **ac_kwargs)
        pi_optimizer = Adam(ac2.pi.parameters(), lr=lr)
        q_params = itertools.chain(ac2.q1.parameters(), ac2.q2.parameters())
        q_optimizer = Adam(q_params, lr=lr)

        model_path = "/checkpoints/model_checkpoint_{}.tar".format(i)
        checkpoint = T.load(dir_path3+model_path)
        ac2.load_state_dict(checkpoint['ac_state_dict'])
        pi_optimizer.load_state_dict(checkpoint['pi_optimizer_state_dict'])
        q_optimizer.load_state_dict(checkpoint['q_optimizer_state_dict'])
        ac2.eval()

        q1_2 = ac2.q1(o,a)
        q2_2 = ac2.q2(o,a)

        loss_q1 = ((q1 - q1_2)**2).mean()
        loss_q2 = ((q2 - q2_2)**2).mean()
        loss_q = loss_q1 + loss_q2

        print(loss_q.item())
        q_losses.append(loss_q.item())
        i = i+10000
    with open(dir_path3+'/q_losses.pkl','wb') as f:
        pickle.dump(q_losses, f)

