In [1]:
import math, random

import gym
import numpy as np

import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn

from tqdm import tqdm, trange
from mxboard import SummaryWriter

In [123]:
from collections import deque

class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        state      = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done
    
    def __len__(self):
        return len(self.buffer)

In [2]:
def ger(x1, x2):
    result = list()
    for i in range(x2.shape[0]):      
        tmp = (x1 * x2[i])
        result.append(tmp)
    return nd.stack(*result).T

In [3]:
ctx=mx.gpu()

In [4]:
from wrappers import make_atari, wrap_deepmind, wrap_mxnet

In [5]:
env_id = "PongNoFrameskip-v4"
env    = make_atari(env_id)
env    = wrap_deepmind(env, frame_stack=True)
env    = wrap_mxnet(env)

In [14]:
class NoiseLayer(nn.Block):
    def __init__(self, in_features, out_features, ctx, std_init=0.4, mode='trainig', **kwargs):
        super(NoiseLayer, self).__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init
        self.mode = mode
        self.ctx=ctx
        
        
        with self.name_scope():
            self.weight_mu = self.params.get('weight_mu', shape=(out_features, in_features))
            self.weight_sigma = self.params.get('weight_sigma', shape=(out_features, in_features))
            self.weight_epsilon = self.params.get('weight_epsilon', shape=(out_features, in_features), grad_req='null')
        
            self.bias_mu = self.params.get('bias_mu', shape=(out_features))
            self.bias_sigma = self.params.get('bias_sigma', shape=(out_features))
            self.bias_epsilon = self.params.get('bias_epsilon', shape=(out_features), grad_req='null')
        self.initiali_idx = 0
    
    def reset_parameters(self):
        mu_range = 1 / math.sqrt(self.weight_mu.shape[1])
        _weight_mu = mx.random.uniform(-mu_range, mu_range, shape=self.weight_mu.shape, ctx=self.ctx)
        self._weight_mu = _weight_mu
#         self.params.setattr('.*weight_mu', _weight_mu)
        with autograd.pause():
            self.weight_mu.set_data(_weight_mu)
        _weight_sigma = nd.zeros(shape=self.weight_sigma.shape, ctx=self.ctx)
        _weight_sigma += self.std_init / math.sqrt(self.weight_sigma.shape[1])
#         self.params.setattr('weight_sigma', _weight_sigma)
        with autograd.pause():
            self.weight_sigma.set_data(_weight_sigma)
        
        _bias_mu = mx.random.uniform(-mu_range, mu_range, shape=self.bias_mu.shape, ctx=self.ctx)
        self.params.setattr('bias_mu', _bias_mu)
        with autograd.pause():
            self.bias_mu.set_data(_bias_mu)
        _bias_sigma = nd.zeros(shape=self.bias_sigma.shape, ctx=self.ctx)
        _bias_sigma += self.std_init / math.sqrt(self.bias_sigma.shape[0])
#         self.params.setattr('bias_sigma', _bias_sigma)
        with autograd.pause():
            self.bias_sigma.set_data(_bias_sigma)
    
    def reset(self):
        self.reset_noise()
        self.reset_parameters()
        
    
    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_features)
        epsilon_out = self._scale_noise(self.out_features)
        
        _weight_epsilon = ger(epsilon_out, epsilon_in)
        self.params.setattr('weight_epsilon', _weight_epsilon)
        with autograd.pause():
            self.weight_epsilon.set_data(_weight_epsilon)
        _bias_epsilon = self._scale_noise(self.out_features)
        self.params.setattr('bias_epsilon', _bias_epsilon)
        with autograd.pause():
            self.bias_epsilon.set_data(_bias_epsilon)
    
    def _scale_noise(self, size):
        x = nd.random_normal(shape=size)
        x = nd.multiply(x.sign(), x.abs().sqrt())
        return x
        
    
    def forward(self, x):
#         if self.initiali_idx ==0:
#             self.reset_noise()
#             self.reset_parameters()
#             self.initiali_idx +=1
        with x.context:
            if self.mode == 'trainig':
                weight = self.weight_mu.data(self.ctx) + nd.multiply(self.weight_sigma.data(self.ctx), self.weight_epsilon.data(self.ctx))
                bias = self.bias_mu.data() +nd.multiply(self.bias_sigma.data(self.ctx), self.bias_epsilon.data(self.ctx))
            else:
                weight = self.weight_mu.data(self.ctx)
                bias = self.bias_mu.data(self.ctx)
            out = nd.FullyConnected(x, weight, bias, num_hidden=weight.shape[0])
        return out

In [85]:
class RainbowDQN(nn.Block):
    def __init__(self, input_shape, num_actions, num_atoms, Vmin, Vmax, ctx, **kwargs):
        super(RainbowDQN, self).__init__(**kwargs)
        
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.num_atoms = num_atoms
        self.Vmin = Vmin
        self.Vmax = Vmax
        self.ctx = ctx
        
        with self.name_scope():
            self.conv1 = nn.Conv2D(32, 8, 4, in_channels=input_shape[0])
            self.bn1 = nn.BatchNorm()
            self.conv2 = nn.Conv2D(64, 4, 2, in_channels=32)
            self.bn2 = nn.BatchNorm()
            self.conv3 = nn.Conv2D(64, 3, 1, in_channels=64)
            self.bn3 = nn.BatchNorm()
            
            self.noisy_value1 = NoiseLayer(3136, 512, ctx)
            self.noisy_value2 = NoiseLayer(512, self.num_atoms, ctx)
            
            self.noisy_value_adv1 = NoiseLayer(3136, 512, ctx)
            self.noisy_value_adv2 = NoiseLayer(512, self.num_atoms * self.num_actions, ctx)
        
    
    def forward(self, x):
        batch_size = x.shape[0]
        self.noisy_value1.reset()
        self.noisy_value2.reset()
        self.noisy_value_adv1.reset()
        self.noisy_value_adv2.reset()
        
        
        x = x /255.0
        x = self.conv1(x)
        x = self.bn1(x)
        x = nd.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nd.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = nd.relu(x)
        x = x.reshape(batch_size, -1)
        
        value = nd.relu(self.noisy_value1(x))
        value = self.noisy_value2(value)
        
        adv = nd.relu(self.noisy_value_adv1(x))
        adv = self.noisy_value_adv2(adv)
  
        value = value.reshape(batch_size, 1, self.num_atoms)
        adv = adv.reshape(batch_size, self.num_actions, self.num_atoms)
        
        out = value + adv - nd.mean(adv, axis=1, keepdims=True)
        out = nd.softmax(out.reshape(-1, self.num_atoms)).reshape(-1, self.num_actions, self.num_atoms)
        return out
    
    def reset_noise(self):
        self.noisy_value1.reset_noise()
        self.noisy_value2.reset_noise()
        self.noisy_value_adv1.reset_noise()
        self.noisy_value_adv2.reset_noise()
    
    def act(self, state):
        state = nd.array(np.float32(state), ctx=ctx).expand_dims(0)
        dist = self.forward(state)
        dist = dist * nd.array(np.linspace(self.Vmin, self.Vmax, self.num_atoms), ctx=self.ctx)
        action = nd.argmax(dist.sum(2),1).asnumpy()[0]
        return action

In [572]:
def proj_distribution(next_state, reward, dones, target_model, num_atoms, ctx):
    batch_size = next_state.shape[0]
    
    delta_z = float(Vmax - Vmin) / (num_atoms - 1)
    support = nd.array(np.linspace(Vmin, Vmax, num_atoms), ctx=self.ctx)
    
    next_dist = target_model(next_state) * support
    next_action = nd.argmax(next_dist.sum(2),1)
    
    action_tmp = nd.stack(nd.arange(next_action.shape[0], ctx=ctx).expand_dims(-1),next_action.expand_dims(-1), axis=0)
    next_dist = nd.gather_nd(next_dist, action_tmp)
    next_dist = next_dist.squeeze()
    
    reward = nd.repeat(reward.expand_dims(1), axis=1, repeats=next_dist.shape[-1])
    dones = nd.repeat(dones.expand_dims(1), axis=1, repeats=next_dist.shape[-1])
    support = nd.repeat(nd.expand_dims(support,0), axis=0, repeats=next_dist.shape[0])
    
    Tz = reward + (1 - done) * 0.99 * support
    Tz = nd.clip(Tz, a_min=Vmin, a_max=Vmax)
    b  = (Tz - Vmin) / delta_z
    l = b.floor()
    u = b.ceil()
    
    offset = nd.expand_dims(nd.array(np.linspace(0, (batch_size - 1) * num_atoms, batch_size), ctx=ctx), 1)
    offset = nd.repeat(offset,repeats=51, axis=1)
    
    proj_dist = np.zeros(next_dist.shape) 
    index = (l + offset).reshape(-1).asnumpy()
    index = index.astype(int)
    value = (next_dist * (u - b)).reshape(-1).asnumpy()
    np.put(proj_dist.reshape(-1), index, value)
    
    index = (u + offset).reshape(-1).asnumpy()
    index = index.astype(int)
    value = (next_dist * (b - l)).reshape(-1).asnumpy()
    np.put(proj_dist.reshape(-1), index, value)
    proj_dist = nd.array(proj_dist, ctx=ctx)

    return proj_dist

In [None]:
def compute_td_loss(batch_size, current_model, target_model, loss_fn, ctx):
    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state      = nd.array((np.float32(state)), ctx=ctx)
    next_state = nd.array(np.float32(next_state), ctx=ctx)
    action     = nd.array((action), ctx=ctx)
    reward     = nd.array((reward), ctx=ctx)
    done       = nd.array((done), ctx=ctx)
   
    q_values      = current_model(state)
    next_q_values = current_model(next_state)
    next_q_state_values = target_model(next_state) 
    
    next_action = nd.argmax(next_q_values,1)
    q_values = nd.gather_nd(q_values, nd.stack(nd.arange(action.shape[0], ctx=ctx).expand_dims(-1),action.expand_dims(-1), axis=0))
    next_q_value = nd.gather_nd(next_q_state_values, nd.stack(nd.arange(next_action.shape[0], ctx=ctx).expand_dims(-1),\
                                                                  next_action.expand_dims(-1), axis=0))
    q_values = q_values.squeeze()
    next_q_value = next_q_value.squeeze()
    expected_q_value = reward + gamma * next_q_value * (1 - done)
    loss = loss_fn(q_values, expected_q_value)
        
    return loss

In [507]:
state, action, reward, next_state, done = replay_buffer.sample(3) 

In [508]:
    state      = nd.array((np.float32(state)), ctx=ctx)
    next_state = nd.array(np.float32(next_state), ctx=ctx)
    action     = nd.array((action), ctx=ctx)
    reward     = nd.array((reward), ctx=ctx)
    done       = nd.array((done), ctx=ctx)

In [509]:
batch_size

3

In [510]:
batch_size = next_state.shape[0]
    
delta_z = float(Vmax - Vmin) / (num_atoms - 1)
support = nd.array(np.linspace(Vmin, Vmax, num_atoms), ctx=ctx)

In [511]:
next_state.shape

(3, 4, 84, 84)

In [512]:
next_dist = net(nd.array(next_state, ctx=ctx)) * support
next_action = nd.argmax(next_dist.sum(2),1)

In [513]:
next_action


[4. 4. 4.]
<NDArray 3 @gpu(0)>

In [514]:
    action_tmp = nd.stack(nd.arange(next_action.shape[0], ctx=ctx).expand_dims(-1),next_action.expand_dims(-1), axis=0)
    next_dist = nd.gather_nd(next_dist, action_tmp)
    next_dist = next_dist.squeeze()

In [515]:
next_dist.shape

(3, 51)

In [516]:
reward = nd.repeat(reward.expand_dims(1), axis=1, repeats=next_dist.shape[-1])
done = nd.repeat(done.expand_dims(1), axis=1, repeats=next_dist.shape[-1])

In [517]:
support = nd.repeat(nd.expand_dims(support,0), axis=0, repeats=next_dist.shape[0])

In [518]:
reward.shape

(3, 51)

In [519]:
done.shape

(3, 51)

In [520]:
support.shape

(3, 51)

In [521]:
Tz = reward + (1 - done) * 0.99 * support

In [522]:
Tz = nd.clip(Tz, a_min=Vmin, a_max=Vmax)

In [523]:
b  = (Tz - Vmin) / delta_z
l = b.floor()
u = b.ceil()

In [524]:
offset = nd.expand_dims(nd.array(np.linspace(0, (batch_size - 1) * num_atoms, batch_size), ctx=ctx), 1)

In [526]:
offset = nd.repeat(offset,repeats=51, axis=1)

In [564]:
proj_dist = np.zeros(next_dist.shape)  

In [565]:
proj_dist.shape

(3, 51)

In [566]:
index = (l + offset).reshape(-1).asnumpy()
index = index.astype(int)

In [567]:
value = (next_dist * (u - b)).reshape(-1).asnumpy()

In [568]:
np.put(proj_dist.reshape(-1), index, value)

In [569]:
index = (u + offset).reshape(-1).asnumpy()
index = index.astype(int)
value = (next_dist * (b - l)).reshape(-1).asnumpy()

In [570]:
np.put(proj_dist.reshape(-1), index, value)

In [571]:
proj_dist

array([[-0.1427325 , -0.04757774, -0.04118025, -0.0368303 , -0.03808421,
        -0.02929653, -0.02871083, -0.02964238, -0.02361904, -0.02395556,
        -0.018809  , -0.0177003 , -0.013663  , -0.01415718, -0.01149903,
        -0.00924269, -0.00756872, -0.00629589, -0.00516141, -0.00413557,
        -0.00298855, -0.00205739, -0.00134245, -0.0006875 , -0.00030758,
         0.        ,  0.00868624,  0.01657625,  0.02186595,  0.0273362 ,
         0.0374762 ,  0.04330807,  0.05161782,  0.06513779,  0.05527223,
         0.07028253,  0.07834616,  0.0876225 ,  0.09846728,  0.09592575,
         0.10465291,  0.09772366,  0.10961386,  0.11379286,  0.12091429,
         0.12848528,  0.13614564,  0.1392462 ,  0.14301001,  0.1323646 ,
         0.14848903],
       [-0.14273714, -0.04757929, -0.04121738, -0.03676648, -0.03803771,
        -0.02936137, -0.02872591, -0.02958586, -0.0236678 , -0.02397388,
        -0.01877357, -0.0177041 , -0.01368735, -0.01414644, -0.01149197,
        -0.009237  , -0.00757

In [465]:
l = torch.from_numpy(l.asnumpy())
u = torch.from_numpy(u.asnumpy())
b = torch.from_numpy(b.asnumpy())
offset = torch.from_numpy(offset.asnumpy())

AttributeError: 'Tensor' object has no attribute 'asnumpy'

In [452]:
offset = torch.linspace(0, (batch_size - 1) * num_atoms, batch_size)\
                    .unsqueeze(1).expand(batch_size, num_atoms)

In [453]:
next_dist = torch.from_numpy(next_dist.asnumpy())

AttributeError: 'Tensor' object has no attribute 'asnumpy'

In [455]:
(l + offset).view(-1).long()

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  25,  26,
         27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
         41,  42,  43,  44,  45,  46,  47,  48,  49,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  76,  77,  78,  79,  80,  81,  82,
         83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,
         97,  98,  99, 100, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
        139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151])

In [464]:
    proj_dist = torch.zeros(next_dist.shape)    
    proj_dist.view(-1).index_add_(0, (l + offset).view(-1).long(), (next_dist * (u.float() - b)).view(-1))
    proj_dist.view(-1).index_add_(0, (u + offset).view(-1).long(), (next_dist * (b - l.float())).view(-1))

TypeError: sub() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
 * (Tensor other, Number alpha)
 * (Number other, Number alpha)


In [443]:
(next_dist * (u.float() - b)).view(-1)

TypeError: sub() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
 * (Tensor other, Number alpha)
 * (Number other, Number alpha)


In [420]:
proj_dist = torch.zeros(next_dist.shape)

In [430]:
aa = proj_dist.view(-1).index_add_(0, (l + offset).view(-1).long(), (next_dist * (u.float() - b)).view(-1))

In [434]:
aa.reshape(proj_dist.shape)

tensor([[-0.2833, -0.2513, -0.2806, -0.2768, -0.2520, -0.2558, -0.2061, -0.2520,
         -0.2088, -0.1877, -0.2118, -0.2021, -0.1984, -0.1570, -0.1419, -0.1371,
         -0.1482, -0.1268, -0.1143, -0.0901, -0.0747, -0.0621, -0.0438, -0.0275,
         -0.0157,  0.0002,  0.0007,  0.0014,  0.0025,  0.0038,  0.0060,  0.0079,
          0.0106,  0.0121,  0.0159,  0.0212,  0.0237,  0.0225,  0.0311,  0.0367,
          0.0386,  0.0417,  0.0591,  0.0558,  0.0641,  0.0616,  0.0801,  0.0865,
          0.0859,  0.0886,  0.0000],
        [-0.2834, -0.2518, -0.2803, -0.2777, -0.2515, -0.2560, -0.2057, -0.2515,
         -0.2083, -0.1877, -0.2122, -0.2025, -0.1990, -0.1573, -0.1420, -0.1373,
         -0.1483, -0.1268, -0.1144, -0.0900, -0.0744, -0.0621, -0.0438, -0.0275,
         -0.0157,  0.0002,  0.0007,  0.0014,  0.0025,  0.0038,  0.0061,  0.0079,
          0.0106,  0.0122,  0.0159,  0.0212,  0.0237,  0.0226,  0.0312,  0.0367,
          0.0385,  0.0417,  0.0591,  0.0559,  0.0641,  0.0614,  0.0802, 

In [None]:
nd.

In [432]:
proj_dist

tensor([[-0.2833, -0.2513, -0.2806, -0.2768, -0.2520, -0.2558, -0.2061, -0.2520,
         -0.2088, -0.1877, -0.2118, -0.2021, -0.1984, -0.1570, -0.1419, -0.1371,
         -0.1482, -0.1268, -0.1143, -0.0901, -0.0747, -0.0621, -0.0438, -0.0275,
         -0.0157,  0.0002,  0.0007,  0.0014,  0.0025,  0.0038,  0.0060,  0.0079,
          0.0106,  0.0121,  0.0159,  0.0212,  0.0237,  0.0225,  0.0311,  0.0367,
          0.0386,  0.0417,  0.0591,  0.0558,  0.0641,  0.0616,  0.0801,  0.0865,
          0.0859,  0.0886,  0.0000],
        [-0.2834, -0.2518, -0.2803, -0.2777, -0.2515, -0.2560, -0.2057, -0.2515,
         -0.2083, -0.1877, -0.2122, -0.2025, -0.1990, -0.1573, -0.1420, -0.1373,
         -0.1483, -0.1268, -0.1144, -0.0900, -0.0744, -0.0621, -0.0438, -0.0275,
         -0.0157,  0.0002,  0.0007,  0.0014,  0.0025,  0.0038,  0.0061,  0.0079,
          0.0106,  0.0122,  0.0159,  0.0212,  0.0237,  0.0226,  0.0312,  0.0367,
          0.0385,  0.0417,  0.0591,  0.0559,  0.0641,  0.0614,  0.0802, 

In [401]:
(next_dist * (u.float() - b)).view(-1)

tensor([-0.1417, -0.1257, -0.1403, -0.1384, -0.1260, -0.1279, -0.1030, -0.1260,
        -0.1044, -0.0938, -0.1059, -0.1010, -0.0992, -0.0785, -0.0710, -0.0685,
        -0.0741, -0.0634, -0.0572, -0.0450, -0.0374, -0.0310, -0.0219, -0.0138,
        -0.0079,  0.0000,  0.0001,  0.0003,  0.0007,  0.0012,  0.0019,  0.0030,
         0.0039,  0.0053,  0.0061,  0.0080,  0.0106,  0.0118,  0.0113,  0.0156,
         0.0184,  0.0193,  0.0209,  0.0296,  0.0279,  0.0321,  0.0308,  0.0400,
         0.0432,  0.0429,  0.0443, -0.1417, -0.1259, -0.1402, -0.1388, -0.1258,
        -0.1280, -0.1029, -0.1257, -0.1042, -0.0939, -0.1061, -0.1012, -0.0995,
        -0.0786, -0.0710, -0.0687, -0.0742, -0.0634, -0.0572, -0.0450, -0.0372,
        -0.0310, -0.0219, -0.0138, -0.0078,  0.0000,  0.0001,  0.0003,  0.0007,
         0.0012,  0.0019,  0.0030,  0.0039,  0.0053,  0.0061,  0.0080,  0.0106,
         0.0118,  0.0113,  0.0156,  0.0183,  0.0193,  0.0209,  0.0295,  0.0280,
         0.0321,  0.0307,  0.0401,  0.04

In [436]:
proj_dist.shape

torch.Size([3, 51])

In [492]:
a = np.zeros((3,51))

In [494]:
b = (l + offset).view(-1).numpy()
b = b.astype(int)
c = np.arange(3*51).reshape(-1)

In [497]:
a.shape

(3, 51)

In [498]:
np.put(a,b,c)

In [499]:
a

array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
         22.,  23.,  24.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,
         34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,
         45.,  46.,  47.,  48.,  49.,  50.,   0.],
       [ 51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,  60.,  61.,
         62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.,  71.,  72.,
         73.,  74.,  75.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,
         85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,
         96.,  97.,  98.,  99., 100., 101.,   0.],
       [102., 103., 104., 105., 106., 107., 108., 109., 110., 111., 112.,
        113., 114., 115., 116., 117., 118., 119., 120., 121., 122., 123.,
        124., 125., 126., 128., 129., 130., 131., 132., 133., 134., 135.,
        136., 137., 138., 139., 140., 141., 142., 143., 144., 145., 146.,
        14

In [485]:
b

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,
        38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  76,
        77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
        90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 127, 128,
       129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
       142, 143, 144, 145, 146, 147, 148, 149, 150, 151])

In [486]:
c

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152])

In [463]:
(u.float() - b)

TypeError: sub() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
 * (Tensor other, Number alpha)
 * (Number other, Number alpha)


In [442]:
(next_dist * (u.float() - b)).view(-1)

TypeError: sub() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
 * (Tensor other, Number alpha)
 * (Number other, Number alpha)


In [403]:
next_dist.shape

torch.Size([3, 51])

In [None]:
nd.utils.

In [402]:
proj_dist.shape

torch.Size([3, 51])

In [404]:
3 * 51

153

In [405]:
(l + offset).view(-1).long()

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  25,  26,
         27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
         41,  42,  43,  44,  45,  46,  47,  48,  49,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  76,  77,  78,  79,  80,  81,  82,
         83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,
         97,  98,  99, 100, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
        139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151])

In [413]:
x = torch.zeros(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
index = torch.tensor([0, 4, 2])

In [414]:
x.shape

torch.Size([5, 3])

In [419]:
x.view(-1).index_add_(0,index,t)

RuntimeError: Expected object of type torch.LongTensor but found type torch.FloatTensor for argument #3 'index'

In [422]:
x = torch.zeros(next_dist.shape)
t = (next_dist * (u.float() - b)).view(-1)
index = (l + offset).view(-1).long()

In [425]:
x.view(-1).shape

torch.Size([153])

In [426]:
x.view(-1).index_add_(0,index,t).shape

torch.Size([153])

In [362]:
temp = next_dist * (u - b)

In [363]:
temp.shape

(3, 51)

In [342]:
x.index_add_(0, index, t)

tensor([[1., 2., 3.],
        [0., 0., 0.],
        [7., 8., 9.],
        [0., 0., 0.],
        [4., 5., 6.]])

In [354]:
(l+offset).reshape(-1)


[  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.
  14.  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.  25.  25.  26.
  27.  28.  29.  30.  31.  32.  33.  34.  35.  36.  37.  38.  39.  40.
  41.  42.  43.  44.  45.  46.  47.  48.  49.  51.  52.  53.  54.  55.
  56.  57.  58.  59.  60.  61.  62.  63.  64.  65.  66.  67.  68.  69.
  70.  71.  72.  73.  74.  75.  76.  76.  77.  78.  79.  80.  81.  82.
  83.  84.  85.  86.  87.  88.  89.  90.  91.  92.  93.  94.  95.  96.
  97.  98.  99. 100. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.
 112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.
 126. 127. 127. 128. 129. 130. 131. 132. 133. 134. 135. 136. 137. 138.
 139. 140. 141. 142. 143. 144. 145. 146. 147. 148. 149. 150. 151.]
<NDArray 153 @gpu(0)>

In [357]:
(l+offset)


[[  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.
   14.  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.  25.  25.  26.
   27.  28.  29.  30.  31.  32.  33.  34.  35.  36.  37.  38.  39.  40.
   41.  42.  43.  44.  45.  46.  47.  48.  49.]
 [ 51.  52.  53.  54.  55.  56.  57.  58.  59.  60.  61.  62.  63.  64.
   65.  66.  67.  68.  69.  70.  71.  72.  73.  74.  75.  76.  76.  77.
   78.  79.  80.  81.  82.  83.  84.  85.  86.  87.  88.  89.  90.  91.
   92.  93.  94.  95.  96.  97.  98.  99. 100.]
 [102. 103. 104. 105. 106. 107. 108. 109. 110. 111. 112. 113. 114. 115.
  116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126. 127. 127. 128.
  129. 130. 131. 132. 133. 134. 135. 136. 137. 138. 139. 140. 141. 142.
  143. 144. 145. 146. 147. 148. 149. 150. 151.]]
<NDArray 3x51 @gpu(0)>

In [355]:
offset


[[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
    0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
    0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
    0.   0.   0.   0.   0.   0.   0.   0.   0.]
 [ 51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.
   51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.
   51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.  51.
   51.  51.  51.  51.  51.  51.  51.  51.  51.]
 [102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102.
  102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102.
  102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102. 102.
  102. 102. 102. 102. 102. 102. 102. 102. 102.]]
<NDArray 3x51 @gpu(0)>

In [350]:
l + offset.reshape(-1)

MXNetError: [16:30:08] src/operator/tensor/./elemwise_binary_broadcast_op.h:68: Check failed: l == 1 || r == 1 operands could not be broadcast together with shapes [3,51] [153]

Stack trace returned 10 entries:
[bt] (0) /opt/venv/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x382d4a) [0x7f0214f3cd4a]
[bt] (1) /opt/venv/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x383381) [0x7f0214f3d381]
[bt] (2) /opt/venv/lib/python3.5/site-packages/mxnet/libmxnet.so(+0xea8c27) [0x7f0215a62c27]
[bt] (3) /opt/venv/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x2b8316a) [0x7f021773d16a]
[bt] (4) /opt/venv/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x2b8d2a9) [0x7f02177472a9]
[bt] (5) /opt/venv/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x2aa4a49) [0x7f021765ea49]
[bt] (6) /opt/venv/lib/python3.5/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7f021765f03f]
[bt] (7) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7f02fb48ae20]
[bt] (8) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call+0x2eb) [0x7f02fb48a88b]
[bt] (9) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(_ctypes_callproc+0x49a) [0x7f02fb48501a]



In [344]:
x.shape

torch.Size([5, 3])

In [343]:
x

tensor([[1., 2., 3.],
        [0., 0., 0.],
        [7., 8., 9.],
        [0., 0., 0.],
        [4., 5., 6.]])

In [253]:
xx = nd.on


[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
  2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
  2. 2. 2.]]
<NDArray 3x51 @cpu(0)>

In [None]:
np.exp2()

In [230]:
aa = torch.from_numpy(np.arange(3))

In [321]:
torch.linspace(0, (batch_size - 1) * num_atoms, batch_size).long()\
                    .unsqueeze(1).expand(batch_size, num_atoms).shape

torch.Size([3, 51])

In [238]:
torch.from_numpy(next_dist.asnumpy()).shape

torch.Size([3, 51])

In [241]:
aa.unsqueeze(1).shape

torch.Size([3, 1])

In [None]:
aa.expand_as()

In [None]:
aa.ex

In [246]:
aa.unsqueeze(1).expand_as(torch.from_numpy(next_dist.asnumpy())).shape

torch.Size([3, 51])

In [243]:
torch.from_numpy(next_dist.asnumpy()).shape

torch.Size([3, 51])

In [183]:
na = nd.stack(nd.arange(next_action.shape[0], ctx=ctx).expand_dims(-1),next_action.expand_dims(-1), axis=0)

In [204]:
next_action = nd.gather_nd(next_dist, na)

In [207]:
next_action = next_action.squeeze()

In [None]:
nd.expand_dims()

In [213]:
nd.array(reward, ctx=ctx).expand_dims(1).shape

(3, 1)

In [214]:
next_dist.shape

(3, 6, 51)

In [198]:
x = torch.from_numpy(next_dist.asnumpy())
a = torch.from_numpy(next_action.asnumpy()).long()

In [199]:
a = a.unsqueeze(1).unsqueeze(1).expand(x.size(0), 1, x.size(2))

In [200]:
type(a)

torch.Tensor

In [202]:
x.gather(1, a)

tensor([[[-0.1861, -0.1757, -0.1683, -0.1685, -0.1743, -0.1356, -0.1661,
          -0.1331, -0.1156, -0.1103, -0.1309, -0.1067, -0.1055, -0.0976,
          -0.0795, -0.0660, -0.0740, -0.0688, -0.0576, -0.0502, -0.0343,
          -0.0313, -0.0229, -0.0154, -0.0077,  0.0000,  0.0078,  0.0145,
           0.0247,  0.0351,  0.0351,  0.0450,  0.0511,  0.0706,  0.0729,
           0.0838,  0.0848,  0.0972,  0.1049,  0.1100,  0.1070,  0.1374,
           0.1574,  0.1621,  0.1479,  0.1736,  0.1554,  0.1749,  0.1818,
           0.1620,  0.2137]],

        [[-0.1863, -0.1755, -0.1682, -0.1684, -0.1743, -0.1356, -0.1661,
          -0.1330, -0.1154, -0.1104, -0.1308, -0.1068, -0.1054, -0.0975,
          -0.0795, -0.0660, -0.0740, -0.0688, -0.0577, -0.0502, -0.0343,
          -0.0313, -0.0229, -0.0154, -0.0077,  0.0000,  0.0078,  0.0145,
           0.0247,  0.0352,  0.0352,  0.0449,  0.0511,  0.0706,  0.0729,
           0.0839,  0.0848,  0.0972,  0.1049,  0.1100,  0.1069,  0.1375,
           0.1574,  

In [215]:
x.shape

torch.Size([3, 6, 51])

In [216]:
a.shape

torch.Size([3, 1, 51])

In [218]:
a.expand_as(x).shape

torch.Size([3, 6, 51])

In [121]:
    batch_size = next_state.shape[0]
    
    delta_z = float(Vmax - Vmin) / (num_atoms - 1)
    support = nd.array(np.linspace(Vmin, Vmax, num_atoms), ctx=ctx)

In [136]:
nd.expand_dims()

NameError: name 'target_model' is not defined

In [None]:
next_dist = target_model(next_state) * support

In [87]:
env.action_space.n

6

In [88]:
num_atoms = 51
Vmin = -10
Vmax = 10

net = RainbowDQN(env.observation_space.shape, env.action_space.n, num_atoms, Vmin, Vmax, ctx)
net.initialize(ctx=ctx)

In [94]:
env.reset()
state, reward, done, _ = env.step(0)

In [91]:
num_atoms

51

In [100]:
next_dist = net(nd.array(np.float32(state), ctx=ctx).expand_dims(0))

In [105]:
next_dist.shape

(1, 6, 51)

In [103]:
next_action = nd.argmax(next_dist.sum(2),1)

In [106]:
next_dist.sum(2)


[[1.         0.99999994 1.         1.         1.         0.9999999 ]]
<NDArray 1x6 @gpu(0)>

In [107]:
import torch

In [108]:
x = torch.rand(size=(1,6,51))

In [112]:
a = x.sum(2).max(1)[1]

In [117]:
a = a.unsqueeze(1).unsqueeze(1).expand(next_dist.shape[0], 1, next_dist.shape[2])

In [118]:
a.shape

torch.Size([1, 1, 51])

### For test

In [129]:
replay_initial = 100
replay_buffer  = ReplayBuffer(100000)

In [130]:
len(replay_buffer)

0

In [131]:
state = env.reset()
for frame_idx in range(1, 1000000 + 1):
    action = 1
    
    next_state, reward, done, _ = env.step(action)
    replay_buffer.push(state, action, reward, next_state, done)
    if len(replay_buffer) > replay_initial:
        break

In [132]:
replay_buffer

<__main__.ReplayBuffer at 0x7f02b483e898>