## memory

In [74]:
import numpy as np
import collections

class Memory:
    def __init__(self, max_memory):
        self.max_memory = max_memory
        self.state = []
        self.new_state = []
        self.action = []
        self.reward = []
        self.is_terminal = []
        self.idx = 0

    def __len__(self):
        return len(self.state)

    def store_transition(self, s, s1, a, r, is_terminal):
        if len(self.state) <= self.max_memory:
            self.state.append(s)
            self.new_state.append(s1)
            self.action.append(a)
            self.reward.append(r)
            self.is_terminal.append(is_terminal)
        else:
            self.state[self.idx] = s
            self.new_state[self.idx] = s1
            self.action[self.idx] = a
            self.reward[self.idx] = r
            self.is_terminal[self.idx] = is_terminal
            self.idx = (self.idx + 1) % self.max_memory
        assert len(self.state) == len(self.new_state) == len(self.reward) == len(self.is_terminal) == len(self.action)


    def clear_memory(self):
        del self.state[:]
        del self.new_state[:]
        del self.action[:]
        del self.reward[:]
        del self.is_terminal[:]

    def sample(self, bs):
        idx = np.random.randint(len(self.state), size=bs)
        state, new_state, action, reward, is_terminal = [], [], [], [], []
        for i in idx:
            state.append(self.state[i])
            new_state.append(self.new_state[i])
            action.append(self.action[i])
            reward.append(self.reward[i])
            is_terminal.append(int(self.is_terminal[i]))
        return state, new_state, action, reward, is_terminal

    def update(self, **kwargs):
        pass


## model DQN

In [84]:
import numpy as np
import torch
from rl.model import DDQN_Model, ICM_Model
from rl.per import PERMemory
from torch.nn.functional import mse_loss, cross_entropy, smooth_l1_loss, softmax
import settings as sett
import itertools
import os


class DQN:
    def __init__(self, state_dim, tau, action_dim, gamma, hidd_ch, lam, lr,
                 eps_sub, eps_sub_decay, beta, bs, target_interval, train_steps, max_memory,
                 conv, reward_rescale, n_proc, per=False, norm_input=True, logger=None):
        """
        :param state_dim: Shape of the state
        :param float tau: Weight for agent loss
        :param int action_dim: Number of actions
        :param float gamma: Discount for sub controller
        :param int hidd_ch: Number of hidden channels
        :param float lam: Scaler for ICM reward
        :param float lr: Learning rate
        :param float eps_sub: Eps greedy change for sub policies
        :param float eps_sub_decay: Epsilon decay for sub policy computed as eps * (1 - eps_decay) each step
        :param float beta: Weight for loss of fwd net vs inv net
        :param int bs: Batch size
        :param int target_interval: Number of train steps between target updates
        :param int train_steps: Number of training iterations for each call
        :param int max_memory: Max memory
        :param bool conv: Use or not convolutional networks
        :param bool per: Use or not prioritized experience replay
        """

        # Parameters
        self.logger = logger
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.target_interval = target_interval
        self.lr = lr
        self.bs = bs
        # policy parameters
        self.tau = tau
        self.eps_sub = eps_sub
        self.eps_sub_decay = 1 - eps_sub_decay
        self.gamma = gamma
        # ICM parameters
        self.beta = beta
        self.lam = lam

        self.n_proc = n_proc
        self.train_steps = train_steps
        self.reward_rescale = reward_rescale
        self.norm_input = norm_input
        self.per = per
        self.target_count = 0

        if self.per:
            memory = PERMemory
        else:
            memory = Memory

        # Create Policies / ICM modules / Memories
        self.agent = DDQN_Model(self.state_dim, self.action_dim, hidd_ch)
        self.agent_target = DDQN_Model(self.state_dim, self.action_dim, hidd_ch)
        self.agent_target.update_target(self.agent)
        self.agent_memory = memory(max_memory)
        self.agent_opt = torch.optim.Adam(self.agent.parameters(), lr=self.lr)
        # self.icm = ICM_Model(self.state_dim, self.action_dim, conv).to(sett.device)
        # self.icm_opt = torch.optim.Adam(self.icm.parameters(), lr=1e-3)

        # Send macro to correct device
        self.agent = self.agent.to(sett.device)
        self.agent_target = self.agent_target.to(sett.device)

    def act(self, obs, deterministic=False):
        x = torch.from_numpy(obs).float().to(sett.device)
        if self.norm_input:
            x /= 255

        eps = max(0.01, self.eps_sub) if not deterministic else 0.01
        actions = self.agent.act(x, eps=eps)
        return actions

    def set_mode(self, training=False):
        self.agent.train(training)

    def process_reward(self, reward):
        # Rescale reward if a scaling is provided
        if self.reward_rescale != 0:
            if self.reward_rescale == 1:
                reward = np.sign(reward)
            elif self.reward_rescale == 2:
                reward = np.clip(reward, -1, 1)
            else:
                reward *= self.reward_rescale
        return reward

    def save(self, i):
        if not os.path.isdir(sett.SAVEPATH):
            os.makedirs(sett.SAVEPATH)
        torch.save(self.agent.state_dict(), os.path.join(sett.SAVEPATH, 'agent_%s.pth' % i))

    def load(self, path, i):
        self.agent.load_state_dict(torch.load(os.path.join(path, 'agent_%s.pth' % i), map_location=sett.device))

    def store_transition(self, s, s1, a, reward, is_terminal):
        reward = self.process_reward(reward)
        self.agent_memory.store_transition(s, s1, a, reward, is_terminal)
#         for i in range(len(s)):
#             self.agent_memory.store_transition(s[i], s1[i], a[i], reward[i], is_terminal[i])

    def update(self):
        for i in range(self.train_steps):
            self._update()
            if self.logger is not None:
                self.logger.step += 1

    def _update(self):
        # First train each sub policy
        i = 0
        memory = self.agent_memory
        if len(memory) < self.bs * 10:
            return

        policy = self.agent
        target = self.agent_target
        # icm = self.icm
        policy_opt = self.agent_opt
        # icm_opt = self.icm_opt

        if self.per:
            state, new_state, action, reward, is_terminal, idxs, w_is = memory.sample(self.bs)
            reduction = 'none'
            self.logger.log_scalar(tag='Beta PER %i' % i, value=memory.beta)
        else:
            state, new_state, action, reward, is_terminal = memory.sample(self.bs)
            reduction = 'mean'

        if self.norm_input:
            state = np.array(state, dtype=np.float) / 255
            new_state = np.array(new_state, dtype=np.float) / 255

        state = torch.tensor(np.stack(state), dtype=torch.float).detach().to(sett.device)
        new_state = torch.tensor(np.stack(new_state), dtype=torch.float).detach().to(sett.device)
        action = torch.tensor(action).detach().to(sett.device)
        reward = torch.tensor(reward, dtype=torch.float).detach().to(sett.device)
        is_terminal = 1. - torch.tensor(is_terminal, dtype=torch.float).detach().to(sett.device)

        # Augment rewards with curiosity
        # curiosity_rewards = icm.curiosity_rew(state, new_state, action)
        # reward = (1 - 0.01) * reward + 0.01 * self.lam * curiosity_rewards
        reward = reward

        # Policy loss
        q = policy.forward(state)[torch.arange(self.bs), action]
        max_action = torch.argmax(policy.forward(new_state), dim=1)
        y = reward + self.gamma * target.forward(new_state)[torch.arange(self.bs), max_action] * is_terminal
        policy_loss = smooth_l1_loss(input=q, target=y.detach(), reduction=reduction).mean(-1)

        # ICM Loss
        # phi_hat = icm.forward(state, action)
        # phi_true = icm.phi_state(new_state)
        # fwd_loss = mse_loss(input=phi_hat, target=phi_true.detach(), reduction=reduction).mean(-1)
        # a_hat = icm.inverse_pred(state, new_state)
        # inv_loss = cross_entropy(input=a_hat, target=action.detach(), reduction=reduction)

        # Total loss
        # inv_loss = (1 - self.beta) * inv_loss
        # fwd_loss = self.beta * fwd_loss * 288
        # loss = self.tau * policy_loss + inv_loss + fwd_loss
        loss = policy_loss

        policy_opt.zero_grad()
        # icm_opt.zero_grad()
        loss.backward()
        # for param in policy.parameters():
        #     param.grad.data.clamp(-1, 1)
        policy_opt.step()
        # icm_opt.step()

        self.target_count += 1
        if self.target_count == self.target_interval:
            self.target_count = 0
            self.agent_target.update_target(self.agent)

        print('Policy Loss %i' % i, policy_loss.mean().cpu().data.numpy())
#         print('ICM Fwd Loss %i' % i, fwd_loss.mean().cpu().data.numpy())
#         print('ICM Inv Loss %i' % i, inv_loss.mean().cpu().data.numpy())
#         print('Total Policy Loss %i' % i, loss.mean().cpu().data.numpy())
#         print('Mean Curiosity Reward %i' % i, curiosity_rewards.mean().cpu().data.numpy())
        # print('Q values %i' % i, q.mean().cpu().data.numpy())
#         print('Target Boltz %i' % i, y.mean().cpu().data.numpy())

        # Reduce sub eps
        self.eps_sub = self.eps_sub * self.eps_sub_decay

        
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import settings as sett


class DDQN_Model(nn.Module):
    def __init__(self, state_size, action_size, conv, macro=None, hidd_ch=256, conv_ch=32):
        super(DDQN_Model, self).__init__()
        self.action_size = action_size
        self.hidd_ch = hidd_ch
        state_size = 4
        if macro is None:
            self.backbone = nn.Sequential(
                nn.Linear(4, 256),
                nn.ReLU(),
            )


        self.features = nn.Sequential(
            nn.Linear(256, hidd_ch),
            nn.ReLU(),
            #nn.Linear(hidd_ch, hidd_ch),
            #nn.ReLU()
        )

        out_shape = hidd_ch
        # self.lstm_out = nn.LSTM(out_shape, hidd_ch, 1, batch_first=True)

        self.advantage = nn.Sequential(
            nn.Linear(hidd_ch, self.action_size)
        )

        self.value = nn.Sequential(
            nn.Linear(hidd_ch, 1)
        )

    def forward(self, obs, macro=None):
        obs = obs.view(-1, 4).float()
        backbone = self.backbone if macro is None else macro.backbone
        x = self.features(backbone(obs))
        x = x.view(x.size(0), -1)

        # h0 = torch.zeros(1, self.hidd_ch).to(sett.device)
        # c0 = torch.zeros(1, self.hidd_ch).to(sett.device)

        # lstm_out, (hn, cn) = self.lstm_out(x, (h0.detach(), c0.detach()))
        # x = lstm_out
        adv = self.advantage(x)
        value = self.value(x)
        return value + (adv - adv.mean(-1, keepdim=True))

    def act(self, state, eps, backbone=None):
        if np.random.random() > eps:
            q = self.forward(state, backbone)
            action = torch.argmax(q, dim=-1).cpu().data.numpy()
        else:
            action = np.random.randint(self.action_size, size=1)
        return action.item() if action.shape == (1,) else list(action.astype(np.int64))

    def update_target(self, model):
        self.load_state_dict(model.state_dict())


class ICM_Model(nn.Module):
    def __init__(self, state_size, action_size, conv):
        super(ICM_Model, self).__init__()

        self.action_size = action_size
        self.state_size = 4

        self.phi = nn.Sequential(
            nn.Linear(4, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )

        out_shape = self.phi(torch.randn(*((1,) + tuple([self.state_size])))).view(-1).size().numel()

        # Forward Model
        self.fwd = nn.Sequential(
            nn.Linear(out_shape + 1, 256),
            nn.ReLU(),
            nn.Linear(256, out_shape)
        )

        # Inverse Model
        self.inv = nn.Sequential(
            nn.Linear(out_shape * 2, 256),
            nn.ELU(),
            nn.Linear(256, action_size)
        )

    def forward(self, *input):
        obs, action = input
        action = action.view(-1, 1)
        phi = self.phi_state(obs)
        x = torch.cat((phi, action.float()), -1)
        phi_hat = self.fwd(x)
        return phi_hat

    def phi_state(self, s):
        s = s.view(-1, 4)
        x = s
        x = self.phi(x)
        return x.view(x.size(0), -1)

    def inverse_pred(self, s, s1):
        s = self.phi_state(s.view(-1, 4).float())
        s1 = self.phi_state(s1.view(-1, 4).float())
        x = torch.cat((s, s1), -1)
        return self.inv(x)

    def curiosity_rew(self, s, s1, a):
        phi_hat = self.forward(s.view(-1, 4), a)
        phi_s1 = self.phi_state(s1.view(-1, 4))
        cur_rew = 1 / 2 * (torch.norm(phi_hat - phi_s1, p=2, dim=-1) ** 2)
        return cur_rew



## wrapper

In [85]:
import numpy as np
from gym import Wrapper
from gym.spaces import Box


class TimeLimit(Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super(TimeLimit, self).__init__(env)
        self.max_episode_steps = max_episode_steps
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = None

    def step(self, action):
        observation, reward, done, terminal, info = self.env.step(action)
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
        return observation, reward, done or terminal, info

    def reset(self, **kwargs):
        self._elapsed_steps = 0
        obs, info = self.env.reset(**kwargs)
        return obs

    from collections import deque
import numpy as np

from gym.spaces import Box
from gym import ObservationWrapper
from collections import deque


class LazyFrames(object):
    r"""Ensures common frames are only stored once to optimize memory use. 

    To further reduce the memory use, it is optionally to turn on lz4 to 
    compress the observations.

    .. note::

        This object should only be converted to numpy array just before forward pass. 

    """
    def __init__(self, frames, lz4_compress=False):
        if lz4_compress:
            from lz4.block import compress
            self.shape = frames[0].shape
            self.dtype = frames[0].dtype
            frames = [compress(frame) for frame in frames]
        self._frames = frames
        self.lz4_compress = lz4_compress

    def __array__(self, dtype=None):
        if self.lz4_compress:
            from lz4.block import decompress
            frames = [np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.shape) for frame in self._frames]
        else:
            frames = self._frames
        out = np.stack(frames, axis=0)
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        return len(self.__array__())

    def __getitem__(self, i):
        return self.__array__()[i]


class FrameStack(ObservationWrapper):
    r"""Observation wrapper that stacks the observations in a rolling manner. 

    For example, if the number of stacks is 4, then the returned observation contains
    the most recent 4 observations. For environment 'Pendulum-v0', the original observation
    is an array with shape [3], so if we stack 4 observations, the processed observation
    has shape [3, 4]. 

    .. note::

        To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`.

    .. note::

        The observation space must be `Box` type. If one uses `Dict`
        as observation space, it should apply `FlattenDictWrapper` at first. 

    Example::

        >>> import gym
        >>> env = gym.make('PongNoFrameskip-v0')
        >>> env = FrameStack(env, 4)
        >>> env.observation_space
        Box(4, 210, 160, 3)

    Args:
        env (Env): environment object
        num_stack (int): number of stacks

    """
    def __init__(self, env, num_stack, lz4_compress=False):
        super(FrameStack, self).__init__(env)
        self.num_stack = num_stack
        self.lz4_compress = lz4_compress

        self.frames = deque(maxlen=num_stack)

        low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
        high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
        self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)

    def _get_observation(self):
        assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
        return LazyFrames(list(self.frames), self.lz4_compress)

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        self.frames.append(observation)
        return self._get_observation(), reward, done, info

    def reset(self, **kwargs):
        observation = self.env.reset(**kwargs)
        [self.frames.append(observation) for _ in range(self.num_stack)]
        return self._get_observation()


## env

In [86]:
from gym import Env
import numpy as np
from gym import error, spaces, utils
from gym.utils import seeding
import logging
import gym

logger = logging.getLogger(__name__)


class DimGridEnvironment(Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, size, hard):
        super(DimGridEnvironment, self).__init__()
        self.size = size
        self.observation_space = spaces.Box(0, 1, [self.size, self.size, 1])
        self.action_space = spaces.Discrete(5)
        self.alternate_dim = False
        self.dim0 = None
        self.dim1 = None
        self.dict = {'0': {'empty': 0, 'wall': 0.33, 'player': 0.66, 'goal': 1},
                     '1': {'empty': 1, 'wall': 0.66, 'player': 0.33, 'goal': 0}}
        self.action_dict = {0: [0, -1], 1: [-1, 0], 2: [0, 1], 3: [1, 0], 4: 4}
        self.goal = None
        self.player_pos = None
        self.walls = None
        self.hard = hard
        self.seed = self._seed()

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return seed

    def _compute_action(self, action):
        action = self.action_dict[action]
        reward = 0.

        if action == 4: # Change dimension
            self.alternate_dim = not self.alternate_dim
        else:
            new_pos = tuple(np.array(self.player_pos) + action)
            if any(np.array(new_pos) == self.size) or any(np.array(new_pos) < 0):
                reward = -1. if not self.hard else 0.
                return reward, True
            self.player_pos = new_pos
            touched_wall = self.walls[self.player_pos]
            if self.alternate_dim:
                reward += -0.1
            else:
                if touched_wall:
                    reward += -0.5
                    if self.hard:
                        return 0., True

        terminal = True if self.player_pos == self.goal else False
        reward += int(terminal)
        return reward, terminal

    def step(self, action):
        reward, terminal = self._compute_action(action)
        obs = self._get_obs()
        return obs, reward, terminal, {}

    def _get_obs(self):
        if self.alternate_dim:
            self.dim1[self.walls] = self.dict['1']['wall']
            self.dim1[self.player_pos] = self.dict['1']['player']
            return self.dim1
        else:
            self.dim0[self.walls] = self.dict['0']['wall']
            self.dim0[self.player_pos] = self.dict['0']['player']
            return self.dim0

    def reset(self):
        y = np.random.choice(self.size, 2, replace=False)
        x = np.random.choice(self.size, 2, replace=False)
        self.player_pos = tuple([x[0], y[0]])
        self.goal = tuple([x[1], y[1]])

        # Fill dim0
        self.dim0 = np.zeros(self.observation_space.shape)
        self.walls = np.random.rand(*self.observation_space.shape) > 0.5
        self.dim0[self.walls] = self.dict['0']['wall']
        self.dim0[self.goal] = self.dict['0']['goal']
        self.dim0[self.player_pos] = self.dict['0']['player']

        # Fill dim1
        self.dim1 = np.ones(self.observation_space.shape)
        self.dim1[self.walls] = self.dict['1']['wall']
        self.dim1[self.goal] = self.dict['1']['goal']
        self.dim1[self.player_pos] = self.dict['1']['player']

        self.walls[self.goal] = False
        self.walls[self.player_pos] = False
        self.alternate_dim = False
        return self._get_obs()

    def render(self, mode='human', close=False):
        logger.info('\n Dim0: %s \n Dim1:  %s' % (self.dim0, self.dim1))


## core

In [87]:
# Setup env
env = gym.make('CartPole-v1')
env = TimeLimit(env, max_episode_steps=500)
# env = FrameStack(env, num_stack=1)

In [88]:
obs = env.reset()

# Setup Model
n_actions = env.action_space.n if env.action_space.shape == () else env.action_space.shape[0]
n_state = env.observation_space.n if env.observation_space.shape == () else env.observation_space.shape
print('n_actions', n_actions)
print('n_state', n_state)

conv = True if isinstance(n_state, tuple) else False
conv = False
print('conv', conv)

n_actions 2
n_state (4,)
conv False


In [94]:
dqn = DQN(state_dim=4,
          tau=0.001,
          action_dim=n_actions,
          gamma=0.95,
          hidd_ch=256,
          lam=0.001,
          lr=1e-4,
          eps_sub=0.3,
          eps_sub_decay=1,
          beta=0.9,
          bs=128,
          target_interval=10,
          train_steps=1,
          max_memory=int(1e6),
          conv=conv,
          per=None,
          n_proc=1,
          reward_rescale=False,
          logger=None,
          norm_input=False
          )

In [95]:
np.array(obs).reshape(-1)

array([ 0.14419308,  1.5188859 , -0.2162334 , -2.5654423 ], dtype=float32)

In [98]:
for _ in range(10000):
    dqn.set_mode(training=False)
    obs = env.reset()
    for i in range(512):
        action = dqn.act(np.array(obs).reshape(-1))
        obs_new, r, is_terminal, info = env.step(action)
        dqn.store_transition(np.array(obs).reshape(-1), np.array(obs_new).reshape(-1), action, r, is_terminal)
        if is_terminal:
            obs = env.reset()
    dqn.set_mode(training=True)
    dqn.update()

Policy Loss 0 604.7045
Policy Loss 0 675.34216
Policy Loss 0 989.8718
Policy Loss 0 1610.032
Policy Loss 0 1452.9939
Policy Loss 0 1975.2595
Policy Loss 0 1958.0645
Policy Loss 0 1830.3772
Policy Loss 0 2064.954
Policy Loss 0 1726.6636
Policy Loss 0 2004.5303
Policy Loss 0 2927.1484
Policy Loss 0 2753.8633
Policy Loss 0 1953.72
Policy Loss 0 3033.5583
Policy Loss 0 3021.8313
Policy Loss 0 3560.927
Policy Loss 0 3797.3677
Policy Loss 0 3560.0483
Policy Loss 0 3405.7344
Policy Loss 0 4924.661
Policy Loss 0 5250.58
Policy Loss 0 4910.124
Policy Loss 0 4356.9424
Policy Loss 0 4207.614
Policy Loss 0 4679.3457
Policy Loss 0 5033.8857
Policy Loss 0 4336.6626
Policy Loss 0 5366.8877
Policy Loss 0 5742.596
Policy Loss 0 6741.076
Policy Loss 0 6032.355
Policy Loss 0 6656.4326
Policy Loss 0 6418.996
Policy Loss 0 7455.3906
Policy Loss 0 5829.415
Policy Loss 0 6153.4033
Policy Loss 0 6508.756
Policy Loss 0 6364.0664
Policy Loss 0 6597.4478
Policy Loss 0 6847.6997
Policy Loss 0 8106.9336
Policy Los

KeyboardInterrupt: 

In [97]:
tot_reward = 0
dqn.set_mode(training=False)
obs = env.reset()
while True:
    action = dqn.act(np.array(obs), deterministic=True)
    obs_new, r, is_terminal, info = env.step(action)
    tot_reward += r
    obs = obs_new
    if is_terminal:
        break
print('total_reward', tot_reward)

dqn.set_mode(training=True)

total_reward 10.0


Training Progress
total_reward -535.8000000000001
total_reward -3.3000000000000003
total_reward -2.3
total_reward -3.8000000000000003
total_reward -1.7