In [71]:
import time
from collections import deque

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

import gym

from replay_buffer import ReplayBuffer

In [72]:
device = 'cuda'

In [73]:
class DQN(nn.Module):
    def __init__(self,N_s,N_a):
        super().__init__()
        self.N_s = N_s
        self.N_a = N_a
        
        self.layer1 = nn.Linear(N_s,200)
        self.layer2 = nn.Linear(200,200)
        self.layer3 = nn.Linear(200,N_a)
    
    def forward(self,s):
        Q = F.leaky_relu(self.layer1(s))
        Q = F.leaky_relu(self.layer2(Q))
        Q = self.layer3(Q)
        
        return Q
    
    def epsilon_greedy(self,s,eps=1e-1):
        assert s.shape == (self.N_s,)
        assert eps*self.N_a <= 1
        
        with torch.no_grad():
            Q = self.forward(s)
        
        a = torch.argmax(Q)
        p = [eps]*self.N_a
        p[a] += (1 - (eps * self.N_a))
        return np.random.choice(np.arange(N_a),p=p)

In [74]:
N_s = 5
N_a = 3
N_batch = 30

dqn = DQN(N_s,N_a).to(device)
target_dqn = DQN(N_s,N_a).to(device)
target_dqn.load_state_dict(dqn.state_dict())

<All keys matched successfully>

In [None]:
def run_episode(env, buffer, policy):
    s = env.reset()
    done = 0
    while(not done):
        T = [s]
        s = torch.tensor(s,dtype = torch.float32).to(device)
        a = policy(s)
        ns, r, done, info = env.step(a)

In [27]:
env = gym.make('CartPole-v0')


for episode in range(1):
    s = env.reset()
    done = 0
    while(not done):
        ns, r, done, info = env.step(0)
        s = ns.copy()
        print(ns)

[ 0.00160841 -0.18543708  0.01611761  0.28593032]
[-0.00210033 -0.38078514  0.02183622  0.58365272]
[-0.00971603 -0.57620608  0.03350927  0.88313359]
[-0.02124016 -0.77176669  0.05117194  1.18615971]
[-0.03667549 -0.96751356  0.07489514  1.49443365]
[-0.05602576 -1.1634623   0.10478381  1.80953138]
[-0.07929501 -1.35958492  0.14097444  2.13285186]
[-0.1064867  -1.55579494  0.18363147  2.4655565 ]
[-0.1376026  -1.75193015  0.2329426   2.80849766]


In [69]:
dqn.epsilon_greedy(S[0])

tensor(2, device='cuda:0')


2

<All keys matched successfully>

In [54]:
S = np.random.uniform(size = (N_batch,N_s))
S = torch.tensor(S,dtype=torch.float32).to(device)

In [55]:
NS = np.random.uniform(size = (N_batch,N_s))
NS = torch.tensor(NS,dtype=torch.float32).to(device)

In [56]:
A = np.random.randint(0,N_a,size = (N_batch,) )

In [57]:
R = np.ones((N_batch))
R = torch.tensor(R,dtype=torch.float32).to(device)

In [58]:
D = np.random.randint(0,2,size = (N_batch,))
D = torch.tensor(D,dtype=torch.bool).to(device)

In [59]:
Q

tensor([[-5.9198e-02,  8.6597e-02, -1.0483e-01],
        [-7.1537e-02,  5.2607e-02, -1.5029e-01],
        [-6.7024e-02, -1.4482e-02, -1.9680e-01],
        [-8.9053e-02,  4.5654e-03, -2.1490e-01],
        [-5.2742e-02,  1.6722e-02, -1.2532e-01],
        [-5.8985e-02, -1.8056e-03, -1.8036e-01],
        [-6.3753e-02, -8.7462e-05, -1.6380e-01],
        [-6.3945e-02,  2.9904e-02, -1.6545e-01],
        [-3.9703e-02,  8.0359e-02, -8.8068e-02],
        [-6.6624e-02, -6.1610e-04, -1.6017e-01],
        [-6.4011e-02,  1.7999e-02, -1.5048e-01],
        [-6.3608e-02,  7.4373e-02, -1.6744e-01],
        [-5.3363e-02,  9.2721e-02, -1.1153e-01],
        [-3.3932e-02,  2.1760e-02, -1.1974e-01],
        [-5.2181e-02, -1.2668e-02, -1.3614e-01],
        [-7.3365e-02, -1.7393e-02, -1.8671e-01],
        [-7.0357e-02,  3.4960e-02, -1.4288e-01],
        [-6.5572e-02,  9.2293e-02, -1.3510e-01],
        [-5.7824e-02,  3.6502e-02, -1.2296e-01],
        [-7.1165e-02,  6.2452e-03, -1.9132e-01],
        [-6.2032e-02

In [44]:
torch.argmax(Q,dim=-1)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 0], device='cuda:0')

In [45]:
Q = dqn(S)
Q.shape

torch.Size([30, 3])

In [46]:
with torch.no_grad():
    NQ = target_dqn(NS)
NQ.shape

torch.Size([30, 3])

In [13]:
QA = Q[np.arange(N_batch),A]

In [14]:
Q_target = torch.max(NQ,dim=1)[0]

In [15]:
QA

tensor([ 0.0647,  0.0839,  0.1244,  0.0551,  0.1619,  0.0505,  0.1600,  0.0074,
         0.1549,  0.0695,  0.0614,  0.1065,  0.0962,  0.0619,  0.0641,  0.1877,
         0.0688,  0.0783,  0.0745,  0.1017, -0.0111,  0.1005, -0.0189, -0.0024,
         0.0627,  0.0991,  0.0816,  0.0767,  0.1172, -0.0202], device='cuda:0',
       grad_fn=<IndexBackward>)

In [16]:
Q_target

tensor([0.1578, 0.1067, 0.1233, 0.0711, 0.1504, 0.1137, 0.0991, 0.1717, 0.1301,
        0.1638, 0.1575, 0.1141, 0.0769, 0.1374, 0.1277, 0.1003, 0.1057, 0.1040,
        0.1208, 0.1615, 0.0832, 0.0760, 0.0760, 0.1204, 0.1752, 0.1454, 0.1262,
        0.0995, 0.1052, 0.1031], device='cuda:0')

In [17]:
discount = 0.99

In [18]:
Q_target = R + (~D)*discount*Q_target

In [19]:
Q_target

tensor([1.0000, 1.0000, 1.1220, 1.0704, 1.0000, 1.0000, 1.0000, 1.1700, 1.1288,
        1.0000, 1.0000, 1.0000, 1.0762, 1.0000, 1.0000, 1.0000, 1.0000, 1.1030,
        1.1196, 1.0000, 1.0823, 1.0752, 1.0753, 1.1192, 1.0000, 1.1439, 1.0000,
        1.0000, 1.1041, 1.1021], device='cuda:0')

In [20]:
F.mse_loss(QA,Q_target)

tensor(0.9529, device='cuda:0', grad_fn=<MseLossBackward>)

In [23]:
def loss(batch,Q,NQ,discount):
    R = batch['reward']
    R = torch.tensor(R,dtype=torch.float32).to(device)
    D = batch['done']
    D = torch.tensor(D,dtype=torch.bool).to(device)
    
    
    QA = Q[np.arange(N_batch),batch['action']]
    Q_target = torch.max(NQ,dim=1).values
    Q_target = R + (~D)*discount*Q_target
    
    return F.mse_loss(QA,Q_target)
    
    
    
    

In [22]:
buffer = ReplayBuffer(max_len = 5)
for i in range(10):
    T = [np.array([i,i]),np.array([-i,-i]),np.array([2*i,2*i,2*i]),i,i%2]
    buffer.append(T)
buffer.sample(2)

{'state': array([[9., 9.],
        [9., 9.]]),
 'next_state': array([[-9., -9.],
        [-9., -9.]]),
 'action': array([[18., 18., 18.],
        [18., 18., 18.]]),
 'reward': array([9., 9.]),
 'done': array([1., 1.])}