In [1]:
import torch
import numpy as np
import random
from Model import actor, get_network_input
from Carrot import GameEnvironment
import torch.nn as nn

In [2]:
class ReplayMemory(object):
    def __init__(self, max_size):
        self.max_size = max_size
        self.buffer = []
        
    def push(self, state, action, reward, next_state, done):
        experience = (state, action, reward, next_state, done)
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        state_batch = []
        action_batch = []
        reward_batch = []
        next_state_batch = []
        done_batch = []
        
        batch = random.sample(self.buffer, batch_size)
        
        for experience in batch:
            state, action, reward, next_state, done = experience
            state_batch.append(state)
            action_batch.append(action)
            reward_batch.append(reward)
            next_state_batch.append(next_state)
            done_batch.append(done)
        
        return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)
    
    def truncate(self):
        self.buffer = self.buffer[-self.max_size:]
    
    def __len__(self):
        return len(self.buffer)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = actor(2, 10, 2)
board = GameEnvironment()
memory = ReplayMemory(100)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)

In [4]:
epsilon = 0.1
gamma = 0.45

def play(num_games, printing=False):
    run = True
    move=0
    games_played = 0
    while run:
        inp = torch.from_numpy(board.get_boardstate())
        out = model(inp)
        rand = np.random.uniform(0,1)
        if rand > epsilon:
            move = torch.argmax(out)
        else:
            move = np.random.randint(0,2)
        if printing == True:
            print(inp, out)

        reward, done = board.update_boardstate(move)
        #print(inp, out, reward)
        newinp = torch.from_numpy(board.get_boardstate())
        
        memory.push(inp, move, reward, newinp, done)
        games_played += 1
        
        if num_games == games_played:
            run = False
            
MSE = nn.MSELoss()
def lossf(batch):
    states, actions, rewards, next_states, dones = batch
    states = torch.cat([x.unsqueeze(0) for x in states], dim=0).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.cat([x.unsqueeze(0) for x in next_states]).to(device)
    dones = torch.FloatTensor(dones).to(device)
    
    curr_Q = model.forward(states)
    next_Q = model.forward(next_states)
    
    curr_Qy = curr_Q.gather(1, actions.unsqueeze(0).transpose(0,1)).transpose(0,1).squeeze(0)
    max_next_Q = torch.max(next_Q, 1)[0]
    expected_Q = rewards + gamma*max_next_Q
    loss = MSE(curr_Qy, expected_Q)
    
    '''
    print('STATE', states)
    print('ACTIONS', actions)
    print('REWARDS', rewards)
    print('NEXT STATES', next_states)
    print('DONES', dones)
    print('CURR_Q', curr_Q)
    print('CURR_Q AFTER ACTIONS', curr_Qy)
    print('NEXT_Q', next_Q)
    print('EXPECTED Q', expected_Q)
    print('LOSS', loss)
    #'''
    
    return loss

def learn(times, batch_size):
    total_loss = 0
    for i in range(times):
        optimizer.zero_grad()
        sample = memory.sample(batch_size)
        loss = lossf(sample)
        total_loss += loss
        loss.backward()
        optimizer.step()
    print('LOSS {}'.format(total_loss))
    
def train_bot(games_per_epoch, batch_size):
    while True:
        play(games_per_epoch)
        learn(100, batch_size)
        memory.truncate()

In [5]:
train_bot(100, 2)

LOSS 1000005.5
LOSS 999381.125
LOSS 998806.75
LOSS 998668.3125
LOSS 998486.125
LOSS 998356.5
LOSS 998444.375
LOSS 998219.625
LOSS 998034.375
LOSS 997824.0625
LOSS 997674.875
LOSS 997495.6875
LOSS 997436.625
LOSS 997223.1875
LOSS 997048.1875
LOSS 996804.8125
LOSS 996595.5
LOSS 996385.6875
LOSS 996135.5625
LOSS 996025.0625
LOSS 995532.25
LOSS 995229.8125
LOSS 994770.125
LOSS 994300.5
LOSS 994052.6875
LOSS 993520.0
LOSS 992839.625
LOSS 992189.125
LOSS 991803.0625
LOSS 990952.875
LOSS 990254.6875
LOSS 989480.75
LOSS 988549.4375
LOSS 987201.3125
LOSS 986078.6875
LOSS 984996.8125
LOSS 984033.3125
LOSS 981917.75
LOSS 979698.125
LOSS 979235.8125
LOSS 977270.25
LOSS 973899.4375
LOSS 972735.0625
LOSS 969844.5625
LOSS 967786.0
LOSS 964043.75
LOSS 958649.5
LOSS 956674.3125
LOSS 958382.5625
LOSS 952482.5
LOSS 950182.0625
LOSS 945973.0
LOSS 948079.8125
LOSS 945931.375
LOSS 939537.75
LOSS 928968.25
LOSS 931281.25
LOSS 931081.6875
LOSS 909996.25
LOSS 907249.125
LOSS 897319.5625
LOSS 902729.1875
LOSS 8

KeyboardInterrupt: 

In [6]:
play(10, True)

tensor([1, 0], dtype=torch.int32) tensor([27.4201,  7.4293], grad_fn=<AddBackward0>)
tensor([1, 0], dtype=torch.int32) tensor([27.4201,  7.4293], grad_fn=<AddBackward0>)
tensor([1, 0], dtype=torch.int32) tensor([27.4201,  7.4293], grad_fn=<AddBackward0>)
tensor([0, 1], dtype=torch.int32) tensor([11.8043, 17.9294], grad_fn=<AddBackward0>)
tensor([0, 1], dtype=torch.int32) tensor([11.8043, 17.9294], grad_fn=<AddBackward0>)
tensor([1, 0], dtype=torch.int32) tensor([27.4201,  7.4293], grad_fn=<AddBackward0>)
tensor([0, 1], dtype=torch.int32) tensor([11.8043, 17.9294], grad_fn=<AddBackward0>)
tensor([1, 0], dtype=torch.int32) tensor([27.4201,  7.4293], grad_fn=<AddBackward0>)
tensor([1, 0], dtype=torch.int32) tensor([27.4201,  7.4293], grad_fn=<AddBackward0>)
tensor([1, 0], dtype=torch.int32) tensor([27.4201,  7.4293], grad_fn=<AddBackward0>)
