April 15th, 2021

In [1]:
import gym
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque()
        self.batch_size = 32
        self.size_limit = 50000
    
    def put(self, transition):
        self.buffer.append(transition)
        if len(self.buffer) > self.size_limit:
            self.buffer.popleft() # First in - out
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)
        
    def size(self):
        return len(self.buffer)

In [3]:
class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(4, 64) # Fully connected, Relu
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 2) # output: 2 (right, left actions)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x) # no relu since Q-value can be negative
        return x
    
    def sample_action(self, obs, epsilon): # epsilon-greedy 구현
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,1)
        else:
            return out.argmax().item() # return the one with largest Q-val

In [4]:
def train(q, q_target, memory, gamma, optimizer, batch_size):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)
        
        q_out = q(s) # s.shape: 32,4 / q.shape: 32,2
        q_a = q_out.gather(1,a) # Extracting q values only for the actions taken (32,2 -> 32,1)
                                # because action includes q val for both direction: right and left
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(target, q_a)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

In [5]:
def main():
    env = gym.make('CartPole-v1')
    q = Qnet()
    q_target = Qnet() # create instance
    q_target.load_state_dict(q.state_dict()) # duplicate
    memory = ReplayBuffer()
    
    print_interval = 20
    score = 0.0
    render = False
    
    gamma = 0.98
    batch_size = 32
    optimizer = optim.Adam(q.parameters(), lr=0.0005)
    # we don't update Target Q's parameters
    
    for n_epi in range(10000):
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200))
        # epsilon decreases from 8% to 1% linearly throughout episodes
        s = env.reset()
        done = False
        if render:
            env.render()
        
        for t in range(600):
            a = q.sample_action(torch.from_numpy(s).float(), epsilon)
            s_prime, r, done, info = env.step(a)
            done_mask = 0.0 if done else 1.0 # will be multiplied to TD value later(?)
            memory.put((s,a,r/200.0, s_prime, done_mask))
            s = s_prime
            
            score += r
            if score > 1000:
                render = True
            if done:
                break
                
        if memory.size()>2000:
            train(q, q_target, memory, gamma, optimizer, batch_size)
            # We only stack memories until 2000 and then start training
            
        if n_epi%20==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict()) # update target Q every 20 episode
            print(f"Episode {n_epi}:  Buffer size: {memory.size()}, Score: {score}, EPS: {epsilon*100:.1f}")
            
            score = 0.0
    env.close()

if __name__ == '__main__':
    main()

Episode 20:  Buffer size: 328, Score: 328.0, EPS: 7.9
Episode 40:  Buffer size: 644, Score: 316.0, EPS: 7.8
Episode 60:  Buffer size: 938, Score: 294.0, EPS: 7.7
Episode 80:  Buffer size: 1276, Score: 338.0, EPS: 7.6
Episode 100:  Buffer size: 1583, Score: 307.0, EPS: 7.5
Episode 120:  Buffer size: 1887, Score: 304.0, EPS: 7.4
Episode 140:  Buffer size: 2385, Score: 498.0, EPS: 7.3
Episode 160:  Buffer size: 2739, Score: 354.0, EPS: 7.2
Episode 180:  Buffer size: 3101, Score: 362.0, EPS: 7.1
Episode 200:  Buffer size: 3428, Score: 327.0, EPS: 7.0
Episode 220:  Buffer size: 3945, Score: 517.0, EPS: 6.9
Episode 240:  Buffer size: 4389, Score: 444.0, EPS: 6.8
Episode 260:  Buffer size: 4795, Score: 406.0, EPS: 6.7
Episode 280:  Buffer size: 5262, Score: 467.0, EPS: 6.6
Episode 300:  Buffer size: 5697, Score: 435.0, EPS: 6.5
Episode 320:  Buffer size: 6197, Score: 500.0, EPS: 6.4
Episode 340:  Buffer size: 6882, Score: 685.0, EPS: 6.3
Episode 360:  Buffer size: 7466, Score: 584.0, EPS: 6.2

Episode 2840:  Buffer size: 50000, Score: 4279.0, EPS: 1.0
Episode 2860:  Buffer size: 50000, Score: 4114.0, EPS: 1.0
Episode 2880:  Buffer size: 50000, Score: 5043.0, EPS: 1.0
Episode 2900:  Buffer size: 50000, Score: 3615.0, EPS: 1.0
Episode 2920:  Buffer size: 50000, Score: 4635.0, EPS: 1.0
Episode 2940:  Buffer size: 50000, Score: 4386.0, EPS: 1.0
Episode 2960:  Buffer size: 50000, Score: 3877.0, EPS: 1.0
Episode 2980:  Buffer size: 50000, Score: 4287.0, EPS: 1.0
Episode 3000:  Buffer size: 50000, Score: 3476.0, EPS: 1.0
Episode 3020:  Buffer size: 50000, Score: 3550.0, EPS: 1.0
Episode 3040:  Buffer size: 50000, Score: 3687.0, EPS: 1.0
Episode 3060:  Buffer size: 50000, Score: 2279.0, EPS: 1.0
Episode 3080:  Buffer size: 50000, Score: 2026.0, EPS: 1.0
Episode 3100:  Buffer size: 50000, Score: 1461.0, EPS: 1.0
Episode 3120:  Buffer size: 50000, Score: 3685.0, EPS: 1.0
Episode 3140:  Buffer size: 50000, Score: 5493.0, EPS: 1.0
Episode 3160:  Buffer size: 50000, Score: 5126.0, EPS: 1

Episode 5620:  Buffer size: 50000, Score: 8765.0, EPS: 1.0
Episode 5640:  Buffer size: 50000, Score: 8935.0, EPS: 1.0
Episode 5660:  Buffer size: 50000, Score: 5918.0, EPS: 1.0
Episode 5680:  Buffer size: 50000, Score: 4637.0, EPS: 1.0
Episode 5700:  Buffer size: 50000, Score: 4715.0, EPS: 1.0
Episode 5720:  Buffer size: 50000, Score: 5736.0, EPS: 1.0
Episode 5740:  Buffer size: 50000, Score: 5858.0, EPS: 1.0
Episode 5760:  Buffer size: 50000, Score: 4773.0, EPS: 1.0
Episode 5780:  Buffer size: 50000, Score: 6584.0, EPS: 1.0
Episode 5800:  Buffer size: 50000, Score: 5789.0, EPS: 1.0
Episode 5820:  Buffer size: 50000, Score: 5922.0, EPS: 1.0
Episode 5840:  Buffer size: 50000, Score: 4809.0, EPS: 1.0
Episode 5860:  Buffer size: 50000, Score: 4221.0, EPS: 1.0
Episode 5880:  Buffer size: 50000, Score: 3837.0, EPS: 1.0
Episode 5900:  Buffer size: 50000, Score: 3789.0, EPS: 1.0
Episode 5920:  Buffer size: 50000, Score: 4198.0, EPS: 1.0
Episode 5940:  Buffer size: 50000, Score: 4108.0, EPS: 1

Episode 8400:  Buffer size: 50000, Score: 10000.0, EPS: 1.0
Episode 8420:  Buffer size: 50000, Score: 10000.0, EPS: 1.0
Episode 8440:  Buffer size: 50000, Score: 9371.0, EPS: 1.0
Episode 8460:  Buffer size: 50000, Score: 9694.0, EPS: 1.0
Episode 8480:  Buffer size: 50000, Score: 9448.0, EPS: 1.0
Episode 8500:  Buffer size: 50000, Score: 10000.0, EPS: 1.0
Episode 8520:  Buffer size: 50000, Score: 10000.0, EPS: 1.0
Episode 8540:  Buffer size: 50000, Score: 10000.0, EPS: 1.0
Episode 8560:  Buffer size: 50000, Score: 9319.0, EPS: 1.0
Episode 8580:  Buffer size: 50000, Score: 9158.0, EPS: 1.0
Episode 8600:  Buffer size: 50000, Score: 9654.0, EPS: 1.0
Episode 8620:  Buffer size: 50000, Score: 10000.0, EPS: 1.0
Episode 8640:  Buffer size: 50000, Score: 10000.0, EPS: 1.0
Episode 8660:  Buffer size: 50000, Score: 9591.0, EPS: 1.0
Episode 8680:  Buffer size: 50000, Score: 8386.0, EPS: 1.0
Episode 8700:  Buffer size: 50000, Score: 8350.0, EPS: 1.0
Episode 8720:  Buffer size: 50000, Score: 9590.0,

`q_target.load_state_dict(q.state_dict())`

- Copying Q to Target Q network
- `state_dict` includes model's weight informations in a form of dictionary

32 batches, 10 updates.
- This means we are performing 320 samplings for every episode