### DQN

In [1]:
import cv2,gym,time,psutil,random
from gym import envs
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import mediapy as media
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque
from moviepy.video.io.bindings import mplfig_to_npimage
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
np.set_printoptions(precision=2,suppress=True)
plt.rc('xtick',labelsize=8); plt.rc('ytick',labelsize=8)
print ("gym:[%s]"%(gym.__version__))
print ("numpy:[%s]"%(np.__version__))
print ("matplotlib:[%s]"%(matplotlib.__version__))
print ("Pytorch:[%s]"%(torch.__version__))

gym:[0.26.2]
numpy:[1.22.4]
matplotlib:[3.7.1]
Pytorch:[2.0.1]


### Define DQN class

In [2]:
class DQNNetwork(nn.Module):
    def __init__(self,odim,adim,hdims=128):
        super(DQNNetwork, self).__init__()
        self.odim = odim
        self.adim = adim
        self.hdims = hdims
        self.fc1 = nn.Linear(self.odim, self.hdims)
        self.fc2 = nn.Linear(self.hdims, self.hdims)
        self.fc3 = nn.Linear(self.hdims, self.adim)
        
    def forward(self, x):
        x = torch.FloatTensor(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ReplayBuffer:
    def __init__(self, buffer_size, odim, adim, batch_size=32):
        self.buffer_size = buffer_size
        self.buffer = deque(maxlen=self.buffer_size)
        self.batch_size = batch_size
        self.odim = odim
        self.adim = adim

    def append(self, o, a, r, o1, d):
        self.buffer.append([o, a, r, o1, d])

    def sample(self):
        size = self.batch_size if len(self.buffer) > self.batch_size else len(self.buffer)
        minibatch = np.transpose(random.sample(self.buffer, size))
        b = []
        for a in range(5):  
            if type(minibatch[a][0]) == np.ndarray:
                b.append(np.concatenate(minibatch[a]).reshape(size, -1))
            else:
                b.append(minibatch[a])
        return b[0], b[1], b[2], b[3], b[4]

def get_envs():
    env = gym.make('CartPole-v1',render_mode='rgb_array')
    eval_env = gym.make('CartPole-v1',render_mode='rgb_array')
    return env,eval_env

def plot_env(env,figsize=(4,4),title_str=None,title_fs=10,
             PLOT_IMG=True,RETURN_IMG=False):
    img = env.render()
    fig = plt.figure(figsize=figsize)
    plt.imshow(img)
    plt.axis('off')
    if title_str is not None:
        plt.title(title_str,fontsize=title_fs)
    if PLOT_IMG:
        plt.show()
    # (Optional) Get image
    if RETURN_IMG:
        img = mplfig_to_npimage(fig)
        plt.close()
        return img
    
class Agent(object):
    def __init__(self,hdims=128,
                 gamma=0.98 ,buffer_size=50000,batch_size=64,learning_rate=0.0005):
        
        # Environment
        self.env, self.eval_env = get_envs()
        self.odim = self.env.observation_space.shape[0]
        self.adim = self.env.action_space.n
        print ("odim:[%d] adim:[%d]"%(self.odim,self.adim))

        # Network
        self.main_network = DQNNetwork(self.odim,self.adim,hdims=hdims)
        self.target_network = DQNNetwork(self.odim,self.adim,hdims=hdims)
        self.gamma = gamma
        self.optimizer = optim.Adam(self.main_network.parameters(),lr=learning_rate)
        self.eps = 1.0

        # Buffer (Memory)
        self.buffer = ReplayBuffer(
            buffer_size=buffer_size,odim=self.odim,adim=self.adim,batch_size=batch_size)

    def getQ(self, obs):
        Q = self.main_network(obs)
        return Q

    def update_main_network(self, o_batch, a_batch, r_batch, o1_batch, d_batch):
        o1_q = self.target_network(o1_batch)
        max_o1_q = o1_q.max(1)[0].detach().numpy()
        d_batch = d_batch.astype(int)
        expected_q = r_batch + self.gamma*max_o1_q*(1.0-d_batch)
        expected_q = expected_q.astype(np.float64) # R + gamma*max(Q)
        expected_q = torch.from_numpy(expected_q)
        main_q = self.main_network(o_batch).max(1)[0]
        loss = F.smooth_l1_loss(main_q.float(), expected_q.float())
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss

    def update_target_network(self):
        self.target_network.load_state_dict(self.main_network.state_dict()) # simple copy
    
    def train(self,max_epoch=200,warmup=2000,
              update_every=10,evaluate_every=20,max_ep_len=500):
        start_time = time.time()
        o,_ = self.env.reset()
        r, d, ep_ret, ep_len, n_env_step = 0, False, 0, 0, 0
        for epoch in range(max_epoch):
            o,_ = self.env.reset()
            d, ep_ret, ep_len = False, 0, 0
              
            while not d:
                if np.random.uniform(0,1,1).item() < self.eps:
                    a = self.env.action_space.sample() # random action
                else:
                    Q = self.getQ(o.reshape(1, -1))
                    a = Q.argmax().item()
                o1, r, d, _, _ = self.env.step(a)
                ep_len += 1
                ep_ret += r
                n_env_step += 1
                # Maximum length handler
                if ep_len >= max_ep_len:
                    break
                # Save the Experience to our buffer
                self.buffer.append(o, a, r, o1, d)
                o = o1
                if(len(self.buffer.buffer) > warmup):
                    # Update the main network
                    if self.eps > 0.01:
                        self.eps *= 0.9995 # reduce epsilon
                    o_batch, a_batch, r_batch, o1_batch, d_batch = self.buffer.sample()
                    self.update_main_network(o_batch,a_batch,r_batch,o1_batch,d_batch)
                    
            # Update target network
            if ((epoch + 1) % update_every):
                self.update_target_network()
                
            # Evaluate
            if (epoch==0) or (((epoch+1) % evaluate_every)==0):
                ram_percent = psutil.virtual_memory().percent  # memory usage
                print("[Eval. start] step:[%d/%d][%.1f%%] #step:[%.1e] buffer:[%d] "\
                      " eps:[%.3f] time:[%s] ram:[%.1f%%]." %
                      (epoch + 1, max_epoch, (epoch+1)/max_epoch*100.0,
                       n_env_step,len(self.buffer.buffer),self.eps,
                       time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time)),
                       ram_percent)
                      )
                o,_ = self.eval_env.reset()
                d, ep_ret, ep_len = False, 0, 0
                frames = []
                while not d:
                    Q = self.getQ(o.reshape(1, -1))
                    a = torch.argmax(torch.squeeze(Q),0).numpy() # argmax eval policy
                    o, r, d, _, _ = self.eval_env.step(a)
                    # frame = self.eval_env.render()
                    title_str = 'epoch:[%d/%d] tick:[%d/%d]'%(epoch,max_epoch,ep_len,max_ep_len)
                    frame = plot_env(self.eval_env,figsize=(4,3),title_str=title_str,title_fs=8,
                                     PLOT_IMG=False,RETURN_IMG=True)
                    frames.append(frame)
                    ep_ret += r # return
                    ep_len += 1 # length 
                    if ep_len >= max_ep_len: break
                print("[Eval. done] ep_ret:[%.1f] ep_len:[%d/%d]"%
                      (ep_ret,ep_len,max_ep_len))
                # Display
                media.show_video(frames, fps=10)
print ("Ready.")

Ready.


### Run DQN

In [3]:
np.random.seed(seed=0)
torch.manual_seed(0)
DQN = Agent(hdims=128,gamma=0.99,buffer_size=20000,batch_size=128,learning_rate=0.0005)    
DQN.train(max_epoch=1000,warmup=2000,update_every=5,evaluate_every=100)
print ("Done.")

odim:[4] adim:[2]
[Eval. start] step:[1/1000][0.1%] #step:[1.4e+01] buffer:[14]  eps:[1.000] time:[00:00:00] ram:[38.3%].
[Eval. done] ep_ret:[84.0] ep_len:[84/500]


0
This browser does not support the video tag.


  result = getattr(asarray(obj), method)(*args, **kwds)


[Eval. start] step:[100/1000][10.0%] #step:[2.3e+03] buffer:[2319]  eps:[0.853] time:[00:00:02] ram:[38.4%].
[Eval. done] ep_ret:[10.0] ep_len:[10/500]


0
This browser does not support the video tag.


[Eval. start] step:[200/1000][20.0%] #step:[4.3e+03] buffer:[4333]  eps:[0.311] time:[00:00:04] ram:[38.4%].
[Eval. done] ep_ret:[10.0] ep_len:[10/500]


0
This browser does not support the video tag.


[Eval. start] step:[300/1000][30.0%] #step:[1.1e+04] buffer:[10812]  eps:[0.012] time:[00:00:10] ram:[38.3%].
[Eval. done] ep_ret:[138.0] ep_len:[138/500]


0
This browser does not support the video tag.


[Eval. start] step:[400/1000][40.0%] #step:[3.0e+04] buffer:[20000]  eps:[0.010] time:[00:00:29] ram:[38.6%].
[Eval. done] ep_ret:[330.0] ep_len:[330/500]


0
This browser does not support the video tag.


[Eval. start] step:[500/1000][50.0%] #step:[7.0e+04] buffer:[20000]  eps:[0.010] time:[00:01:11] ram:[38.9%].
[Eval. done] ep_ret:[500.0] ep_len:[500/500]


0
This browser does not support the video tag.


[Eval. start] step:[600/1000][60.0%] #step:[1.2e+05] buffer:[20000]  eps:[0.010] time:[00:02:06] ram:[39.8%].
[Eval. done] ep_ret:[500.0] ep_len:[500/500]


0
This browser does not support the video tag.


[Eval. start] step:[700/1000][70.0%] #step:[1.7e+05] buffer:[20000]  eps:[0.010] time:[00:03:04] ram:[40.4%].
[Eval. done] ep_ret:[500.0] ep_len:[500/500]


0
This browser does not support the video tag.


[Eval. start] step:[800/1000][80.0%] #step:[2.2e+05] buffer:[20000]  eps:[0.010] time:[00:04:03] ram:[40.5%].
[Eval. done] ep_ret:[500.0] ep_len:[500/500]


0
This browser does not support the video tag.


[Eval. start] step:[900/1000][90.0%] #step:[2.7e+05] buffer:[20000]  eps:[0.010] time:[00:05:01] ram:[42.3%].
[Eval. done] ep_ret:[500.0] ep_len:[500/500]


0
This browser does not support the video tag.


[Eval. start] step:[1000/1000][100.0%] #step:[3.2e+05] buffer:[20000]  eps:[0.010] time:[00:05:54] ram:[43.0%].
[Eval. done] ep_ret:[500.0] ep_len:[500/500]


0
This browser does not support the video tag.


Done.
