In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm.auto import tqdm
import gymnasium as gym
from dqn import DQN
from replay_buffer import ReplayBuffer

env = gym.make('CartPole-v1')

In [2]:
rb = ReplayBuffer(8,1e5,5)

In [3]:
N_ACTIONS = env.action_space.n
STATE_DIM = env.observation_space.shape[0]
HIDDEN_SIZE = 128

def update_target_model(q_online, q_target):
    online_sd = q_online.state_dict()
    q_target.load_state_dict(online_sd)
    return q_online, q_target

q_online = DQN(state_space=STATE_DIM, action_space=N_ACTIONS, hidden_size=HIDDEN_SIZE)
q_target = DQN(state_space=STATE_DIM, action_space=N_ACTIONS, hidden_size=HIDDEN_SIZE)
q_online, q_target = update_target_model(q_online, q_target)


def epsilon_greedy(state, eps):
    if np.random.uniform() < eps:  # explore
        action = env.action_space.sample()
        print('random action')
        return action
    else:  # exploit
        state = torch.from_numpy(state).float().unsqueeze(0)  # batch, state
        q_online.eval()
        with torch.no_grad():
            action_q_values = q_online(state)
        q_online.train()
        print('model action')
        return action_q_values.argmax(dim=1).item()  # best action with highest q-value

In [4]:
eps=1.

In [5]:
state, info = env.reset()
steps_taken = 0
for _ in range(1000):
    action = epsilon_greedy(state, eps)
    next_state, reward, terminated, truncated, info = env.step(action)
    rb.add(state, action, reward, next_state, truncated or terminated)
    state = next_state
    eps = max(0.1, eps*0.995)
    steps_taken+=1
    if truncated or terminated:
        break
print(steps_taken)

random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
model action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
random action
32


In [6]:
rb.is_ready()

True

In [7]:
batch = rb.sample()

In [8]:
batch

{'state': tensor([[ 0.0542,  0.2186, -0.0408, -0.0484],
         [-0.0379, -0.3572,  0.0180,  0.6181],
         [ 0.1487,  1.2012, -0.1417, -1.6802],
         [ 0.1286,  1.0049, -0.1146, -1.3540],
         [-0.0558,  0.6111,  0.0686, -0.6847],
         [-0.0483, -0.3579,  0.0370,  0.6333],
         [ 0.0460,  0.4132, -0.0342, -0.3301],
         [-0.0588, -0.3591,  0.0567,  0.6604]]),
 'action': tensor([0, 1, 1, 1, 1, 1, 0, 1]),
 'reward': tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
 'next_state': tensor([[ 0.0586,  0.0241, -0.0417,  0.2311],
         [-0.0451, -0.1623,  0.0304,  0.3312],
         [ 0.1727,  1.3977, -0.1753, -2.0135],
         [ 0.1487,  1.2012, -0.1417, -1.6802],
         [-0.0436,  0.8052,  0.0549, -0.9550],
         [-0.0555, -0.1633,  0.0497,  0.3525],
         [ 0.0542,  0.2186, -0.0408, -0.0484],
         [-0.0659, -0.1648,  0.0700,  0.3862]]),
 'done': tensor([0., 0., 0., 0., 0., 0., 0., 0.])}

In [9]:
states = batch['state']
actions = batch['action']
rewards = batch['reward']
next_states = batch['next_state']
dones = batch['done']

In [10]:
with torch.no_grad():
    next_q_values = q_target(next_states).max(dim=1)[0]  # max: values, indices. choosing values
    # The target values are the new q_values that the online model should converge to
    next_q_targets = rewards + 0.99 * (1 - dones) * next_q_values  # bellman equation

In [12]:
next_q_values

tensor([0.0916, 0.1100, 0.1773, 0.1499, 0.1050, 0.1119, 0.0712, 0.1143])

In [15]:
dones

tensor([0., 0., 0., 0., 0., 0., 0., 0.])

In [16]:
rewards

tensor([1., 1., 1., 1., 1., 1., 1., 1.])

In [14]:
next_q_targets

tensor([1.0907, 1.1089, 1.1755, 1.1484, 1.1040, 1.1107, 1.0705, 1.1131])

In [11]:
current_q_values = q_online(states).gather(1, actions.unsqueeze(1)).flatten()

In [21]:
qos = q_online(states)
qos

tensor([[0.0712, 0.0407],
        [0.1139, 0.0441],
        [0.1499, 0.0267],
        [0.1203, 0.0264],
        [0.0950, 0.0165],
        [0.1149, 0.0441],
        [0.0773, 0.0364],
        [0.1162, 0.0451]], grad_fn=<AddmmBackward0>)

In [20]:
actions.unsqueeze(1)

tensor([[0],
        [1],
        [1],
        [1],
        [1],
        [1],
        [0],
        [1]])

In [22]:
torch.gather(qos,1,actions.unsqueeze(1))

tensor([[0.0712],
        [0.0441],
        [0.0267],
        [0.0264],
        [0.0165],
        [0.0441],
        [0.0773],
        [0.0451]], grad_fn=<GatherBackward0>)

In [25]:
next_q_values

tensor([0.0916, 0.1100, 0.1773, 0.1499, 0.1050, 0.1119, 0.0712, 0.1143])