In [None]:
#Include this at the top of your colab code
import os
if not os.path.exists('.mujoco_setup_complete'):
  # Get the prereqs
  !apt-get -qq update
  !apt-get -qq install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libgl1-mesa-dev libglew-dev patchelf
  # Get Mujoco
  !mkdir ~/.mujoco
  !wget -q https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz
  !tar -zxf mujoco.tar.gz -C "$HOME/.mujoco"
  !rm mujoco.tar.gz
  # Add it to the actively loaded path and the bashrc path (these only do so much)
  !echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc 
  !echo 'export LD_PRELOAD=$LD_PRELOAD:/usr/lib/x86_64-linux-gnu/libGLEW.so' >> ~/.bashrc 
  # THE ANNOYING ONE, FORCE IT INTO LDCONFIG SO WE ACTUALLY GET ACCESS TO IT THIS SESSION
  !echo "/root/.mujoco/mujoco210/bin" > /etc/ld.so.conf.d/mujoco_ld_lib_path.conf
  !ldconfig
  # Install Mujoco-py
  !pip3 install -U 'mujoco-py<2.2,>=2.1'
  # run once
  !touch .mujoco_setup_complete

try:
  if _mujoco_run_once:
    pass
except NameError:
  _mujoco_run_once = False
if not _mujoco_run_once:
  # Add it to the actively loaded path and the bashrc path (these only do so much)
  try:
    os.environ['LD_LIBRARY_PATH']=os.environ['LD_LIBRARY_PATH'] + ':/root/.mujoco/mujoco210/bin'
    os.environ['LD_LIBRARY_PATH']=os.environ['LD_LIBRARY_PATH'] + ':/usr/lib/nvidia'
  except KeyError:
    os.environ['LD_LIBRARY_PATH']='/root/.mujoco/mujoco210/bin'
  try:
    os.environ['LD_PRELOAD']=os.environ['LD_PRELOAD'] + ':/usr/lib/x86_64-linux-gnu/libGLEW.so'
  except KeyError:
    os.environ['LD_PRELOAD']='/usr/lib/x86_64-linux-gnu/libGLEW.so'
  # presetup so we don't see output on first env initialization
  import mujoco_py
  _mujoco_run_once = True

In [None]:
!pip install git+https://github.com/tinkoff-ai/d4rl@master#egg=d4rl

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions import Normal
from typing import Tuple, List, Optional
from math import sqrt, log
import numpy as np
import os
from dataclasses import dataclass
from copy import deepcopy
import d4rl
import gym
import random
from tqdm import trange
from imageio import mimsave

In [5]:
init_arguments = {
    "hopper-medium-replay": [11, 3, 6, 1.0],
    "hopper-medium": [11, 3, 6, 1.0],
    "walker2d-medium-replay": [17, 6, 12, 1.0],
    "walker2d-medium": [17, 6, 12, 1.0],
    "halfcheetah-medium-replay": [17, 6, 12, 1.0],
    "halfcheetah-medium": [17, 6, 12, 1.0]
}

def seed_everything(seed: int,
                    env: Optional[gym.Env] = None,
                    use_deterministic_algos: bool = False):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(use_deterministic_algos)
    random.seed(seed)

In [28]:
@dataclass
class spot_config:
    save_video: bool = False
    buffer_size: int = 1000000
    env: str = "hopper"  # halfcheetah walker2d hopper
    dataset: str = "medium-replay"  # medium, medium-replay
    version: str = "v0"
    env_name: str = f"{env}-{dataset}-{version}"
    seed: int = 0
    eval_frequency: int = 5000
    max_timesteps: int = 1000000
    save_model: bool = False
    save_final_model: bool = True
    eval_episodes: int = 10
    clip: bool = False
    exploration_noise: float = 0.1
    batch_size: int = 256
    discount_factor: float = 0.99
    tau: float = 0.005
    policy_noise: float = 0.2
    noise_clip: float = 0.5
    policy_frequency: int = 2
    lr: float = 3e-4
    actor_lr: float = None
    actor_hidden_dim: int = 256
    critic_hidden_dim: int = 256
    actor_dropout: float = 0.1
    alpha: float = 0.4
    normalize_env: bool = True
    #vae_model_path: str = os.path.join("spot", "weights", f"vae_{env}-{dataset}.pt")
    vae_path: str = f"vae_{env}-{dataset}.pt"
    beta: float = 0.5
    use_importance_sampling: bool = False
    num_samples: int = 1
    lambda_: float = 1.0
    with_q_norm: bool = True
    lambda_cool: float = False
    lambda_end: float = 0.2
    base_dir: str = "spot"
    weights_dir: str = "online_policy_weights"

cfg = spot_config()

In [7]:
#https://github.com/MishaLaskin/rad/blob/master/logger.py
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import json
import os
import shutil
import torch
import torchvision
import numpy as np
from termcolor import colored

FORMAT_CONFIG = {
    'rl': {
        'train': [
            ('episode', 'E', 'int'), ('step', 'S', 'int'),
            ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'),
            ('batch_reward', 'BR', 'float'), ('actor_loss', 'A_LOSS', 'float'),
            ('critic_loss', 'CR_LOSS', 'float')
        ],
        'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')]
    }
}


class AverageMeter(object):
    def __init__(self):
        self._sum = 0
        self._count = 0

    def update(self, value, n=1):
        self._sum += value
        self._count += n

    def value(self):
        return self._sum / max(1, self._count)


class MetersGroup(object):
    def __init__(self, file_name, formating):
        self._file_name = file_name
        if os.path.exists(file_name):
            os.remove(file_name)
        self._formating = formating
        self._meters = defaultdict(AverageMeter)

    def log(self, key, value, n=1):
        self._meters[key].update(value, n)

    def _prime_meters(self):
        data = dict()
        for key, meter in self._meters.items():
            if key.startswith('train'):
                key = key[len('train') + 1:]
            else:
                key = key[len('eval') + 1:]
            key = key.replace('/', '_')
            data[key] = meter.value()
        return data

    def _dump_to_file(self, data):
        with open(self._file_name, 'a') as f:
            f.write(json.dumps(data) + '\n')

    def _format(self, key, value, ty):
        template = '%s: '
        if ty == 'int':
            template += '%d'
        elif ty == 'float':
            template += '%.04f'
        elif ty == 'time':
            template += '%.01f s'
        else:
            raise 'invalid format type: %s' % ty
        return template % (key, value)

    def _dump_to_console(self, data, prefix):
        prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
        pieces = ['{:5}'.format(prefix)]
        for key, disp_key, ty in self._formating:
            value = data.get(key, 0)
            pieces.append(self._format(disp_key, value, ty))
        print('| %s' % (' | '.join(pieces)))

    def dump(self, step, prefix):
        if len(self._meters) == 0:
            return
        data = self._prime_meters()
        data['step'] = step
        self._dump_to_file(data)
        self._dump_to_console(data, prefix)
        self._meters.clear()


class Logger(object):
    def __init__(self, log_dir, use_tb=True, config='rl'):
        self._log_dir = log_dir
        if use_tb:
            tb_dir = os.path.join(log_dir, 'tb')
            if os.path.exists(tb_dir):
                shutil.rmtree(tb_dir)
            self._sw = SummaryWriter(tb_dir)
        else:
            self._sw = None
        self._train_mg = MetersGroup(
            os.path.join(log_dir, 'train.log'),
            formating=FORMAT_CONFIG[config]['train']
        )
        self._eval_mg = MetersGroup(
            os.path.join(log_dir, 'eval.log'),
            formating=FORMAT_CONFIG[config]['eval']
        )

    def _try_sw_log(self, key, value, step):
        if self._sw is not None:
            self._sw.add_scalar(key, value, step)

    def _try_sw_log_image(self, key, image, step):
        if self._sw is not None:
            assert image.dim() == 3
            grid = torchvision.utils.make_grid(image.unsqueeze(1))
            self._sw.add_image(key, grid, step)

    def _try_sw_log_video(self, key, frames, step):
        if self._sw is not None:
            frames = torch.from_numpy(np.array(frames))
            frames = frames.unsqueeze(0)
            self._sw.add_video(key, frames, step, fps=30)

    def _try_sw_log_histogram(self, key, histogram, step):
        if self._sw is not None:
            self._sw.add_histogram(key, histogram, step)

    def log(self, key, value, step, n=1):
        assert key.startswith('train') or key.startswith('eval')
        if type(value) == torch.Tensor:
            value = value.item()
        self._try_sw_log(key, value / n, step)
        mg = self._train_mg if key.startswith('train') else self._eval_mg
        mg.log(key, value, n)

    def log_param(self, key, param, step):
        self.log_histogram(key + '_w', param.weight.data, step)
        if hasattr(param.weight, 'grad') and param.weight.grad is not None:
            self.log_histogram(key + '_w_g', param.weight.grad.data, step)
        if hasattr(param, 'bias'):
            self.log_histogram(key + '_b', param.bias.data, step)
            if hasattr(param.bias, 'grad') and param.bias.grad is not None:
                self.log_histogram(key + '_b_g', param.bias.grad.data, step)

    def log_image(self, key, image, step):
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_image(key, image, step)

    def log_video(self, key, frames, step):
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_video(key, frames, step)

    def log_histogram(self, key, histogram, step):
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_histogram(key, histogram, step)

    def dump(self, step):
        self._train_mg.dump(step, 'train')
        self._eval_mg.dump(step, 'eval')

In [8]:
def make_dir(dir_path):
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


class ReplayBuffer:
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 buffer_size: int = 1000000) -> None:
        self.buffer_size = buffer_size
        self.pointer = 0
        self.size = 0

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        self.states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
        self.actions = torch.zeros((buffer_size, action_dim), dtype=torch.float32, device=device)
        self.rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self.next_states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
        self.dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)

        # i/o order: state, action, reward, next_state, done
    
    def from_json(self, json_file: str):
        import json

        if not json_file.endswith('.json'):
            json_file = json_file + '.json'

        json_file = os.path.join("json_datasets", json_file)
        output = dict()

        with open(json_file) as f:
            dataset = json.load(f)
        
        for k, v in dataset.items():
            v = np.array(v)
            if k != "terminals":
                v = v.astype(np.float32)
            
            output[k] = v
        
        self.from_d4rl(output)
    
    @staticmethod
    def to_tensor(data: np.ndarray, device=None) -> torch.Tensor:
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        return torch.tensor(data, dtype=torch.float32, device=device)
    
    def sample(self, batch_size: int):
        indexes = np.random.randint(0, self.size, size=batch_size)

        return (
            self.to_tensor(self.states[indexes], self.device),
            self.to_tensor(self.actions[indexes], self.device),
            self.to_tensor(self.rewards[indexes], self.device),
            self.to_tensor(self.next_states[indexes], self.device),
            self.to_tensor(self.dones[indexes], self.device)
        )
    
    def from_d4rl(self, dataset):
        if self.size:
            print("Warning: loading data into non-empty buffer")
        n_transitions = dataset["observations"].shape[0]

        if n_transitions < self.buffer_size:
            self.states[:n_transitions] = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
            self.actions[:n_transitions] = self.to_tensor(dataset["actions"][-n_transitions:], self.device)
            self.next_states[:n_transitions] = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
            self.rewards[:n_transitions] = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
            self.dones[:n_transitions] = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

        else:
            self.buffer_size = n_transitions

            self.states = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
            self.actions = self.to_tensor(dataset["actions"][-n_transitions:])
            self.next_states = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
            self.rewards = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
            self.dones = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)
        
        self.size = n_transitions
        self.pointer = n_transitions % self.buffer_size
    
    def from_d4rl_finetune(self, dataset):
        raise NotImplementedError()
    
    def normalize_states(self, eps=1e-3):
        mean = self.states.mean(0, keepdim=True)
        std = self.states.std(0, keepdim=True) + eps
        self.states = (self.states - mean) / std
        self.next_states = (self.next_states - mean) / std
        return mean, std
    
    def clip(self, eps=1e-5):
        self.action = torch.clip(self.action, - 1 + eps, 1 - eps)

    def add_transition(self,
                       state: torch.Tensor,
                       action: torch.Tensor,
                       reward: torch.Tensor,
                       next_state: torch.Tensor,
                       done: torch.Tensor):
        if not isinstance(state, torch.Tensor):
            state = self.to_tensor(state)
            action = self.to_tensor(action)
            reward = self.to_tensor(reward)
            next_state = self.to_tensor(next_state)
            done = self.to_tensor(done)


        self.states[self.pointer] = state
        self.actions[self.pointer] = action
        self.rewards[self.pointer] = reward
        self.next_states[self.pointer] = next_state
        self.dones[self.pointer] = done

        self.pointer = (self.pointer + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)
    
    def add_batch(self,
                  states: List[torch.Tensor],
                  actions: List[torch.Tensor],
                  rewards: List[torch.Tensor],
                  next_states: List[torch.Tensor],
                  dones: List[torch.Tensor]):
        for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
            self.add_transition(state, action, reward, next_state, done)
    
    @staticmethod
    def dataset_stats(dataset):
        episode_returns = []
        returns = 0
        episode_length = 0

        for reward, done in zip(dataset["rewards"], dataset["terminals"]):
            if done:
                episode_returns.append(returns)
                returns = 0
                episode_length = 0
            else:
                episode_length += 1
                returns += reward
                if episode_length == 1000:
                    episode_returns.append(returns)
                    returns = 0
                    episode_length = 0

        episode_returns = np.array(episode_returns)
        return episode_returns.mean(), episode_returns.std()


def train_val_split(replay_buffer: ReplayBuffer, val_size: float) -> Tuple[ReplayBuffer, ReplayBuffer]:
    data_size = replay_buffer.size
    val_size = int(data_size * val_size)

    permutation = torch.randperm(data_size)
        
    train_rb = ReplayBuffer(replay_buffer.state_dim, replay_buffer.action_dim)
    val_rb = ReplayBuffer(replay_buffer.state_dim, replay_buffer.action_dim)

    train_rb.add_batch(
        replay_buffer.states[permutation[val_size:]],
        replay_buffer.actions[permutation[val_size:]],
        replay_buffer.rewards[permutation[val_size:]],
        replay_buffer.next_states[permutation[val_size:]],
        replay_buffer.dones[permutation[val_size:]]
        )

    val_rb.add_batch(
        replay_buffer.states[permutation[:val_size]],
        replay_buffer.actions[permutation[:val_size]],
        replay_buffer.rewards[permutation[:val_size]],
        replay_buffer.next_states[permutation[:val_size]],
        replay_buffer.dones[permutation[:val_size]]
        )
        
    return train_rb, val_rb


class Actor(nn.Module):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 max_action: float = None,
                 dropout: float = None,
                 hidden_dim: int = 256,
                 uniform_initialization: bool = False) -> None:
        super().__init__()

        if dropout is None:
            dropout = 0
        self.max_action = max_action

        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        action = self.actor(state)

        if self.max_action is not None:
            return self.max_action * torch.tanh(action)
        return action


class Critic(nn.Module):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 hidden_dim: int = 256,
                 uniform_initialization: bool = False) -> None:
        super().__init__()

        self.q1_ = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.q2_ = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self,
                state: torch.Tensor,
                action: torch.Tensor):
        concat = torch.cat([state, action], 1)

        return self.q1_(concat), self.q2_(concat)
    
    def q1(self,
           state: torch.Tensor,
           action: torch.Tensor) -> torch.Tensor:

        return self.q1_(torch.cat([state, action], 1))


class ConditionalVAE(nn.Module):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 latent_dim: int,
                 max_action: int = None,
                 hidden_dim: int = 750,
                 device: torch.device = None) -> None:
        super().__init__()

        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.max_action = max_action
        self.latent_dim = latent_dim

        self.e1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.e2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.log_std = nn.Linear(hidden_dim, latent_dim)

        self.d1 = nn.Linear(state_dim + latent_dim, hidden_dim)
        self.d2 = nn.Linear(hidden_dim, hidden_dim)
        self.d3 = nn.Linear(hidden_dim, action_dim)
    
    def encode(self,
               state: torch.Tensor,
               action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        z = F.relu(self.e1(torch.cat([state, action], -1)))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        std = torch.exp(self.log_std(z).clamp(-4, 15))  # see __ in 'paper' folder
        return mean, std
    
    def decode(self,
               state: torch.Tensor,
               z: torch.Tensor = None) -> torch.Tensor:
        
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5)  # see __ in 'paper' folder
        
        action = F.relu(self.d1(torch.cat([state, z], -1)))
        action = F.relu(self.d2(action))
        action = self.d3(action)

        if self.max_action is not None:
            return self.max_action * torch.tanh(action)
        return action
    
    def forward(self,
                state: torch.Tensor,
                action: torch.Tensor):
        
        mean, std = self.encode(state, action)
        z = mean * std * torch.randn_like(std)

        return self.decode(state, z), mean, std
    
    def importance_sampling_loss(self,
                                 state: torch.Tensor,
                                 action: torch.Tensor,
                                 beta: float,
                                 num_samples: int = 10) -> torch.Tensor:
        # see eq8 in 'paper' folder
        mean, std = self.encode(state, action)
        
        mean = mean.repeat(num_samples, 1, 1).permute(1, 0, 2)
        std = std.repeat(num_samples, 1, 1).permute(1, 0, 2)
        z = mean + std * torch.randn_like(std)
        state = state.repeat(num_samples, 1, 1).permute(1, 0, 2)
        action = action.repeat(num_samples, 1, 1).permute(1, 0, 2)
        
        mean_decoded = self.decode(state, z)
        scale_factor = sqrt(beta) / 2

        log_prob_q_zx = Normal(loc=mean, scale=std).log_prob(z)
        mean_prior = torch.zeros_like(z).to(self.device)
        std_prior = torch.ones_like(z).to(self.device)
        log_prob_p_z = Normal(loc=mean_prior, scale=std_prior).log_prob(z)
        std_decoded = torch.ones_like(mean_decoded).to(self.device) * scale_factor
        log_prob_p_xz = Normal(loc=mean_decoded, scale=std_decoded).log_prob(action)

        w = log_prob_p_xz.sum(-1) + log_prob_p_z.sum(-1) - log_prob_q_zx.sum(-1)
        score = w.logsumexp(dim=-1) - log(num_samples)
        return -score
    
    def elbo_loss(self,
                  state: torch.Tensor,
                  action: torch.Tensor,
                  beta: float,
                  num_samples: int = 10) -> torch.Tensor:
        # see eq7 in 'paper' folder
        mean, std = self.encode(state, action)

        mean = mean.repeat(num_samples, 1, 1).permute(1, 0, 2)
        std = std.repeat(num_samples, 1, 1).permute(1, 0, 2)
        z = mean + std * torch.randn_like(std)
        state = state.repeat(num_samples, 1, 1).permute(1, 0, 2)
        action = action.repeat(num_samples, 1, 1).permute(1, 0, 2)

        decoded = self.decode(state, z)
        reconstruction_loss = ((decoded - action) ** 2).mean(dim=(1, 2))

        kl_loss = -1 / 2 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean(-1)
        loss = reconstruction_loss + beta * kl_loss
        return loss
    
    def load(self, filename):
        self.load_state_dict(torch.load(filename, map_location=self.device))


class SPOT:
    diverging_threshold = 1e4

    def __init__(self,
                 vae: ConditionalVAE,
                 state_dim: int,
                 action_dim: int,
                 max_action: float = None,
                 discount_factor: float = 0.99,
                 tau: float = 0.005,
                 policy_noise: float = 0.2,
                 noise_clip: float = 0.5,
                 policy_frequency: int = 2,
                 beta: float = 0.5,
                 lambda_: float = 1.0,
                 lr: float = 3e-4,
                 actor_lr: float = None,
                 with_q_norm: bool = True,
                 num_samples: int = 1,
                 use_importance_sampling: bool = False,
                 actor_hidden_dim: int = 256,
                 critic_hidden_dim: int = 256,
                 actor_dropout: float = 0.1,
                 actor_init_w: bool = False,
                 critic_init_w: bool = False,
                 lambda_cool: bool = False,
                 lambda_end: float = 0.2) -> None:
        
        self.iterations = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        self.actor = Actor(state_dim, action_dim, max_action, actor_dropout, actor_hidden_dim, actor_init_w).to(device)
        self.actor_target = deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr or actor_lr)

        self.critic = Critic(state_dim, action_dim, critic_hidden_dim, critic_init_w).to(device)
        self.critic_target = deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.discount_factor = discount_factor
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_frequency = policy_frequency
        self.vae = vae
        self.beta = beta
        self.num_samples = num_samples
        self.use_importance_sampling = use_importance_sampling
        self.with_q_norm = with_q_norm
        self.lambda_ = lambda_
        self.lambda_cool = lambda_cool
        self.lambda_end = lambda_end
    
    @staticmethod
    def to_tensor(data, device=None) -> torch.Tensor:
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return torch.tensor(data, dtype=torch.float32, device=device)
    
    @torch.no_grad()
    def act(self, state: np.ndarray) -> np.ndarray:
        self.actor.eval()
        state = self.to_tensor(state.reshape(1, -1), device=self.device)
        action = self.actor(state).cpu().data.numpy().flatten()
        self.actor.train()
        return action

    def soft_update(self, regime):
        if regime == "actor":
            for param, tgt_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                tgt_param.data.copy_(self.tau * param.data + (1 - self.tau) * tgt_param.data)
        else:
            for param, tgt_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                tgt_param.data.copy_(self.tau * param.data + (1 - self.tau) * tgt_param.data)
    
    def train(self,
              replay_buffer: ReplayBuffer,
              batch_size: int = 256,
              logger: Logger = None) -> None:
        self.iterations += 1
        
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)

            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)

            tgt_q1, tgt_q2 = self.critic_target(next_state, next_action)
            tgt_q = torch.min(tgt_q1, tgt_q2)

            tgt_q = reward + (1 - done) * self.discount_factor * tgt_q  # eq1 in 'paper' folder
        
        current_q1, current_q2 = self.critic(state, action)

        critic_loss = F.mse_loss(current_q1, tgt_q) + F.mse_loss(current_q2, tgt_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        if logger is not None:
            logger.log("train/critic_loss", critic_loss, self.iterations)
        
        if not self.iterations % self.policy_frequency:
            
            pi = self.actor(state)
            q = self.critic.q1(state, pi)
            
            if self.use_importance_sampling:
                density_estimator_loss = self.vae.importance_sampling_loss(state, pi, self.beta, self.num_samples)
            else:
                density_estimator_loss = self.vae.elbo_loss(state, pi, self.beta, self.num_samples)
            
            # see practical_algo.jpeg in 'paper' folder
            if self.with_q_norm:
                actor_loss = -q.mean() / q.abs().mean().detach() + self.lambda_ * density_estimator_loss.mean()
            else:
                actor_loss = -q.mean() + self.lambda_ * density_estimator_loss.mean()
            
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            if logger is not None:
                logger.log("train/Q", q.mean(), self.iterations)
                logger.log("train/actor_loss", actor_loss, self.iterations)
                logger.log("train/neg_log_beta", density_estimator_loss.mean(), self.iterations)
                logger.log("train/neg_log_beta_max", density_estimator_loss.max(), self.iterations)
            
            if q.mean().item() > self.diverging_threshold:
                exit()
            
            self.soft_update(regime="actor")
            self.soft_update(regime="critic")

    def train_online(self,
                     replay_buffer: ReplayBuffer,
                     batch_size: int = 256,
                     logger: Logger =None) -> None:
        self.iterations += 1
        
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)

            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)

            tgt_q1, tgt_q2 = self.critic_target(next_state, next_action)
            tgt_q = torch.min(tgt_q1, tgt_q2)

            tgt_q = reward + (1 - done) * self.discount_factor * tgt_q
        
        current_q1, current_q2 = self.critic(state, action)

        critic_loss = F.mse_loss(current_q1, tgt_q) + F.mse_loss(current_q2, tgt_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        if logger is not None:
            logger.log("train/critic_loss", critic_loss, self.iterations)
        
        if not self.iterations % self.policy_frequency:
            
            pi = self.actor(state)
            q = self.critic.q1(state, pi)
            
            if self.use_importance_sampling:
                density_estimator_loss = self.vae.importance_sampling_loss(state, pi, self.beta, self.num_samples)
            else:
                density_estimator_loss = self.vae.elbo_loss(state, pi, self.beta, self.num_samples)
            
            # additional component for online learning
            lambda_ = self.lambda_
            if self.lambda_cool:
                lambda_ = self.lambda_ * max(self.lambda_end, (1.0 - self.iterations / 1000000))

                if logger is not None:
                    logger.log("train/lambda_", lambda_, self.iterations)
            
            if self.with_q_norm:
                actor_loss = -q.mean() / q.abs().mean().detach() + lambda_ * density_estimator_loss.mean()
            else:
                actor_loss = -q.mean() + lambda_ * density_estimator_loss.mean()
            
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            if logger is not None:
                logger.log("train/Q", q.mean(), self.iterations)
                logger.log("train/actor_loss", actor_loss, self.iterations)
                logger.log("train/neg_log_beta", density_estimator_loss.mean(), self.iterations)
                logger.log("train/neg_log_beta_max", density_estimator_loss.max(), self.iterations)
            

            self.soft_update(regime="actor")
            self.soft_update(regime="critic")

    def save(self, model_dir):
        make_dir(model_dir)

        torch.save(self.critic.state_dict(), os.path.join(model_dir, f"critic_s{str(self.iterations)}.pth"))
        torch.save(self.critic_target.state_dict(), os.path.join(model_dir, f"critic_target_s{str(self.iterations)}.pth"))
        torch.save(self.critic_optimizer.state_dict(), os.path.join(
            model_dir, f"critic_optimizer_s{str(self.iterations)}.pth"))

        torch.save(self.actor.state_dict(), os.path.join(model_dir, f"actor_s{str(self.iterations)}.pth"))
        torch.save(self.actor_target.state_dict(), os.path.join(model_dir, f"actor_target_s{str(self.iterations)}.pth"))
        torch.save(self.actor_optimizer.state_dict(), os.path.join(
            model_dir, f"actor_optimizer_s{str(self.iterations)}.pth"))

    def load(self, model_dir, step=1000000):
        self.critic.load_state_dict(torch.load(os.path.join(model_dir, f"critic_s{str(step)}.pth")))
        self.critic_target.load_state_dict(torch.load(os.path.join(model_dir, f"critic_target_s{str(step)}.pth")))
        self.critic_optimizer.load_state_dict(torch.load(os.path.join(model_dir, f"critic_optimizer_s{str(step)}.pth")))

        self.actor.load_state_dict(torch.load(os.path.join(model_dir, f"actor_s{str(step)}.pth")))
        self.actor_target.load_state_dict(torch.load(os.path.join(model_dir, f"actor_target_s{str(step)}.pth")))
        self.actor_optimizer.load_state_dict(torch.load(os.path.join(model_dir, f"actor_optimizer_s{str(step)}.pth")))

In [9]:
class VideoRecorder:
    def __init__(self, dir_name, height=512, width=512, camera_id=0, fps=60):
        self.dir_name = dir_name
        self.height = height
        self.width = width
        self.camera_id = camera_id
        self.fps = fps
        self.frames = []

    def init(self, enabled=True):
        self.frames = []
        self.enabled = self.dir_name is not None and enabled

    def record(self, env: gym.Env):
        if self.enabled:
            frame = env.render(
                mode='rgb_array',
                height=self.height,
                width=self.width,
                # camera_id=self.camera_id
            )
            self.frames.append(frame)

    def save(self, file_name):
        if self.enabled:
            path = os.path.join(self.dir_name, file_name)
            mimsave(path, self.frames, fps=self.fps)

In [29]:
def eval_policy(cfg: spot_config,
                iteration: int,
                recorder: VideoRecorder,
                logger: Logger,
                policy: SPOT,
                env_name: str,
                seed: int,
                mean: np.ndarray,
                std: np.ndarray,
                eval_episodes: int = 10):
    env = gym.make(env_name)
    env.seed(seed + 100)

    lengths, returns, last_rewards = [], [], []
    average_reward = 0.0

    for episode in trange(eval_episodes):
        recorder.init(enabled=cfg.save_video)
        state, done = env.reset(), False
        
        #recorder.record(env)
        steps = 0
        episode_return = 0

        while not done:
            state = (np.array(state).reshape(1, -1) - mean) / std
            action = policy.act(state)

            state, reward, done, _ = env.step(action)
            recorder.record(env)

            average_reward += reward
            episode_return += reward
            steps += 1

        lengths.append(steps)
        returns.append(episode_return)
        last_rewards.append(reward)
        recorder.save(f"evaluation_{iteration}_episode{episode}_return_{episode_return}.mp4")
    
    average_reward /= eval_episodes
    d4rl_score = env.get_normalized_score(average_reward)

    if logger is not None:
        logger.log('eval/lengths_mean', np.mean(lengths), iteration)
        logger.log('eval/lengths_std', np.std(lengths), iteration)
        logger.log('eval/returns_mean', np.mean(returns), iteration)
        logger.log('eval/returns_std', np.std(returns), iteration)
        logger.log('eval/d4rl_score', d4rl_score, iteration)
    
    return d4rl_score


def train_policy(cfg=spot_config()):
    video_dir = os.path.join(cfg.base_dir, "video")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    weights_dir = os.path.join(cfg.base_dir, cfg.weights_dir)
    make_dir(weights_dir)
    make_dir(video_dir)

    env = gym.make(cfg.env_name)

    seed_everything(cfg.seed)

    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    
    #state_dim, action_dim, max_action = parse_json_dataset(cfg.env_name)

    vae = ConditionalVAE(state_dim, action_dim, action_dim * 2, max_action).to(device)
    !wget "https://github.com/zzmtsvv/rl_task/raw/main/spot/vae_weights/{cfg.vae_path}"
    vae.load(cfg.vae_path)
    vae.eval()

    policy = SPOT(vae,
                  state_dim=state_dim,
                  action_dim=action_dim,
                  max_action=max_action,
                  discount_factor=cfg.discount_factor,
                  tau=cfg.tau,
                  policy_noise=cfg.policy_noise,
                  noise_clip=cfg.noise_clip,
                  policy_frequency=cfg.policy_frequency,
                  beta=cfg.beta,
                  lambda_=cfg.lambda_,
                  lr=cfg.lr,
                  actor_lr=cfg.actor_lr,
                  with_q_norm=cfg.with_q_norm,
                  num_samples=cfg.num_samples,
                  use_importance_sampling=cfg.use_importance_sampling,
                  actor_hidden_dim=cfg.actor_hidden_dim,
                  actor_dropout=cfg.actor_dropout)
    
    model_dir = os.path.join("spot_weights", f"{cfg.env}_policy_weights")
    policy.load(model_dir)
    
    replay_buffer = ReplayBuffer(state_dim, action_dim, buffer_size=cfg.buffer_size)
    replay_buffer.from_d4rl(d4rl.qlearning_dataset(env))
    #replay_buffer.from_json(cfg.env_name)
    #assert replay_buffer.size + cfg.max_timesteps <= replay_buffer.buffer_size

    mean, std = 0, 1
    if cfg.normalize_env:
        mean, std = replay_buffer.normalize_states()
        mean, std = mean.cpu().numpy(), std.cpu().numpy()
    
    if cfg.clip:
        replay_buffer.clip()
    
    logger = Logger(os.path.join(cfg.base_dir, "runs"), use_tb=True)
    recorder = VideoRecorder(video_dir)

    for timestep in trange(cfg.max_timesteps):
        policy.train_online(replay_buffer, batch_size=cfg.batch_size, logger=logger)

        if not (timestep + 1) % cfg.eval_frequency:
            d4rl_score = eval_policy(cfg,
                                     timestep + 1,
                                     recorder,
                                     logger,
                                     policy,
                                     cfg.env_name,
                                     cfg.seed,
                                     mean,
                                     std,
                                     cfg.eval_episodes)
            
            if cfg.save_model:
                policy.save(weights_dir)
    
    if cfg.save_final_model:
        policy.save(weights_dir)
    
    logger._sw.close()

In [None]:
!cp /content/drive/MyDrive/spot_weights.zip /content/
!unzip -u /content/spot_weights.zip

In [None]:
train_policy()