In [6]:
import torch
import numpy as np
import random
from Model import actor, get_network_input
from Game import GameEnvironment
import torch.nn as nn

In [7]:
class ReplayMemory(object):
    def __init__(self, max_size):
        self.max_size = max_size
        self.buffer = []
        
    def push(self, state, action, reward, next_state, done):
        experience = (state, action, reward, next_state, done)
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        state_batch = []
        action_batch = []
        reward_batch = []
        next_state_batch = []
        done_batch = []
        
        batch = random.sample(self.buffer, batch_size)
        
        for experience in batch:
            state, action, reward, next_state, done = experience
            state_batch.append(state)
            action_batch.append(action)
            reward_batch.append(reward)
            next_state_batch.append(next_state)
            done_batch.append(done)
        
        return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)
    
    def truncate(self):
        self.buffer = self.buffer[-self.max_size:]
    
    def __len__(self):
        return len(self.buffer)

In [9]:
model = actor(10, 20, 5)
epsilon = 0.1
reward_nothing = 0.
reward_dead = -100.
reward_apple = 100.
gridsize = 13
gamma = 0.6

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
board = GameEnvironment(gridsize, reward_nothing, reward_dead, reward_apple)
memory = ReplayMemory(100)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-5)

def play(num_games):
    run = True
    move=0
    games_played = 0
    while run:
        inp = get_network_input(board.snake, board.apple)
        out = model(inp)
        rand = np.random.uniform(0,1)
        if rand > epsilon:
            move = torch.argmax(out)
        else:
            move = np.random.randint(0,5)

        reward, done = board.update_boardstate(move)
        newinp = get_network_input(board.snake, board.apple)
        
        memory.push(inp, move, reward, newinp, done)
        
        
        if board.game_over == True:
            games_played += 1
            board.resetgame()
            
            if num_games == games_played:
                run = False
                
MSE = nn.MSELoss()
def lossf(batch):
    states, actions, rewards, next_states, dones = batch
    states = torch.cat([x.unsqueeze(0) for x in states], dim=0).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.cat([x.unsqueeze(0) for x in next_states]).to(device)
    dones = torch.FloatTensor(dones).to(device)
    
    curr_Q = model.forward(states)
    next_Q = model.forward(next_states)
    #print(curr_Q)
    #print(actions)

    curr_Q = curr_Q.gather(1, actions.unsqueeze(0).transpose(0,1)).transpose(0,1).squeeze(0)    
    max_next_Q = torch.max(next_Q, 1)[0]*(torch.ones(dones.size()) - dones)
    expected_Q = rewards + gamma*max_next_Q
    loss = MSE(curr_Q, expected_Q)
    
    '''
    print('STATE', states)
    print('ACTIONS', actions)
    print('REWARDS', rewards)
    print('NEXT STATES', next_states)
    print('DONES', dones)
    print('CURR_Q', curr_Q)
    print('NEXT_Q', next_Q)
    print('LOSS', loss)
    #'''
    
    return loss

def learn(times, batch_size):
    total_loss = 0
    for i in range(times):
        optimizer.zero_grad()
        sample = memory.sample(batch_size)
        loss = lossf(sample)
        total_loss += loss
        loss.backward()
        optimizer.step()
    print('LOSS {}'.format(total_loss))
    
def train_bot(games_per_epoch, batch_size):
    iterations = 0
    while True:
        iterations += 1
        play(games_per_epoch)
        learn(400, batch_size)
        memory.truncate()
        if iterations % 50 == 0:
            torch.save(model.state_dict(), './Models/DQN6{}'.format(iterations))

In [11]:
model.load_state_dict(torch.load('./Models/DQN54400'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [12]:
train_bot(30, 2)

LOSS 152707.046875
LOSS 199464.015625
LOSS 192258.140625
LOSS 192524.390625
LOSS 155373.828125
LOSS 199996.921875
LOSS 161602.421875
LOSS 146083.125
LOSS 172936.28125
LOSS 171062.25
LOSS 148278.15625
LOSS 169637.40625
LOSS 162093.59375
LOSS 164380.71875
LOSS 196366.859375
LOSS 196530.671875
LOSS 156461.0
LOSS 171363.859375
LOSS 136449.75
LOSS 191536.65625
LOSS 199664.640625
LOSS 163504.484375
LOSS 148677.953125
LOSS 168064.65625
LOSS 134306.78125
LOSS 173443.5625
LOSS 140269.453125
LOSS 142777.96875
LOSS 156912.1875
LOSS 157188.609375
LOSS 146023.953125
LOSS 182303.0
LOSS 189908.625
LOSS 176445.3125
LOSS 175774.71875
LOSS 210155.53125
LOSS 235730.828125
LOSS 161053.359375
LOSS 164991.015625
LOSS 160058.78125
LOSS 187948.296875
LOSS 191258.203125
LOSS 170342.625
LOSS 200699.71875
LOSS 156693.609375
LOSS 142003.546875
LOSS 189022.5625
LOSS 194153.0
LOSS 158567.1875
LOSS 154513.171875
LOSS 201409.015625
LOSS 185481.53125
LOSS 191969.15625
LOSS 171531.015625
LOSS 184937.359375
LOSS 175771.

KeyboardInterrupt: 