In [100]:
import gym
import time
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from pprint import pprint
import utils
import torch
from torch.utils.tensorboard import SummaryWriter
import os
from collections import deque
import random
from myrl.buffers import ReplayBuffer
from myrl.visualizer import showit
from myrl.utils import ExperimentWriter
from myrl.value_functions import polyak
from myrl.policies import RandomPolicy
from myrl.environments import Envs


wll = ExperimentWriter('tb/dqn')
envname = 'BipedalWalker-v3'
envname = 'LunarLander-v2'
envname = 'CartPole-v0'
envname = 'MountainCar-v0'
envname = 'Pendulum-v0'
env = gym.make(envname)
envs = Envs(envname, 1)

In [101]:
rbuff = ReplayBuffer(10000)

In [102]:
class DQN(nn.Module):
    def __init__(self, idim, hdim, odim):
        super().__init__()
        self.lin1 = nn.Linear(idim, hdim)
        self.lin2 = nn.Linear(hdim, odim)
        self.odim = odim
    def forward(self, x):
        h = self.lin1(x)
        h = F.relu(h)
        h = self.lin2(h)
        return h
    def act(self, obs, epsilon=0.1, debug=True):
        obs = torch.tensor(obs).float()
        qs = self.forward(obs)
        if random.uniform(0, 1) > epsilon:
            ii = torch.argmax(qs, dim=-1).unsqueeze(-1)
        else:
            try:
                qs.shape[3]
                ii = torch.randint(0, self.odim-1, (qs.shape[0], qs.shape[1], 1))
            except:
                ii = torch.randint(0, self.odim-1, (qs.shape[0], 1))     
        dummy = torch.tensor([[1]]).float()   
        return ii-2, (dummy, dummy, dummy)
    
print(env.action_space, env.observation_space)
adim = 5#env.action_space.shape[0]
sdim = env.observation_space.shape[0]
dqn = DQN(sdim, 64, adim)
import copy
tdqn = copy.deepcopy(dqn)
dqn

Box(1,) Box(3,)


DQN(
  (lin1): Linear(in_features=3, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=5, bias=True)
)

In [108]:
showit(env, shower)

198 /2000

In [104]:
opt = torch.optim.Adam(dqn.parameters(), lr=1e-3)

In [105]:
wll.new()
writer = wll.writer
bsize = 128
warmup = 100
gamma = 0.97

collector = lambda obs: dqn.act(obs, epsilon=0.15)
shower    = lambda obs: dqn.act(obs, epsilon=0)
random_policy = RandomPolicy(env).act


for ep in range(10000):
    obs = env.reset()
    if ep%1==0:
        pi = collector if len(rbuff)>bsize*5 else random_policy
        oldobs, a, r, obs, d, _, _, _ = envs.rollout(pi, length=1)
        rbuff.add(oldobs, a, r, obs, d)
        writer.add_scalar('dqn/reward', r.mean(), ep)
    rew = r.mean().item()

    if bsize*10 > len(rbuff):
        continue
    
    for jup in range(1):
        oldobs, a, r, obs, done = rbuff.get(bsize)
        for opt_step in range(3):
            target = torch.max(tdqn(obs), dim=-1)[0].detach().unsqueeze(-1) * (1-done)
            frst = torch.tensor(list(range(bsize)))
            calc = dqn(oldobs)
            calc = calc[frst, a.long().squeeze(-1)+2].unsqueeze(-1)

            loss = ((r + gamma*target - calc)**2).mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
    writer.add_scalar('dqn/model', loss.item(), ep)
    polyak(dqn, tdqn, 1-1/100)
    
    if ep%10==0:
        print(ep, loss.item(), rew)
    if ep%50==0:
        showit(env, shower)

1280 46.40803527832031 -5.026594161987305
1290 29.97684669494629 -14.123114585876465
1300 23.70888900756836 -1.9319820404052734
198 /20001310 17.432514190673828 -14.062799453735352
1320 12.013439178466797 -3.46389102935791
1330 9.182476997375488 -11.882407188415527
1340 7.755334854125977 -3.478599786758423
1350 6.43708610534668 -8.215517044067383
198 /20001360 6.627577304840088 -3.4475889205932617
1370 5.665998458862305 -7.576043128967285
1380 6.242976188659668 -0.6482446193695068
1390 3.8189520835876465 -8.679797172546387
1400 3.7721896171569824 -2.1802902221679688
198 /20001410 3.205242156982422 -8.24006175994873
1420 3.1337807178497314 -5.695961952209473
1430 2.2555627822875977 -8.612106323242188
1440 2.1644387245178223 -7.61298942565918
1450 1.919003963470459 -7.443322658538818
198 /20001460 1.514763593673706 -6.496237754821777
1470 2.141718626022339 -8.006667137145996
1480 1.5749154090881348 -8.265135765075684
1490 1.4167715311050415 -7.2437520027160645
1500 1.0324926376342773 -7.

KeyboardInterrupt: 

In [71]:
env.close()