In [8]:
import os
import time
import torch
from torch import nn, optim
import argparse
import numpy as np
from collections import deque
import gym

In [21]:
import torch
import os
import datetime
import numpy as np
torch.manual_seed(42)
np.random.seed(42)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
HOME = os.path.expanduser('~')
TIME = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
SAVEPATH = '/home/adryw/dataset/ehdqn/ckpts/'
SAVEPATH = '../ckpts/' if not os.path.isdir(SAVEPATH) else SAVEPATH
SAVEPATH = os.path.join(SAVEPATH, TIME)
LOGPATH = HOME + '/dataset/ehdqn/logs/' + TIME
#LOGPATH = '../logs/' + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
class Setting:
    def __init__(self) -> None:
        pass
sett = Setting()
sett.device = device
sett.HOME = HOME
sett.SAVEPATH = SAVEPATH
sett.LOGPATH = LOGPATH

In [22]:
import numpy as np
import collections

class Memory:
    def __init__(self, max_memory):
        self.max_memory = max_memory
        self.state = []
        self.new_state = []
        self.action = []
        self.reward = []
        self.is_terminal = []
        self.idx = 0

    def __len__(self):
        return len(self.state)

    def store_transition(self, s, s1, a, r, is_terminal):
        if len(self.state) <= self.max_memory:
            self.state.append(s)
            self.new_state.append(s1)
            self.action.append(a)
            self.reward.append(r)
            self.is_terminal.append(is_terminal)
        else:
            self.state[self.idx] = s
            self.new_state[self.idx] = s1
            self.action[self.idx] = a
            self.reward[self.idx] = r
            self.is_terminal[self.idx] = is_terminal
            self.idx = (self.idx + 1) % self.max_memory
        assert len(self.state) == len(self.new_state) == len(self.reward) == len(self.is_terminal) == len(self.action)


    def clear_memory(self):
        del self.state[:]
        del self.new_state[:]
        del self.action[:]
        del self.reward[:]
        del self.is_terminal[:]

    def sample(self, bs):
        idx = np.random.randint(len(self.state), size=bs)
        state, new_state, action, reward, is_terminal = [], [], [], [], []
        for i in idx:
            state.append(self.state[i])
            new_state.append(self.new_state[i])
            action.append(self.action[i])
            reward.append(self.reward[i])
            is_terminal.append(int(self.is_terminal[i]))
        return state, new_state, action, reward, is_terminal

    def update(self, **kwargs):
        pass

In [23]:
class ICM_Model(nn.Module):
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

        self.phi = nn.Sequential(
            nn.Linear(self.state_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.Relu()
        )

        out_shape = 256

        # Forward Model
        self.fwd = nn.Sequential(
            nn.Linear(out_shape + 1, 256),
            nn.ReLU(),
            nn.Linear(256, out_shape)
        )

        # Inverse Model
        self.inv = nn.Sequential(
            nn.Linear(out_shape * 2, 256),
            nn.ELU(),
            nn.Linear(256, action_size)
        )
        
    def forward(self, *input):
        obs, action = input
        action = action.view(-1, 1)
        phi = self.phi_state(obs)
        x = torch.cat((phi, action.float()), -1)
        phi_hat = self.fwd(x)
        return phi_hat

    def phi_state(self, s):
        s = s[:, -1]
        x = s.float().transpose(1, 3)
        x = self.phi(x)
        return x.view(x.size(0), -1)

    def inverse_pred(self, s, s1):
        s = self.phi_state(s.float())
        s1 = self.phi_state(s1.float())
        x = torch.cat((s, s1), -1)
        return self.inv(x)

    def curiosity_rew(self, s, s1, a):
        phi_hat = self.forward(s, a)
        phi_s1 = self.phi_state(s1)
        cur_rew = 1 / 2 * (torch.norm(phi_hat - phi_s1, p=2, dim=-1) ** 2)
        return cur_rew


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


class DDQN_Model(nn.Module):
    def __init__(self, state_size, action_size, conv, macro=None, hidd_ch=256, conv_ch=32):
        super(DDQN_Model, self).__init__()
        self.action_size = action_size
        self.hidd_ch = hidd_ch
        self.state_size = state_size
        if macro is None:
            self.backbone = nn.Sequential(
                nn.Linear(state_size, hidd_ch),
                nn.ReLU(),
            )

        self.features = nn.Sequential(
            nn.Linear(state_size, hidd_ch),
            nn.ReLU(),
            #nn.Linear(hidd_ch, hidd_ch),
            #nn.ReLU()
        )

        self.advantage = nn.Sequential(
            nn.Linear(hidd_ch, self.action_size)
        )

        self.value = nn.Sequential(
            nn.Linear(hidd_ch, 1)
        )

    def forward(self, obs, macro=None):
        if obs.ndimension() == 4:
            obs = obs[None]
        obs = obs.float().transpose(2, 4)
        stack = obs.shape[1]
        backbone = self.backbone if macro is None else macro.backbone
        x = torch.cat([self.features(backbone(obs[:, i]))[:, None] for i in range(stack)], dim=1)
        x = x.view(x.size(0), stack, -1)

        adv = self.advantage(x)
        value = self.value(x)
        return value + (adv - adv.mean(-1, keepdim=True))

    def act(self, state, eps, backbone=None):
        if np.random.random() > eps:
            q = self.forward(state, backbone)
            action = torch.argmax(q, dim=-1).cpu().data.numpy()
        else:
            action = np.random.randint(self.action_size, size=1 if len(state.shape) == 1 else state.shape[0])
        return action.item() if action.shape == (1,) else list(action.astype(np.int))

    def update_target(self, model):
        self.load_state_dict(model.state_dict())

In [None]:
import numpy as np
import torch
from torch.nn.functional import mse_loss, cross_entropy, smooth_l1_loss, softmax
import itertools
import os


class EHDQN:
    def __init__(self, state_dim, tau, action_dim, gamma, n_subpolicy, max_time, hidd_ch, lam, lr, eps,
                 eps_decay, eps_sub, eps_sub_decay, beta, bs, target_interval, train_steps, max_memory, max_memory_sub,
                 conv, gamma_macro, reward_rescale, n_proc, per=False, norm_input=True, logger=None):
        """
        :param state_dim: Shape of the state
        :param float tau: Weight for agent loss
        :param gamma_macro: Discount for macro controller
        :param int action_dim: Number of actions
        :param float gamma: Discount for sub controller
        :param int n_subpolicy: Number of sub policies
        :param int max_time: Number of steps for each sub policy
        :param int hidd_ch: Number of hidden channels
        :param float lam: Scaler for ICM reward
        :param float lr: Learning rate
        :param float eps: Eps greedy chance for macro policy
        :param float eps_decay: Epsilon decay computed as eps * (1 - eps_decay) each step
        :param float eps_sub: Eps greedy change for sub policies
        :param float eps_sub_decay: Epsilon decay for sub policy computed as eps * (1 - eps_decay) each step
        :param float beta: Weight for loss of fwd net vs inv net
        :param int bs: Batch size
        :param int target_interval: Number of train steps between target updates
        :param int train_steps: Number of training iterations for each call
        :param int max_memory: Max memory
        :param bool conv: Use or not convolutional networks
        :param bool per: Use or not prioritized experience replay
        :param int max_time: Maximum steps for sub policy
        """

        # Parameters
        self.logger = logger
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.target_interval = target_interval
        self.lr = lr
        self.bs = bs
        # Macro Policy parameters
        self.eps = eps
        self.eps_decay = 1 - eps_decay
        self.gamma_macro = gamma_macro
        # Sub policy parameters
        self.n_subpolicy = n_subpolicy
        self.tau = tau
        self.eps_sub = eps_sub
        self.eps_sub_decay = 1 - eps_sub_decay
        self.gamma = gamma
        # ICM parameters
        self.beta = beta
        self.lam = lam

        self.n_proc = n_proc
        self.selected_policy = np.full((self.n_proc,), fill_value=None)
        self.macro_state = np.full((self.n_proc,), fill_value=None)
        self.max_time = max_time
        self.train_steps = train_steps
        self.reward_rescale = reward_rescale
        self.norm_input = norm_input
        self.per = per
        self.curr_time = np.zeros((self.n_proc, ), dtype=np.int)
        self.macro_reward = np.zeros((self.n_proc,), dtype=np.float)
        self.target_count = np.zeros((self.n_subpolicy,), dtype=np.int)
        self.counter_macro = np.zeros((self.n_subpolicy,), dtype=np.int)
        self.counter_policies = np.zeros((self.n_subpolicy, self.action_dim), dtype=np.int)
        self.macro_count = 0

        memory = Memory

        # Create Policies / ICM modules / Memories
        self.macro = DDQN_Model(state_dim, n_subpolicy, conv=conv, hidd_ch=hidd_ch)
        self.macro_target = DDQN_Model(state_dim, n_subpolicy, conv=conv, hidd_ch=hidd_ch)
        self.macro_target.update_target(self.macro)
        self.macro_memory = Memory(max_memory)
        self.macro_opt = torch.optim.Adam(self.macro.parameters(), lr=self.lr * 4 if self.per else self.lr)
        self.memory, self.policy, self.target, self.icm, self.policy_opt, self.icm_opt = [], [], [], [], [], []
        for i in range(n_subpolicy):
            # Create sub-policies
            self.policy.append(DDQN_Model(state_dim, action_dim, conv=conv, hidd_ch=hidd_ch, macro=self.macro).to(sett.device))
            self.target.append(DDQN_Model(state_dim, action_dim, conv=conv, hidd_ch=hidd_ch, macro=self.macro).to(sett.device))
            self.target[-1].update_target(self.policy[-1])
            self.memory.append(memory(max_memory_sub))

            # Create ICM modules
            self.icm.append(ICM_Model(self.state_dim, self.action_dim, conv).to(sett.device))

            # Create sub optimizers
            self.policy_opt.append(torch.optim.Adam(self.policy[i].parameters(), lr=self.lr))
            self.icm_opt.append(torch.optim.Adam(self.icm[i].parameters(), lr=1e-3))

        # Send macro to correct device
        self.macro = self.macro.to(sett.device)
        self.macro_target = self.macro_target.to(sett.device)

    def save(self, i):
        if not os.path.isdir(sett.SAVEPATH):
            os.makedirs(sett.SAVEPATH)
        torch.save(self.macro.state_dict(), os.path.join(sett.SAVEPATH, 'Macro_%s.pth' % i))
        for sub in range(self.n_subpolicy):
            torch.save(self.policy[sub].state_dict(), os.path.join(sett.SAVEPATH, 'Sub_%s_%s.pth' % (sub, i)))
            torch.save(self.icm[sub].state_dict(), os.path.join(sett.SAVEPATH, 'Icm_%s_%s.pth' % (sub, i)))

    def load(self, path, i):
        self.macro.load_state_dict(torch.load(os.path.join(path, 'Macro_%s.pth' % i), map_location=sett.device))
        for sub in range(self.n_subpolicy):
            self.policy[sub].load_state_dict(torch.load(os.path.join(path, 'Sub_%s_%s.pth' % (sub, i)), map_location=sett.device))
            self.icm[sub].load_state_dict(torch.load(os.path.join(path, 'Icm_%s_%s.pth' % (sub, i)), map_location=sett.device))

    def act(self, obs, deterministic=False):
        x = torch.from_numpy(obs).float().to(sett.device)
        if self.norm_input:
            x /= 255

        for i, sel_policy, curr_time in zip(range(self.n_proc), self.selected_policy, self.curr_time):
            if sel_policy is None or curr_time == self.max_time:
                if sel_policy is not None and not deterministic:
                    # Store non terminal macro transition
                    self.macro_memory.store_transition(self.macro_state[i], obs[i], sel_policy, self.macro_reward[i], False)
                    self.macro_reward[i] = 0

                # Pick macro action
                self.selected_policy[i] = self.pick_policy(x[i][None], deterministic=deterministic)
                assert isinstance(self.selected_policy[i], int)
                self.curr_time[i] = 0
                if not deterministic:
                    self.macro_state[i] = obs[i]

                self.counter_macro[sel_policy] += 1

        eps = max(0.01, self.eps_sub) if not deterministic else 0.01
        sel_pol = np.unique(self.selected_policy)
        sel_indices = [(self.selected_policy == i).nonzero()[0] for i in sel_pol]
        action = -np.ones((self.n_proc,), dtype=np.int)
        for policy_idx, indices in zip(sel_pol, sel_indices):
            action[indices] = self.policy[policy_idx].act(x[indices], eps=eps, backbone=self.macro)
            self.counter_policies[policy_idx, action[indices]] += 1

        self.curr_time += 1  # Is a vector
        return action

    def pick_policy(self, obs, deterministic=False):
        eps = max(0.01, self.eps) if not deterministic else 0.01
        policy = self.macro.act(obs, eps=eps)
        return policy

    def set_mode(self, training=False):
        for policy in self.policy:
            policy.train(training)
        self.macro.train(training)
        self.selected_policy[:] = None
        self.curr_time[:] = 0

    def process_reward(self, reward):
        # Rescale reward if a scaling is provided
        if self.reward_rescale != 0:
            if self.reward_rescale == 1:
                reward = np.sign(reward)
            elif self.reward_rescale == 2:
                reward = np.clip(reward, -1, 1)
            else:
                reward *= self.reward_rescale
        return reward

    def store_transition(self, s, s1, a, reward, is_terminal):
        reward = self.process_reward(reward)

        for i, sel_policy in enumerate(self.selected_policy):
            # Store sub policy experience
            self.memory[sel_policy].store_transition(s[i], s1[i], a[i], reward[i], is_terminal[i])
            self.macro_reward[i] += reward[i]

            # Store terminal macro transition
            if is_terminal[i]:
                self.macro_memory.store_transition(self.macro_state[i], s1[i], sel_policy, self.macro_reward[i], is_terminal[i])
                self.macro_reward[i] = 0
                self.selected_policy[i] = None

    def update(self):
        for i in range(self.train_steps):
            self._update()
            if self.logger is not None:
                self.logger.step += 1

    def _update(self):
        # First train each sub policy
        self.macro_opt.zero_grad()  # To allow cumulative gradients on backbone part

        for i in range(self.n_subpolicy):
            memory = self.memory[i]
            if len(memory) < self.bs * 100:
                continue

            policy = self.policy[i]
            target = self.target[i]
            icm = self.icm[i]
            policy_opt = self.policy_opt[i]
            icm_opt = self.icm_opt[i]

            if self.per:
                state, new_state, action, reward, is_terminal, idxs, w_is = memory.sample(self.bs)
                reduction = 'none'
                self.logger.log_scalar(tag='Beta PER %i' % i, value=memory.beta)
            else:
                state, new_state, action, reward, is_terminal = memory.sample(self.bs)
                reduction = 'mean'

            if self.norm_input:
                state = np.array(state, dtype=np.float) / 255
                new_state = np.array(new_state, dtype=np.float) / 255

            state = torch.tensor(state, dtype=torch.float).detach().to(sett.device)
            new_state = torch.tensor(new_state, dtype=torch.float).detach().to(sett.device)
            action = torch.tensor(action).detach().to(sett.device)
            reward = torch.tensor(reward, dtype=torch.float).detach().to(sett.device)
            is_terminal = 1. - torch.tensor(is_terminal, dtype=torch.float).detach().to(sett.device)

            # Augment rewards with curiosity
            curiosity_rewards = icm.curiosity_rew(state, new_state, action)
            reward = (1 - 0.01) * reward + 0.01 * self.lam * curiosity_rewards

            # Policy loss
            q = policy.forward(state, macro=self.macro)[torch.arange(self.bs), action]
            max_action = torch.argmax(policy.forward(new_state, macro=self.macro), dim=1)
            y = reward + self.gamma * target.forward(new_state, macro=self.macro)[torch.arange(self.bs), max_action] * is_terminal
            policy_loss = smooth_l1_loss(input=q, target=y.detach(), reduction=reduction).mean(-1)

            # ICM Loss
            phi_hat = icm.forward(state, action)
            phi_true = icm.phi_state(new_state)
            fwd_loss = mse_loss(input=phi_hat, target=phi_true.detach(), reduction=reduction).mean(-1)
            a_hat = icm.inverse_pred(state, new_state)
            inv_loss = cross_entropy(input=a_hat, target=action.detach(), reduction=reduction)

            # Total loss
            inv_loss = (1 - self.beta) * inv_loss
            fwd_loss = self.beta * fwd_loss * 288
            loss = self.tau * policy_loss + inv_loss + fwd_loss

            if self.per:
                error = np.clip((torch.abs(q - y)).cpu().data.numpy(), 0, 0.8)
                inv_prob = (1 - softmax(a_hat, dim=1)[torch.arange(self.bs), action]) / 5
                curiosity_error = torch.abs(inv_prob).cpu().data.numpy()
                total_error = error + curiosity_error

                # update priorities
                for k in range(self.bs):
                    memory.update(idxs[k], total_error[k])

                loss = (loss * torch.FloatTensor(w_is).to(sett.device)).mean()

            policy_opt.zero_grad()
            icm_opt.zero_grad()
            loss.backward()
            for param in policy.parameters():
                param.grad.data.clamp(-1, 1)
            policy_opt.step()
            icm_opt.step()

            self.target_count[i] += 1
            if self.target_count[i] == self.target_interval:
                self.target_count[i] = 0
                self.target[i].update_target(self.policy[i])

            if self.logger is not None:
                self.logger.log_scalar(tag='Policy Loss %i' % i, value=policy_loss.mean().cpu().data.numpy())
                self.logger.log_scalar(tag='ICM Fwd Loss %i' % i, value=fwd_loss.mean().cpu().data.numpy())
                self.logger.log_scalar(tag='ICM Inv Loss %i' % i, value=inv_loss.mean().cpu().data.numpy())
                self.logger.log_scalar(tag='Total Policy Loss %i' % i, value=loss.mean().cpu().data.numpy())
                self.logger.log_scalar(tag='Mean Curiosity Reward %i' % i, value=curiosity_rewards.mean().cpu().data.numpy())
                self.logger.log_scalar(tag='Q values %i' % i, value=q.mean().cpu().data.numpy())
                self.logger.log_scalar(tag='Target Boltz %i' % i, value=y.mean().cpu().data.numpy())
                actions = self.counter_policies[i] / max(1, self.counter_policies[i].sum())
                self.logger.log_text(tag='Policy actions %i Text' %i, value=[str(v) for v in actions],
                                     step=self.logger.step)
                if self.per:
                    self.logger.log_scalar(tag='PER Error %i' % i, value=total_error.mean())
                    self.logger.log_scalar(tag='PER Error Policy %i' % i, value=error.mean())
                    self.logger.log_scalar(tag='PER Error Curiosity %i' % i, value=curiosity_error.mean())

        # Reduce sub eps
        self.eps_sub = self.eps_sub * self.eps_sub_decay

        # Train Macro policy
        if len(self.macro_memory) < self.bs * 100:
            return

        # Reduce eps
        self.eps = self.eps * self.eps_decay

        state, new_state, action, reward, is_terminal = self.macro_memory.sample(self.bs)
        if self.norm_input:
            state = np.array(state, dtype=np.float) / 255
            new_state = np.array(new_state, dtype=np.float) / 255

        state = torch.tensor(state, dtype=torch.float).detach().to(sett.device)
        new_state = torch.tensor(new_state, dtype=torch.float).detach().to(sett.device)
        action = torch.tensor(action).detach().to(sett.device)
        reward = torch.tensor(reward, dtype=torch.float).detach().to(sett.device)
        is_terminal = 1. - torch.tensor(is_terminal, dtype=torch.float).detach().to(sett.device)

        q = self.macro.forward(state)[torch.arange(self.bs), action]
        max_action = torch.argmax(self.macro.forward(new_state), dim=1)
        y = reward + self.gamma_macro * self.macro_target.forward(new_state)[torch.arange(self.bs), max_action] * is_terminal
        loss = smooth_l1_loss(input=q, target=y.detach())

        loss.backward()
        for param in self.macro.parameters():
            param.grad.data.clamp(-1, 1)
        self.macro_opt.step()

        self.macro_count += 1
        if self.macro_count == self.target_interval:
            self.macro_count = 0
            self.macro_target.update_target(self.macro)

        if self.logger is not None:
            self.logger.log_scalar(tag='Macro Loss', value=loss.cpu().detach().numpy())
            self.logger.log_scalar(tag='Sub Eps', value=self.eps_sub)
            self.logger.log_scalar(tag='Macro Eps', value=self.eps)
            values = self.counter_macro / max(1, sum(self.counter_macro))
            self.logger.log_text(tag='Macro Policy Actions Text', value=[str(v) for v in values],
                                 step=self.logger.step)
            self.logger.log_histogram(tag='Macro Policy Actions Hist', values=values,
                                      step=self.logger.step, bins=self.n_subpolicy)
            self.logger.log_scalar(tag='Macro Q values', value=q.cpu().detach().numpy().mean())
            self.logger.log_scalar(tag='Marcro Target Boltz', value=y.cpu().detach().numpy().mean())