In [1]:
import torch
import gym
import random
import numpy as np

In [741]:
def init_weights(m):
    if type(m) == torch.nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

In [126]:
env = gym.make("CartPole-v1")
observation = env.reset()
for _ in range(10000):
    env.render()
    action = env.action_space.sample()
    observation, reward, done, info = env.step(action)
    print(observation , reward, done)
    
    if done:
        observation = env.reset()
        break
        
env.close()

[-0.00576908 -0.17838304 -0.00627325  0.32392704] 1.0 False
[-0.00933674  0.01682768  0.00020529  0.0292724 ] 1.0 False
[-0.00900019  0.21194668  0.00079074 -0.26334575] 1.0 False
[-0.00476126  0.01681346 -0.00447617  0.02958649] 1.0 False
[-0.00442499 -0.17824402 -0.00388444  0.32085377] 1.0 False
[-0.00798987  0.01693304  0.00253263  0.02694838] 1.0 False
[-0.00765121  0.21201858  0.0030716  -0.2649344 ] 1.0 False
[-0.00341084  0.40709656 -0.00222709 -0.55664694] 1.0 False
[ 0.0047311   0.6022497  -0.01336003 -0.85003066] 1.0 False
[ 0.01677609  0.7975513  -0.03036064 -1.1468846 ] 1.0 False
[ 0.03272711  0.99305624 -0.05329833 -1.4489316 ] 1.0 False
[ 0.05258824  1.1887914  -0.08227696 -1.7577796 ] 1.0 False
[ 0.07636406  0.9946922  -0.11743256 -1.491778  ] 1.0 False
[ 0.09625791  0.801179   -0.14726812 -1.2379532 ] 1.0 False
[ 0.11228149  0.99785334 -0.17202719 -1.5729119 ] 1.0 False
[ 0.13223855  0.8051501  -0.20348541 -1.3384504 ] 1.0 False
[ 0.14834157  1.0021685  -0.23025443 -1.

In [814]:
class ReplayMemory(object):
    def __init__(self, capacity=5800):
        self.memory = []
        
    def add_transition(self, state, policy, reward, new_state, is_final):
        transition = [state, policy, reward, new_state, is_final]
        self.memory.append(transition)
        
    def get_batch(self, batch_size):                                        
        if batch_size > len(self.memory):
            return self.memory
        return random.sample(self.memory, k=batch_size)
    
    def __len__(self):
        return len(self.memory)

In [981]:
class DQN(torch.nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.linear1 = torch.nn.Linear(4,10)
        #self.linear2 = torch.nn.Linear(10, 20)
        self.linear3 = torch.nn.Linear(10, 20)
        self.linear4 = torch.nn.Linear(20, 10)
        self.linear5 = torch.nn.Linear(10, 5)
        self.linear6 = torch.nn.Linear(5, 2)
        
        self.relu = torch.nn.functional.relu
        
    def forward(self, x):
        #print(x)
        x = self.relu(self.linear1(x))
        #print(x)
        #x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.relu(self.linear4(x))
        x = self.relu(self.linear5(x))
        x = self.relu(self.linear6(x))
        return x #torch.tanh(x)
    
   
    
    def predict(self, x, mask=None):
        if mask == None:
            return self.forward(x)
        x = self.forward(x) * torch.tensor(mask).type(torch.float).unsqueeze(1)
        return x     

In [1070]:
class Game():
    def __init__(self):
        self.env = gym.make("CartPole-v1")
        self.observation_history = []
        
    def get_state(self, state_length=5):
        if len(self.observation_history) < state_length:
            ret = [[0]*4]*(state_length-len(self.observation_history)) + self.observation_history
        else:
            ret = self.observation_history[-state_length:]
        return np.array(ret)
        
    def play(self, model=None, render=True, debug_log=False):
        state = torch.tensor(env.reset(), dtype=torch.float)
        #self.observation_history.append(observation)
        #state = torch.tensor(observation, dtype=torch.float)
        
        while True:
            if render:
                env.render()
        
            #state = self.get_state()
                     
            if model != None:
                #state = torch.as_tensor(state, dtype=torch.float).flatten()
                action = int(torch.argmax(model(state)))
                print(action)
            else:
                action = env.action_space.sample()
            
            state, reward, done, _ = env.step(action)
            state = torch.tensor(state, dtype=torch.float)
            
            if debug_log:
                print(state, reward, done)
                
            #self.observation_history.append(observation)
            
            if done:
                state = torch.tensor(env.reset(), dtype=torch.float)
                break
        env.close()

In [1100]:
class Trainer(object):
    def __init__(self, model, epochs=3, batch_size=40):
        #self.epochs = epochs
        self.batch_size = batch_size
        self.game = Game()
        self.memory = ReplayMemory()
        self.model = model
        self.state_length = 5

        self.prev_model = DQN()
        self.prev_model.apply(init_weights)
        
        self.criterion = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), 0.01)
    
    def optimize(self, gamma=0.999):
        if len(self.memory) <= self.batch_size:
            return
        #optimizer.zero_grad()
        batch = self.memory.get_batch(batch_size=self.batch_size)
        
        q_values = torch.zeros(len(batch))
        prev_q_values = torch.zeros(len(batch))
        rewards = torch.zeros(len(batch))
        
        for i in range(len(batch)):
            #s = torch.as_tensor(batch[i][0], dtype=torch.float).flatten()
            
            q_values[i] = self.model(batch[i][0])[batch[i][1]]
            if batch[i][3] != None:
                #sp = torch.as_tensor(batch[i][3], dtype=torch.float).flatten()
                
                prev_q_values[i] = torch.max(self.prev_model(batch[i][3]))
            rewards[i] = batch[i][2]
        
        #print(prev_q_values)

        prev_q_values = prev_q_values * gamma + rewards
        
        loss = self.criterion(q_values, prev_q_values)
        
        self.optimizer.zero_grad()
        loss.backward()
        
        #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
        for param in self.model.parameters():
            param.grad.data.clamp_(-1, 1)
        
        self.optimizer.step()
        #print(loss.cpu().detach())
        #print(float(loss.cpu().detach()))
        #return float(loss.cpu().detach())
            
        
    def train(self, epochs, epsilon=0.4, decay=0.998, gamma=0.999):
        tester = torch.randn(4, dtype=torch.float)
        state = torch.tensor(env.reset(), dtype=torch.float)
        
        self.model.train()
        losses = []

        for e in tqdm(range(epochs)):
            #print(self.model.predict(tester))
            
            for i in range(300000):
                #state = self.game.get_state(state_length=self.state_length)
                use_model = random.random() > epsilon
            
                if use_model:
                    #state = torch.as_tensor(state, dtype=torch.float).flatten()
                    #policy = self.model.predict(state)
                    policy = self.model(state)
                    action = int(torch.argmax(policy))
                else:
                    action = random.randint(0,1)
                
                new_state, reward, done, _ = env.step(action)
                #self.game.observation_history.append(observation)
            
                #new_state = self.game.get_state()
                #prev_state = self.game.get_state(state_length=self.state_length+1)[:-1]
                
                if done:
                    new_state = None
                else:
                    new_state = torch.tensor(new_state, dtype=torch.float)

                self.memory.add_transition(state, action, reward, new_state, done)
                self.optimize()
                
                state = new_state
                                    
                if done:
                    break
            env.close()
            state = torch.tensor(env.reset(), dtype=torch.float)
            epsilon *= decay
            env.reset()
            if e % 5 == 0:
                self.prev_model.load_state_dict(self.model.state_dict())
            
        self.model.eval()
        return losses
        

In [1110]:
model = DQN()
model.apply(init_weights)

DQN(
  (linear1): Linear(in_features=4, out_features=10, bias=True)
  (linear3): Linear(in_features=10, out_features=20, bias=True)
  (linear4): Linear(in_features=20, out_features=10, bias=True)
  (linear5): Linear(in_features=10, out_features=5, bias=True)
  (linear6): Linear(in_features=5, out_features=2, bias=True)
)

In [1106]:
from tqdm import tqdm
trainer = Trainer(model)

In [1116]:
loss = trainer.train(epochs=300)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [02:34<00:00,  1.94it/s]


In [1117]:
game = Game()
game.play(model=model, debug_log=True)

1
tensor([ 0.0307,  0.2375,  0.0347, -0.3079]) 1.0 False
0
tensor([ 0.0354,  0.0419,  0.0285, -0.0044]) 1.0 False
1
tensor([ 0.0363,  0.2366,  0.0284, -0.2880]) 1.0 False
0
tensor([0.0410, 0.0411, 0.0227, 0.0135]) 1.0 False
1
tensor([ 0.0418,  0.2359,  0.0229, -0.2719]) 1.0 False
0
tensor([0.0465, 0.0405, 0.0175, 0.0279]) 1.0 False
1
tensor([ 0.0473,  0.2353,  0.0181, -0.2592]) 1.0 False
0
tensor([0.0521, 0.0400, 0.0129, 0.0391]) 1.0 False
1
tensor([ 0.0529,  0.2349,  0.0137, -0.2494]) 1.0 False
0
tensor([0.0575, 0.0396, 0.0087, 0.0475]) 1.0 False
1
tensor([ 0.0583,  0.2346,  0.0096, -0.2424]) 1.0 False
0
tensor([0.0630, 0.0393, 0.0048, 0.0533]) 1.0 False
1
tensor([ 0.0638,  0.2344,  0.0058, -0.2379]) 1.0 False
0
tensor([0.0685, 0.0392, 0.0011, 0.0566]) 1.0 False
1
tensor([ 0.0693,  0.2343,  0.0022, -0.2357]) 1.0 False
0
tensor([ 0.0740,  0.0391, -0.0025,  0.0577]) 1.0 False
1
tensor([ 0.0748,  0.2343, -0.0013, -0.2358]) 1.0 False
0
tensor([ 0.0794,  0.0392, -0.0061,  0.0565]) 1.0 Fals

In [307]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
plt.plot(loss)

In [1114]:
model(torch.randn(4))

tensor([0.1712, 0.0000], grad_fn=<ReluBackward0>)

In [539]:
b = model

In [635]:
torch.max(model(torch.randn(4*5)))

tensor(0.7092, grad_fn=<MaxBackward1>)

In [794]:
env.close()

In [917]:
trainer.optimizer

RMSprop (
Parameter Group 0
    alpha: 0.99
    centered: False
    eps: 1e-08
    lr: 0.01
    momentum: 0
    weight_decay: 0
)