In [None]:
#Include this at the top of your colab code
import os
if not os.path.exists('.mujoco_setup_complete'):
  # Get the prereqs
  !apt-get -qq update
  !apt-get -qq install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libgl1-mesa-dev libglew-dev patchelf
  # Get Mujoco
  !mkdir ~/.mujoco
  !wget -q https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz
  !tar -zxf mujoco.tar.gz -C "$HOME/.mujoco"
  !rm mujoco.tar.gz
  # Add it to the actively loaded path and the bashrc path (these only do so much)
  !echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc 
  !echo 'export LD_PRELOAD=$LD_PRELOAD:/usr/lib/x86_64-linux-gnu/libGLEW.so' >> ~/.bashrc 
  # THE ANNOYING ONE, FORCE IT INTO LDCONFIG SO WE ACTUALLY GET ACCESS TO IT THIS SESSION
  !echo "/root/.mujoco/mujoco210/bin" > /etc/ld.so.conf.d/mujoco_ld_lib_path.conf
  !ldconfig
  # Install Mujoco-py
  !pip3 install -U 'mujoco-py<2.2,>=2.1'
  # run once
  !touch .mujoco_setup_complete

try:
  if _mujoco_run_once:
    pass
except NameError:
  _mujoco_run_once = False
if not _mujoco_run_once:
  # Add it to the actively loaded path and the bashrc path (these only do so much)
  try:
    os.environ['LD_LIBRARY_PATH']=os.environ['LD_LIBRARY_PATH'] + ':/root/.mujoco/mujoco210/bin'
    os.environ['LD_LIBRARY_PATH']=os.environ['LD_LIBRARY_PATH'] + ':/usr/lib/nvidia'
  except KeyError:
    os.environ['LD_LIBRARY_PATH']='/root/.mujoco/mujoco210/bin'
  try:
    os.environ['LD_PRELOAD']=os.environ['LD_PRELOAD'] + ':/usr/lib/x86_64-linux-gnu/libGLEW.so'
  except KeyError:
    os.environ['LD_PRELOAD']='/usr/lib/x86_64-linux-gnu/libGLEW.so'
  # presetup so we don't see output on first env initialization
  import mujoco_py
  _mujoco_run_once = True

In [None]:
!pip install git+https://github.com/tinkoff-ai/d4rl@master#egg=d4rl

In [None]:
import wandb
import torch
from torch import nn
from dataclasses import dataclass
from copy import deepcopy
from torch.nn import functional as F
import numpy as np
import gym
import d4rl
from typing import List, Tuple
import os
import random
from tqdm import tqdm

In [None]:
max_target_returns = {
    "halfcheetah-medium-replay-v0": 15.743,
    "halfcheetah-medium-v0": 15.743,
    "hopper-medium-replay-v0": 6.918,
    "hopper-medium-v0": 6.918,
    "walker2d-medium-replay-v0": 10.271,
    "walker2d-medium-v0": 10.271
}


@dataclass
class train_config:
    policy: str = "REDQ_BC"
    env: str = "hopper-medium-replay-v0" # [halfcheetah-medium-replay-v0 walker2d-medium-replay-v0]
    seed: int = 42
    eval_frequency: int = 5000
    max_timesteps: int = 250000
    pretrain_timesteps: int = 1000000
    num_updates: int = 10
    save_model: bool = True
    load_policy_path: str = ""
    episode_length: int = 1000
    exploration_noise: float = 0.1  # standard deviation of a gaussian devoted to the action space exploration noise
    batch_size: int = 256
    discount_factor: float = 0.99
    tau: float = 0.005  # see algo.jpeg in 'paper' folder
    policy_noise: float = 0.2
    noise_clip: float = 0.5
    policy_frequency: int = 2
    alpha: float = 0.4
    alpha_finetune: float = 0.4
    sample_method: str = "random"  # best
    sample_ratio: float = 0.05  # see algo.jpeg in 'paper' folder (ratio to keep offline data in replay buffer)
    minimize_over_q: bool = False  # if false, use randomized ensembles, else min Q values for steps, see eq3.PNG in 'paper' folder
    Kp: float = 0.00003 # see eq2.PNG in 'paper' folder
    Kd: float = 0.0001 # see eq2.PNG in 'paper' folder
    normalize_returns: bool = True  # if true, divide returns by a factor of a target return defined in 'max_target_returns' dataclass

cfg = train_config()

In [None]:
run_name = f"redq_bc_{cfg.env}_{cfg.seed}"

In [None]:
class Actor(nn.Module):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 max_action: float,
                 hidden_dim: int = 256) -> None:
        super().__init__()

        self.max_action = max_action
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()
        )
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.max_action * self.actor(state)
    
    @torch.no_grad()
    def act(self, state, device: str = "cpu") -> np.ndarray:
        state = state.reshape(1, -1)

        if not isinstance(state, torch.Tensor):
            state = torch.tensor(state, device=device, dtype=torch.float32)
        
        return self(state).cpu().data.numpy().flatten()


class EnsembleLinear(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 ensemble_size: int) -> None:
        super().__init__()

        self.ensemble_size = ensemble_size
        scale_factor = 2 * in_features ** 0.5

        self.weight = nn.Parameter(torch.zeros(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(ensemble_size, 1, out_features))

        nn.init.trunc_normal_(self.weight, std=1 / scale_factor)
    
    def forward(self, x: torch.Tensor):

        if len(x.shape) == 2:
            #print(x.shape, self.weight.shape)
            x = torch.einsum('ij,bjk->bik', x, self.weight)
        else:
            x = torch.einsum('bij,bjk->bik', x, self.weight)
        
        x = x + self.bias
        return x


class EnsembledCritic(nn.Module):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 hidden_dim: int = 256,
                 num_critics: int = 10) -> None:
        super().__init__()

        self.critics = nn.Sequential(
            EnsembleLinear(state_dim + action_dim, hidden_dim, ensemble_size=num_critics),
            nn.ReLU(),
            EnsembleLinear(hidden_dim, hidden_dim, ensemble_size=num_critics),
            nn.ReLU(),
            EnsembleLinear(hidden_dim, 1, ensemble_size=num_critics)
        )

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        # shape: (num_critics, batch, 1)
        concat = torch.cat([state, action], 1)
        #print(f"concat shape {concat.shape}")

        #print(self.critics(concat).shape)
        return self.critics(concat)


class RandomizedEnsembles_BC:
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 max_action: float,
                 discount_factor: float = 0.99,
                 tau: float = 0.005,
                 exploration_noise: float = 0.2,
                 noise_clip: float = 0.5,
                 policy_frequency: int = 2,
                 num_q_networks: int = 10,
                 alpha_finetune: float = 0.4,
                 pretrain: bool = False,
                 minimize_over_q: bool = False,
                 Kp: float = 0.00003,
                 Kd: float = 0.0001) -> None:
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = EnsembledCritic(state_dim, action_dim, num_critics=num_q_networks).to(device)
        self.critic_target = deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.max_action = max_action
        self.discount_factor = discount_factor
        self.tau = tau
        self.exploration_noise = exploration_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_frequency
        self.num_nets = num_q_networks
        self.alpha = alpha_finetune
        self.alpha_finetune = alpha_finetune
        self.pretrain = pretrain
        self.minimize_over_q = minimize_over_q
        self.kp = Kp
        self.kd = Kd
    
    def update_alpha(self,
                     episode_timesteps,
                     average_return,
                     current_return,
                     target_return: float = 1.05) -> None:
        # see eq2.PNG in 'paper' folder
        self.alpha += episode_timesteps * (self.kp * (average_return - target_return) + self.kd * max(0, average_return - current_return))
        self.alpha = max(0.0, min(self.alpha, self.alpha_finetune))
    
    def train(self, data):
        self.iteration = 1

        state, action, reward, next_state, done = data

        with torch.no_grad():
            noise = (torch.randn_like(action) * self.exploration_noise).clamp(-self.noise_clip, self.noise_clip)
            
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)

            if self.minimize_over_q and not self.pretrain:
                print(f"input shape: {next_state.shape}, {next_action.shape}")
                tgt_qs = self.critic_target(next_state, next_action)
                tgt_q, _ = torch.min(tgt_qs, dim=0)
            else:  # REDQ
                random_indexes = np.random.permutation(self.num_nets)
                tgt_qs = self.critic_target(next_state, next_action)[random_indexes]
                tgt_q1, tgt_q2 = tgt_qs[:2]
                tgt_q = torch.min(tgt_q1, tgt_q2)
            
            tgt_q = reward + (1 - done) * self.discount_factor * tgt_q
        
        current_qs = self.critic(state, action)

        critic_loss = F.mse_loss(current_qs.unsqueeze(0), tgt_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # policy update
        if not self.iteration % self.policy_freq:

            pi = self.actor(state)
            q = self.critic(state, pi).mean(0)
            
            actor_loss = -q.mean() / q.abs().mean().detach() + self.alpha * F.mse_loss(pi, action)

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            self.soft_update("actor")
            self.soft_update("critic")

        return {
            "critic_loss": critic_loss.item(),
            "critic_Qs": current_qs[0].mean().item()}

    
    def act(self, state):
        if len(state.shape) == 1:
            state = state.reshape(1, -1)
        
        if isinstance(state, np.ndarray):
            state = self.to_tensor(state, device=self.device)
        else:
            state = state.to(self.device)
        
        with torch.no_grad():
            action = self.actor(state)
        
        return action.cpu().data.numpy().flatten()
    
    def soft_update(self, regime):
        if regime == "actor":
            for param, tgt_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                tgt_param.data.copy_(self.tau * param.data + (1 - self.tau) * tgt_param.data)
        else:
            for param, tgt_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                tgt_param.data.copy_(self.tau * param.data + (1 - self.tau) * tgt_param.data)
    
    @staticmethod
    def to_tensor(data, device=None) -> torch.Tensor:
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        return torch.tensor(data, dtype=torch.float32, device=device)
    
    def save(self, filename):
        torch.save({
            "critic": self.critic.state_dict(),
            "critic_optimizer": self.critic_optimizer.state_dict(),
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict()
        }, filename + '_policy.pth')

    def load(self, filename):
        policy_dict = torch.load(filename + "_policy.pth")

        self.critic.load_state_dict(policy_dict["critic"])
        self.critic_optimizer.load_state_dict(policy_dict["critic_optimizer"])
        self.critic_target = deepcopy(self.critic)

        self.actor.load_state_dict(policy_dict["actor"])
        self.actor_optimizer.load_state_dict(policy_dict["actor_optimizer"])
        self.actor_target = deepcopy(self.actor)

In [None]:
class ReplayBuffer:
    data_size_threshold = 50000
    distill_methods = ["random", "best"]

    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 buffer_size: int = 1000000) -> None:
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.buffer_size = buffer_size
        self.pointer = 0
        self.size = 0

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        self.states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
        self.actions = torch.zeros((buffer_size, action_dim), dtype=torch.float32, device=device)
        self.rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self.next_states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
        self.dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)

        # i/o order: state, action, reward, next_state, done
    
    @staticmethod
    def to_tensor(data: np.ndarray, device=None) -> torch.Tensor:
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        return torch.tensor(data, dtype=torch.float32, device=device)
    
    def from_json(self, json_file):
        import json
        json_file = os.path.join("json_datasets", json_file)
        output = dict()

        with open(json_file) as f:
            dataset = json.load(f)
        
        for k, v in dataset.items():
            v = np.array(v)
            if k != "terminals":
                v = v.astype(np.float32)
            
            output[k] = v
        
        self.from_d4rl(output)
    
    def sample(self, batch_size: int):
        indexes = np.random.randint(0, self.size, size=batch_size)

        return (
            self.to_tensor(self.states[indexes], self.device),
            self.to_tensor(self.actions[indexes], self.device),
            self.to_tensor(self.rewards[indexes], self.device),
            self.to_tensor(self.next_states[indexes], self.device),
            self.to_tensor(self.dones[indexes], self.device)
        )
    
    def from_d4rl(self, dataset):
        if self.size:
            print("Warning: loading data into non-empty buffer")
        n_transitions = dataset["observations"].shape[0]

        if n_transitions < self.buffer_size:
            self.states[:n_transitions] = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
            self.actions[:n_transitions] = self.to_tensor(dataset["actions"][-n_transitions:], self.device)
            self.next_states[:n_transitions] = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
            self.rewards[:n_transitions] = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
            self.dones[:n_transitions] = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

        else:
            self.buffer_size = n_transitions

            self.states = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
            self.actions = self.to_tensor(dataset["actions"][-n_transitions:])
            self.next_states = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
            self.rewards = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
            self.dones = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)
        
        self.size = n_transitions
        self.pointer = n_transitions % self.buffer_size
    
    def normalize_states(self, eps=1e-3):
        mean = self.states.mean(0, keepdim=True)
        std = self.states.std(0, keepdim=True) + eps
        self.states = (self.states - mean) / std
        self.next_states = (self.next_states - mean) / std
        return mean, std
    
    def get_all(self):
        return (
            self.states[:self.size].to(self.device),
            self.actions[:self.size].to(self.device),
            self.rewards[:self.size].to(self.device),
            self.next_states[:self.size].to(self.device),
            self.dones[:self.size].to(self.device)
        )

    def add_transition(self,
                       state: torch.Tensor,
                       action: torch.Tensor,
                       reward: torch.Tensor,
                       next_state: torch.Tensor,
                       done: torch.Tensor):
        if not isinstance(state, torch.Tensor):
            state = self.to_tensor(state)
            action = self.to_tensor(action)
            reward = self.to_tensor(reward)
            next_state = self.to_tensor(next_state)
            done = self.to_tensor(done)


        self.states[self.pointer] = state
        self.actions[self.pointer] = action
        self.rewards[self.pointer] = reward
        self.next_states[self.pointer] = next_state
        self.dones[self.pointer] = done

        self.pointer = (self.pointer + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)
    
    def add_batch(self,
                  states: List[torch.Tensor],
                  actions: List[torch.Tensor],
                  rewards: List[torch.Tensor],
                  next_states: List[torch.Tensor],
                  dones: List[torch.Tensor]):
        for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
            self.add_transition(state, action, reward, next_state, done)
    
    def distill(self,
                dataset,
                env_name,
                sample_method,
                ratio=0.05):
        data_size = max(int(ratio * dataset["observations"].shape[0]), self.data_size_threshold)
        assert sample_method in self.distill_methods, "Unknown sample method"

        if sample_method == "random":
            indexes = np.random.randint(0, dataset["observations"].shape[0], size=data_size)
        if sample_method == "best":
            full_datas_size = dataset["observations"].shape[0]
            indexes = np.arange(full_datas_size - data_size)
        
        if data_size < self.buffer_size:
            self.states[:data_size] = self.to_tensor(dataset["observations"][indexes], self.device)
            self.actions[:data_size] = self.to_tensor(dataset["actions"][indexes], self.device)
            self.rewards[:data_size] = self.to_tensor(dataset["rewards"][indexes], self.device)
            self.next_states[:data_size] = self.to_tensor(dataset["next_observations"][indexes].reshape(-1, 1), self.device)
            self.dones[:data_size] = self.to_tensor(dataset["terminals"][indexes].reshape(-1, 1), self.device)
        else:
            self.buffer_size = data_size
            self.states = self.to_tensor(dataset["observations"][indexes], self.device)
            self.actions = self.to_tensor(dataset["actions"][indexes], self.device)
            self.rewards = self.to_tensor(dataset["rewards"][indexes], self.device)
            self.next_states = self.to_tensor(dataset["next_observations"][indexes].reshape(-1, 1), self.device)
            self.dones = self.to_tensor(dataset["terminals"][indexes].reshape(-1, 1), self.device)
        
        self.size = data_size
        self.pointer = data_size % self.buffer_size
    
    @staticmethod
    def dataset_stats(dataset):
        episode_returns = []
        returns = 0
        episode_length = 0

        for reward, done in zip(dataset["rewards"], dataset["terminals"]):
            if done:
                episode_returns.append(returns)
                returns = 0
                episode_length = 0
            else:
                episode_length += 1
                returns += reward
                if episode_length == 1000:
                    episode_returns.append(returns)
                    returns = 0
                    episode_length = 0

        episode_returns = np.array(episode_returns)
        return episode_returns.mean(), episode_returns.std()

In [None]:
policy = RandomizedEnsembles_BC()
policy.load(run_name)

In [None]:
env = gym.make(cfg.env)
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)

env.seed(cfg.seed)
env.action_space.seed(cfg.seed)
env.observation_space.seed(cfg.seed)


state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0] 
max_action = float(env.action_space.high[0])

state, done = env.reset(), False
episode_timesteps = 0
update_info, eval_info = {}, {}

buffer = ReplayBuffer(state_dim, action_dim, buffer_size=cfg.max_timesteps)
buffer.distill(d4rl.qlearning_dataset(env), cfg.env, cfg.sample_method, cfg.sample_ratio)

policy.alpha = cfg.alpha_finetune
policy.pretrain = False

In [None]:
def evaluate_policy(policy: RandomizedEnsembles_BC,
                    env_name: str,
                    seed=cfg.seed,
                    eval_episodes=10):
    
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 42)

    average_reward, average_length = 0, 0

    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False

        while not done:
            action = policy.act(state)
            state, reward, done, _ = eval_env.step(action)
            average_reward += reward
            average_length += 1
        
        average_reward /= eval_episodes
        average_length = int(average_length / eval_episodes)

        d4rl_score = eval_env.get_normalized_score(average_reward) * 100

        return {
            "d4rl": d4rl_score,
            "evaluation": average_reward,
            "length": average_length
        }

In [None]:

max_return = max_target_returns[cfg.env]
with wandb.init(project='adaptive_bc', group=cfg.env, job_type="finetune", name=run_name):
    wandb.config.update({k: v for k, v in cfg.__dict__.items() if not k.startswith('__')})

    episode_return = 0.0
    if cfg.normalize_returns:
        last_return = evaluate_policy(policy, cfg.env)["evaluation"] / max_return
    else:
        last_return = evaluate_policy(policy, cfg.env)["d4rl"] * 0.01
        
    current_return = last_return
    target_return = 1.05

    for timestep in tqdm(range(cfg.max_timesteps)):
        episode_timesteps += 1
        
        action = (policy.act(state) + np.random.normal(0, scale=cfg.exploration_noise, size=action_dim)).clip(-max_action, max_action)

        next_state, reward, done, _ = env.step(action)

        episode_return += reward

        done = float(done) if episode_timesteps < env._max_episode_steps else 0.0
        buffer.add_transition(state, action, reward, next_state, done)
        state = next_state

        for _ in range(cfg.num_updates):
            update_info = policy.train(buffer.sample(cfg.batch_size))
        
        update_info.update({'current_return': current_return, 'last_return': last_return})

        wandb.log({'online_training/': update_info,
                   'online_trainig/alpha': policy.alpha})
        
        if done:
            state, done = env.reset(), False

            if cfg.normalize_returns:
                current_return = episode_return / max_return
            else:
                current_return = env.get_normalized_score(episode_return)
            
            policy.update_alpha(episode_timesteps, last_return, current_return)

            episode_timesteps = 0
            episode_return = 0
        
        if not timestep % cfg.eval_frequency:
            eval_info = evaluate_policy(policy, cfg.env, cfg.seed)
            wandb.log({'online_evaluation/': eval_info})

            if cfg.save_model:
                policy.save(f"online_{run_name}")