In [1]:
import time
from collections import deque

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

import gym

from replay_buffer import ReplayBuffer

In [2]:
device = 'cuda'

In [3]:
class DQN(nn.Module):
    def __init__(self,N_s,N_a):
        super().__init__()
        self.N_s = N_s
        self.N_a = N_a
        
        self.layer1 = nn.Linear(N_s,200)
        self.layer2 = nn.Linear(200,200)
        self.layer3 = nn.Linear(200,N_a)
    
    def forward(self,s):
        Q = F.leaky_relu(self.layer1(s))
        Q = F.leaky_relu(self.layer2(Q))
        Q = self.layer3(Q)
        
        return Q
    
    def epsilon_greedy(self,s,eps=1e-1):
        assert s.shape == (self.N_s,)
        assert eps*self.N_a <= 1
        
        with torch.no_grad():
            Q = self.forward(s)
        
        a = torch.argmax(Q)
        p = [eps]*self.N_a
        p[a] += (1 - (eps * self.N_a))
        return np.random.choice(np.arange(N_a),p=p)

In [4]:
def run_episode(env, buffer, policy, render = False):
    s = env.reset()
    done = 0
    if render:
        env.render()
    while(not done):
        s_ = torch.tensor(s,dtype = torch.float32).to(device)
        a = np.array(policy(s_))
        ns, r, done, info = env.step(a)
        T = [s,ns,a,r,done]
        buffer.append(T)
        s = ns.copy()
        if render:
            env.render()

In [5]:
def get_loss(batch,Q,NQ,discount):
    R = batch['reward']
    R = torch.tensor(R,dtype=torch.float32).to(device)
    D = batch['done']
    D = torch.tensor(D,dtype=torch.bool).to(device)
    
    
    
    QA = Q[np.arange(N_batch),batch['action']]
    Q_target = torch.max(NQ,dim=1).values
    Q_target = R + (~D)*discount*Q_target
    
    return F.mse_loss(QA,Q_target)

In [6]:
N_s = 4
N_a = 2
N_batch = 64
tau = 0.01

dqn = DQN(N_s,N_a).to(device)
optimizer = torch.optim.Adam(dqn.parameters(),lr=1e-3)
target_dqn = DQN(N_s,N_a).to(device)
target_dqn.load_state_dict(dqn.state_dict())

buffer = ReplayBuffer(max_len = 1e6)

env = gym.make('CartPole-v0')


#run_episode(env, buffer, policy = lambda s:dqn.epsilon_greedy(s,eps=1e-1))

In [7]:
#Complete random action
for episode in tqdm(range(1000)):
    policy = lambda s:target_dqn.epsilon_greedy(s,eps=5e-1)
    run_episode(env, buffer, policy, render = False)

100%|██████████| 1000/1000 [00:07<00:00, 132.23it/s]


In [8]:
for episode in tqdm(range(3000)):
    # Run episode
    policy = lambda s:target_dqn.epsilon_greedy(s,eps=5e-2)
    run_episode(env, buffer, policy, render = False)
    
    # sample batch
    batch = buffer.sample(N_batch)
    S = torch.tensor(batch['state'],dtype=torch.float32).to(device)
    NS = torch.tensor(batch['next_state'],dtype=torch.float32).to(device)
    Q = dqn(S)
    with torch.no_grad():
        NQ = target_dqn(NS)
        
    # optimize
    optimizer.zero_grad()
    loss = get_loss(batch,Q,NQ,discount = 0.99)
    loss.backward()
    optimizer.step()
    
    #print(float(loss))

    # sync target network
    for param, target_param in zip(dqn.parameters(),target_dqn.parameters()):
        target_param.data.copy_( tau * param.data + (1-tau) * target_param.data )
    

100%|██████████| 3000/3000 [03:22<00:00, 14.79it/s]


In [10]:
for episode in tqdm(range(3)):
    policy = lambda s:target_dqn.epsilon_greedy(s,eps=0.)
    run_episode(env, buffer, policy, render = True)

env.close()

100%|██████████| 3/3 [00:08<00:00,  2.69s/it]
