In [47]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as distributions

import matplotlib.pyplot as plt
import tqdm


In [48]:
'''
Implementing the backgammon game

White pieces are represented by 1 and black pieces by -1

White starts the game

'''

class Board():
    def __init__(self):
        self.setup()
        self.turn = 0
        self.startpool = [0,0]

    def print_board(self):
        print(self.board)

    def setup(self):
        self.board = [2,0,0,0,0,-5,0,-3,0,0,0,5,-5,0,0,0,3,0,5,0,0,0,0,-2]

    def bear_off(self):
        if self.turn % 2 == 0:
            if self.startpool[0] == 0:
                bull_count = 0
                for i in self.board[:-6]:
                    bull_count += max([0,i])
                if bull_count == 0:
                    return True
        elif self.turn % 2 == 1:
            if self.startpool[1] == 0:
                bull_count = 0
                for i in self.board[6:]:
                    bull_count += min([0,i])
                if bull_count == 0:
                    return True
        else:
            return False


    def move_(self, start, roll):
        if start+roll > len(self.board):
            if not self.bear_off():
                return False
        if start > (-1)**self.turn*(start+roll):
            print("You can only move pieces forward")
            return False
        elif np.sign(self.board[start]) != (-1)**self.turn:
            print("You can only move pieces of your own colour")
            return False
        elif np.sign(self.board[(start+roll)]) == np.sign(self.board[start]):
            self.board[(start+roll)] += np.sign(self.board[start])
            self.board[start] -= np.sign(self.board[start])
            return True

        elif np.abs(self.board[(start+roll)]) > 1:
            print("You can't move to a space occupied by two or more opposing pieces")
            return False

        else:
            self.startpool[self.turn % 2 + 1] += self.board[(start+roll)]
            self.board[(start+roll)] = 0
            self.board[(start+roll)] += np.sign(self.board[start])
            self.board[start] -= np.sign(self.board[start])

        return True


    def roll(self):
        roll_1 = np.random.randint(1, 7)
        roll_2 = np.random.randint(1, 7)

        if roll_1 == roll_2:
            return [roll_1, roll_2, roll_1, roll_2]
        else:
            return [roll_1, roll_2, 0, 0]

    def get_state(self):
        return self.board + self.startpool  + self.roll()

    def take_turn(self, agent):
        state = self.get_state()
        for move in  agent.take_turn(state):
            self.move_(move[0], move[1])
            self.check_victory()
        return True


    def check_victory(self):
        if self.turn % 2 == 0:
            if self.startpool[0] == 0:
                bull_count = 0
                for i in self.board:
                    bull_count += max([0, i])
                if bull_count == 0:
                    return True
        elif self.turn % 2 == 1:
            if self.startpool[1] == 0:
                bull_count = 0
                for i in self.board:
                    bull_count += min([0,i])
                if bull_count == 0:
                    return True
        else:
            return False  


In [49]:
class MLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc_1(x)
        x = F.relu(x)
        x = self.fc_2(x)
        return x


class ActorCritic(nn.Module):
    def __init__(self, actor, critic):
        super().__init__()

        self.actor = actor
        self.critic = critic

    def forward(self, state):

        action_pred = self.actor(state)
        value_pred = self.critic(state)

        return action_pred, value_pred


In [50]:
board = Board()

input_dim = np.shape(board.get_state())
hidden_dim = 32
output_dim = 8


In [51]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0)


In [52]:
def train(env, policy, optimizer, discount_factor, device):

    policy.train()

    log_prob_actions = []
    entropies = []
    value_preds = []
    rewards = []
    done = False
    episode_reward = 0

    state = env.reset()

    while not done:

        state = torch.FloatTensor(state).unsqueeze(0).to(device)

        action_pred, value_pred = policy(state)

        action_prob = F.softmax(action_pred, dim=-1)

        dist = distributions.Categorical(action_prob)

        action = dist.sample()
        log_prob_action = dist.log_prob(action)

        entropy = dist.entropy()

        state, reward, done, _ = env.step(action.item())

        log_prob_actions.append(log_prob_action)
        entropies.append(entropy)
        value_preds.append(value_pred.squeeze(0))
        rewards.append(reward)

        episode_reward += reward

    log_prob_actions = torch.cat(log_prob_actions)
    entropies = torch.cat(entropies)
    value_preds = torch.cat(value_preds)

    returns = calculate_returns(rewards, discount_factor, device)
    advantages = calculate_advantages(returns, value_preds)

    loss = update_policy(advantages, log_prob_actions,
                         returns, value_preds, entropies, optimizer)

    return loss, episode_reward


def calculate_returns(rewards, discount_factor, device, normalize=True):

    returns = []
    R = 0

    for r in reversed(rewards):
        R = r + R * discount_factor
        returns.insert(0, R)

    returns = torch.tensor(returns).to(device)

    if normalize:
        returns = (returns - returns.mean()) / returns.std()

    return returns


def calculate_advantages(returns, pred_values, normalize=True):

    advantages = returns - pred_values

    if normalize:

        advantages = (advantages - advantages.mean()) / advantages.std()

    return advantages


def update_policy(advantages, log_prob_actions, returns, value_preds, entropies, optimizer):

    returns = returns.detach()

    policy_loss = -(advantages * log_prob_actions).mean()
    value_loss = F.smooth_l1_loss(returns, value_preds)

    optimizer.zero_grad()

    loss = policy_loss + value_loss * 0.5 - entropies.mean() * 0.01

    loss.backward()

    optimizer.step()

    return loss.item()


def evaluate(env, policy, device):

    policy.eval()

    done = False
    episode_reward = 0

    state = env.reset()

    while not done:

        state = torch.FloatTensor(state).unsqueeze(0).to(device)

        with torch.no_grad():

            action_pred, _ = policy(state)

            action_prob = F.softmax(action_pred, dim=-1)

        action = torch.argmax(action_prob, dim=-1)

        state, reward, done, _ = env.step(action.item())

        episode_reward += reward

    return episode_reward


In [53]:
n_runs = 5
max_episodes = 300
discount_factor = 0.99

train_rewards = torch.zeros(n_runs, max_episodes)
test_rewards = torch.zeros(n_runs, max_episodes)
device = torch.device('cpu')

for run in range(n_runs):

    actor = MLP(input_dim, hidden_dim, output_dim)
    critic = MLP(input_dim, hidden_dim, 1)
    actor_critic = ActorCritic(actor, critic)
    actor_critic = actor_critic.to(device)
    actor_critic.apply(init_weights)
    optimizer = optim.Adam(actor_critic.parameters(), lr=1e-2)

    for episode in tqdm.tqdm(range(max_episodes), desc=f'Run: {run}'):

        loss, train_reward = train(
            train_env, actor_critic, optimizer, discount_factor, device)

        test_reward = evaluate(test_env, actor_critic, device)

        train_rewards[run][episode] = train_reward
        test_rewards[run][episode] = test_reward


idxs = range(max_episodes)
fig, ax = plt.subplots(1, figsize=(10, 6))
ax.plot(idxs, test_rewards.mean(0))
ax.fill_between(idxs, test_rewards.min(0).values,
                test_rewards.max(0).values, alpha=0.1)
ax.set_xlabel('Steps')
ax.set_ylabel('Rewards')

x = torch.randn(2, 10)
y = torch.randn(2, 10)
print(F.smooth_l1_loss(x, y))
print(F.mse_loss(x, y))


TypeError: empty(): argument 'size' must be tuple of ints, but found element of type tuple at pos 2