<a href="https://colab.research.google.com/github/oroojlooy/RL_pytorch/blob/master/general_dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import gym 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt 
import re
import os
import argparse
os.environ["CUDA_VISIBLE_DEVICES"]= "0, 1"


In [2]:
arg_lists = []
parser = argparse.ArgumentParser()


def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg

def str2bool(v):
    return v.lower() in ('true', '1')


arg_lists = []
res_arg = add_argument_group('prediction')
res_arg.add_argument('--task', type=str, default='pred', help='')
res_arg.add_argument('--lr0', type=float, default=0.001, help='')
res_arg.add_argument('--log_interval', type=int, default=200, help='')
res_arg.add_argument('--batch_size', type=int, default=128, help='')
res_arg.add_argument('--nodes', type=list, default=[350, 150], help='')
res_arg.add_argument('--output_dim', type=int, default=11, help='')
res_arg.add_argument('--input_dim', type=int, default=484, help='')


_StoreAction(option_strings=['--input_dim'], dest='input_dim', nargs=None, const=None, default=484, type=<class 'int'>, choices=None, help='', metavar=None)

In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")

no_randomness = False
load_weights = False
semi_random = False

class replay_memory(object):
    def __init__(self, size, sd, b):
        self.max_size = size
        self.storage = []
        self.cur_size = 0
        self.batch_size = b
        self.index = 0 
        
    def add(self, s,a,r,ns,d):
        if self.cur_size < self.max_size:
            self.storage.append([{"s":s, "a":a, "r":r, "ns":ns, "d":d}])
            self.cur_size += 1
        else:
            self.storage.pop(0)
            self.storage.append([{"s":s, "a":a, "r":r, "ns":ns, "d":d}])
            
    def sample(self):
        s = []
        a = []
        r = []
        ns = []
        d = []
        for i in range(self.batch_size):
            indx = torch.randint(self.cur_size, size=(1,)).numpy()[0]
            if no_randomness:
                indx = i
            if semi_random:
                indx = self.index
                if self.index < self.cur_size-1 and self.index < self.max_size-1:
                    self.index += 1
                else:
                    self.index = 0 
            s += [self.storage[indx][0]["s"]]
            a += [self.storage[indx][0]["a"]]
            r += [self.storage[indx][0]["r"]]
            ns += [self.storage[indx][0]["ns"]]
            d += [self.storage[indx][0]["d"]]
            
        return {"s":s, "a":a, "r":r, "ns":ns, "d":d}
    

def get_action_(env, epsilon, action):
    if no_randomness:
        epsilon = 0 
    if semi_random:
        x=np.squeeze(action_ts.detach().numpy())
        if np.sum(x) < 10:
            if 2*np.log(x[0]) - 5 > 3*np.log(x[1]) - 7:
                a = 1
            else:
                a = 0 
        else:
            a = action.argmax().detach().numpy()            
    else:
        rnd = torch.rand((1)).numpy()[0]
        if rnd < epsilon:
            a = torch.randint(env.action_space.n, size=(1,)).numpy()[0]
        else:
            a = action.argmax().detach().to(cpu_device).numpy()
        
    return a
    
class linear_exploration(object):
    def __init__(self, max_,min_,num_eps):
        self.epsilon = max_
        self.min_eps = min_
        self.num_eps = num_eps
        self.eps_red = (max_ - min_)/num_eps
    
    def reduce(self):
        if (self.epsilon > self.min_eps):
            self.epsilon -= self.eps_red 

class multiplicative_exploration(object):
    def __init__(self, max_, min_, reduction):
        self.epsilon = max_
        self.min_eps = min_
        self.reduction = reduction
    
    def reduce(self):
        if (self.epsilon > self.min_eps):
            self.epsilon = self.epsilon * self.reduction

        
class DQN(nn.Module):
    def __init__(self, id, od, layers_d, activations=[], dropouts={}, 
                 batch_norm={}, fixed_weight=None, normal_weight_mu=None, 
                 normal_weight_std=None, uniform_weight_l=None, 
                 uniform_weight_u=None):
        super(DQN, self).__init__()
        self.id = id
        self.od = od
        self.layers_d = layers_d
        self.dropouts = dropouts
        self.batch_norm = batch_norm
        if len(activations) == 0:
            self.activations = [F.relu for _ in range(len(layers_d))] + [None]
        else:
            self.activations = activations

        for c in range(len(layers_d) + 1):
            # print('layer {:1} has {:2} nodes'.format(c, l) )
            if c == 0:
                module = nn.Linear(id, layers_d[c])
                if c in self.batch_norm:
                    self.batch_norm[c] = nn.BatchNorm1d(layers_d[c])
            elif c == len(layers_d):
                module = nn.Linear(layers_d[c - 1], od)
            else:
                module = nn.Linear(layers_d[c - 1], layers_d[c])
                if c in self.batch_norm:
                    self.batch_norm[c] = nn.BatchNorm1d(layers_d[c])
            self.add_module('Layer_' + str(c), module)

            if c in self.dropouts:
                self.dropouts[c] = nn.Dropout(self.dropouts[c])
                # print('dropout is added to layer {:1} '.format(c))

        if fixed_weight is not None:
            for p in self.parameters():
                torch.nn.init.constant_(p, fixed_weight)

        if normal_weight_std is not None:
            for p in self.parameters():
                torch.nn.init.normal_(p, normal_weight_mu, normal_weight_std)

        if uniform_weight_l is not None:
            for p in self.parameters():
                torch.nn.init.uniform_(p, uniform_weight_l, uniform_weight_u)

    def forward(self, x):

        out = x
        c = 0
        for c, (name, module) in enumerate(self.named_children()):
            if self.activations[c] is None:
                out = module(out)
            else:
                out = self.activations[c](module(out))
            if c in self.dropouts:
                out = self.dropouts[c](out)
            if c in self.batch_norm:
                if self.batch_norm[c]:
                    out = self.batch_norm[c](out)

        return out

    def get_norm(self):
        list_ = []
        for c, i in self.named_parameters():
            list_ += [i.grad.reshape(-1,)]
        norm = torch.norm(torch.cat(list_))
        return norm.item()
  
    
def test(period, avg_q, policy_net, env):
    
    rewards = 0
    for i in range(100):
        state = env.reset()    
        done = False
        while not done:
            action = policy_net(torch.tensor(state).float().to(device))
            action = get_action_(env, 0, action)
            next_state, reward, done, _ = env.step(action)

            state = next_state
            rewards += reward
            state = next_state
    print ("Test: episode={0:d}, Q-value={1:0.2f}, reward={2:0.2f}".format(period, avg_q, rewards/100.))
    return rewards/100.
    
def get_weight_norm(net):
    grad_norm=0
    for param in net.parameters():
    #     print(param)
        grad_norm += torch.norm(param)

    return grad_norm

def get_grad_norm(net):
    grad_norm=0
    for param in net.parameters():
    #     print(param)
        grad_norm += torch.norm(param.grad)

    return grad_norm


def get_grad_list(net):
    grads=np.array([])
    for param in net.parameters():
        grads = np.concatenate((grads, param.grad.data.view(-1).detach().numpy()))
    return grads

In [4]:
config, unparsed = parser.parse_known_args()
torch.manual_seed(298)
env_id = "MountainCar-v0"
env = gym.make(env_id)
config.nodes = [128, 128]
config.batch_size = 128
config.min_replay_buffer = 1000
config.max_replay_buffer = 1e6

config.target_update = 200 
config.num_episodes = 1500
config.show_detail = False
config.log_interval = 100 # will print the details of last xx train_steps

In [17]:

policy = DQN(env.observation_space.shape[0], env.action_space.n, config.nodes, 
             uniform_weight_l=-1, uniform_weight_u=1).to(device)
            #  fixed_weight=0.01).to(device) 
            #  normal_weight_mu=0.0, normal_weight_std=1.0).to(device) 
            #  =0.0, normal_weight_std=1.0).to(device)              

target_policy = DQN(env.observation_space.shape[0], env.action_space.n,                     
                    config.nodes, uniform_weight_l=-1, uniform_weight_u=1).to(
                        device)
                    # config.nodes, fixed_weight=0.01).to(device) 
                    # config.nodes, normal_weight_mu=0.0, normal_weight_std=1.0
                    # ).to(device) 

optimizer = optim.Adam(policy_net.parameters())
# optimizer = optim.SGD(policy_net.parameters(),lr=0.001)
# torch.nn.utils.clip_grad_norm(policy_net.parameters(),max_norm=10,norm_type=2)

rbm = replay_memory(config.max_replay_buffer, env.observation_space, config.batch_size)
exp = linear_exploration(0.9, 0.05, config.num_episodes)
# exp = multiplicative_exploration(0.9, 0.05, 0.98)


In [19]:
######################################################################
# Training loop

result = []

train_step = 0
for i_episode in range(config.num_episodes):
     # initialize state
    state = env.reset()

    # Select and perform an action    
    # keep going until get to the goal state
    cnt = 0
    done = False
    rewards = []
    while not done:
        cnt+=1
        action_ts = policy_net(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)).squeeze()
        action = get_action_(env, exp.epsilon, action_ts)
        next_state, reward, done, _ = env.step(action)
        rewards.append(reward)
        
        rbm.add(torch.tensor(state), torch.tensor(action), torch.tensor(reward), torch.tensor(next_state), 
               done)
            
        if config.show_detail:
            print (i_episode,"-",cnt,"-",rbm.cur_size, state, action_ts, action, reward, next_state, done)
        
        if rbm.cur_size >= config.min_replay_buffer:
            if rbm.cur_size == config.min_replay_buffer:
                print ("Started training")
            batch = rbm.sample()
            
            target_Q = target_net(torch.stack(batch["ns"]).float().to(device
                    )).squeeze().max(1)[0].detach()
            target = 0.99 * target_Q*(1-torch.tensor(batch["d"]).to(device).float()) + \
            torch.stack(batch["r"]).to(device)
            QValue = policy_net(torch.stack(batch["s"]).float().to(
                device)).squeeze().gather(1, torch.stack(batch["a"]).to(device).unsqueeze(1))
            
            loss = F.mse_loss(QValue, target.unsqueeze(1))

            if train_step % config.log_interval == 0:
                if config.show_detail:
                    print (i_episode, "-", train_step)
                    print (np.array(batch["d"],dtype=np.int))
                    print ([i for i in torch.stack(batch["a"]).numpy()])
                tmp = [i for i in target_Q.to(cpu_device).numpy()]
                if config.show_detail:
                    print ("target_Qvalue", tmp)
                tmp = [i for i in target.to(cpu_device).numpy()]
                if config.show_detail:
                    print ("target_value", tmp)
                tmp = [i for i in np.squeeze(QValue.to(cpu_device).detach().numpy())]
                if config.show_detail:
                    print ("Qvalue", tmp)
                            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_step+=1
                             
        state = next_state
        if train_step > 1 and train_step % config.target_update==0:
            target_net.load_state_dict(policy_net.state_dict())
            
    if train_step > 0 and i_episode % config.log_interval == 0:
        try:
            per = test(i_episode, QValue.mean().item(), policy_net, env)
        except:
            per = test(i_episode, 0, policy_net, env)
            
        if env_id == "MountainCar-v0":
          if per > -110:
            print("It is solved!")
            break
        result += [ per]
        #     print (cnt)
    exp.reduce()
        

Test: episode=0, Q-value=5.62, reward=-200.00
Test: episode=100, Q-value=-36.95, reward=-200.00
Test: episode=200, Q-value=-23.53, reward=-200.00
Test: episode=300, Q-value=-29.31, reward=-184.71
Test: episode=400, Q-value=-31.61, reward=-200.00
Test: episode=500, Q-value=-32.90, reward=-131.25
Test: episode=600, Q-value=-31.24, reward=-155.13
Test: episode=700, Q-value=-39.25, reward=-112.66
Test: episode=800, Q-value=-36.28, reward=-133.20
Test: episode=900, Q-value=-39.09, reward=-154.36
Test: episode=1000, Q-value=-37.95, reward=-124.95
Test: episode=1100, Q-value=-41.23, reward=-149.84
Test: episode=1200, Q-value=-44.28, reward=-155.86
Test: episode=1300, Q-value=-41.00, reward=-122.02
Test: episode=1400, Q-value=-42.60, reward=-124.19
