Jupyterlab Shortcuts
- Shift + Enter : Run
- Enter/Esc : Mode change(Edit/Command)
- Fn + Up/Down : Move between cell 
- A / B : Insert Cell Above/Below
- D, D : Delete selected cell

In [6]:
env = gym.make('CartPole-v0')
state = env.reset()

# for discrete env
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

print(state_dim, action_dim)

4 2


In [4]:
from collections import deque
import numpy as np
import random
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [5]:
class Net(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(state_dim, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, action_dim)
    
    def forward(self, state):
        out = F.relu(self.fc1(state))
        out = F.relu(self.fc2(out))
        out = F.softmax(self.fc3(out), dim=1) 
        return out

In [18]:
q = Net(state_dim, action_dim)
qtarget = Net(state_dim, action_dim)

In [21]:
# network의 weight값을 print
# network의 weight값을 다른 network의 weight값으로 맞춰줌
weight_name = 'fc1.weight'
print(q.state_dict()[weight_name][1])
print(qtarget.state_dict()[weight_name][1])
qtarget.load_state_dict(q.state_dict())
print(q.state_dict()[weight_name][1])
print(qtarget.state_dict()[weight_name][1])

tensor([-0.3967, -0.3701,  0.4617, -0.1940])
tensor([-0.3967, -0.3701,  0.4617, -0.1940])
tensor([-0.3967, -0.3701,  0.4617, -0.1940])
tensor([-0.3967, -0.3701,  0.4617, -0.1940])


In [None]:
class DQNAgent():
    def __init__(self, env, epsilon):
        # if env is discrete,
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        
        self.epsilon = epsilon
                
        self.q = Net(state_dim, action_dim)
        self.qtarget = Net(state_dim, action_dim)
        
        self.optimizer = optim.Adam(self.q.parameters())
    
    def update_qtarget_parameter(self):
        self.qtarget.load_state_dict(self.q.state_dict())    
    
    def get_action(self, state):
        qval = self.q.forward(state)
        return np.argmax(qval)
    
    def learn(self, mini_batch):
        # batch_size 만큼 크기의 loss의 합을 최소화
        # zero_grad()를 매번 해줘야 하나?
        loss = 0
        for sample in mini_batch:
            loss = loss + (sample['predict'] - self.q.forward(sample['state']))^2
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    def _compute_dqn_loss(self):
        pass

In [None]:
class ReplayBuffer():
    def __init__(self, buffer_size, gamma):
        # deque는 maxlen만큼만 FIFO
        self._memory = deque(maxlen=buffer_size)
        self.gamma = gamma
    
    def add(self, sample):
        if sample['done']:
            sample['predict'] = sample['reward']
        else:
            sample['predict'] = sample['reward'] + self.gamma*max(sample['qtarget'])
        self._memory.append(sample)
    
    def sample(self, batch_size):
        # random하게 batch_size만큼 sample
        mini_batch = random.sample(self._memory, batch_size)        
        return mini_batch

In [None]:
# parameters
episode_len = 1000
buffer_size = 1000
batch_size = 10
epsilon = 0.0
gamma = 0.0

# init class
env = gym.make('CartPole-v0')
agent = DQNAgent(env, epsilon)
buffer = ReplayBuffer(buffer_size, gamma)

# Playing Atari with Deep Reinforcement Learning, Algorithm 1
# Q1. q 와 qtarget을 따로 둬야 하는 이유?
# Q2. 매 스텝마다 learn을 해야 하는것인가?
# Q3. 언제 network(q, qtarget)의 parameter를 업데이트?
# Q4. DQN의 action selection strategy? greedy? -- 그냥 epsilon greedy
# Q5. zero_grad()
for i in range(episode_len):
    state = env.reset()
    while True:
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        sample = {'state':state, 'reward': reward, 'done': done, 'qtarget': agent.qtarget.forward(next_state)}
        buffer.add(sample)
        mini_batch = buffer.sample(batch_size)
        agent.learn(mini_batch)
        agent.update_qtarget_parameter()
        if done:
            break