In [1]:
import gym
import numpy as np
import sys
import math
import ipdb
from six import StringIO
from gym import spaces
from gym.utils import seeding
import random
from copy import deepcopy
from collections import namedtuple
from numba import jitclass, int64, float32, float64, bool_

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import optim

In [2]:
def weighted_mse_loss(input, target, weight):
    return torch.sum(weight * (input - target) ** 2)


def optimize_model(
    dqn_net,
    target_net,
    memory,
    learning_rate,
    batch_size,
    size_board,
    gamma,
    optimizer,
    device,
):

    tree_indexes, memory_batch, batch_ISWeights = memory.sample(batch_size)

    samples = Transition(*zip(*memory_batch))

    states_batch = samples.state
    actions_batch = samples.action
    rewards_batch = samples.reward
    next_states_batch = samples.next_state
    dones_batch = samples.done

    target_qs_batch = []

    torch_next_states_batch = (
        torch.from_numpy(np.asarray(next_states_batch)).float().to(device)
    )

    # Get Q values for next state
    q_next_state = dqn_net(torch_next_states_batch, batch_size, size_board)

    q_target_next_state = (
        target_net(torch_next_states_batch, batch_size, size_board).cpu().detach()
    )

    for i in range(0, len(memory_batch)):
        terminal = dones_batch[i]

        action = np.argmax(q_next_state[i].cpu().detach().numpy())

        if terminal:
            target_qs_batch.append(rewards_batch[i])
        else:
            target = rewards_batch[i] + gamma * q_target_next_state[i][action]
            target_qs_batch.append(target)

    targets_batch = np.array([each for each in target_qs_batch])

    torch_states_batch = torch.from_numpy(np.asarray(states_batch)).float().to(device)

    output = dqn_net(torch_states_batch, batch_size, size_board)

    torch_actions_batch = torch.from_numpy(np.asarray(actions_batch))
    torch_actions_batch = torch_actions_batch.unsqueeze(0)
    torch_actions_batch = torch_actions_batch.view(batch_size, 1)

    q_values = output.gather(1, torch_actions_batch.to(device))
    q_values = q_values.float()

    absolute_errors = (
        torch.abs(
            q_values
            - torch.from_numpy(targets_batch).view(batch_size, 1).float().to(device)
        )
        .cpu()
        .detach()
        .numpy()
    )

    torch_batch_ISWeights = torch.from_numpy(batch_ISWeights).to(device)

    diff_target = q_values - torch.from_numpy(targets_batch).view(
        batch_size, 1
    ).float().to(device)
    squared_diff = diff_target ** 2
    weighted_squared_diff = squared_diff * torch_batch_ISWeights

    loss = torch.mean(weighted_squared_diff)

    optimizer.zero_grad()

    loss.backward()
    optimizer.step()

    absolute_errors = np.squeeze(absolute_errors, 1)

    memory.batch_update(tree_indexes, absolute_errors)

    return loss.cpu().detach().numpy()


def pre_train(env, pre_train_len, memory):
    print("Starting pretrain...")
    board, valid_movements = env.reset()
    state = to_power_two_matrix(board)

    eps_threshold = 1

    for i in range(pre_train_len):
        action = selection_action(
            eps_threshold, valid_movements, None, None, None, None
        )

        new_board, reward, done, info = env.step(action)

        if done:
            next_state = np.zeros(state.shape)
            memory.store(state, action, reward, next_state, done)

            board, valid_movements = env.reset()

        else:
            next_state = to_power_two_matrix(new_board)
            memory.store(state, action, reward, next_state, done)

            state = next_state

            valid_movements = info["valid_movements"]


def train(
    dqn_net,
    target_net,
    env,
    memory,
    batch_size,
    size_board,
    episodes,
    ep_update_target,
    decay_rate,
    explore_start,
    explore_stop,
    learning_rate,
    gamma,
    interval_mean,
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dqn_net.to(device)
    target_net.to(device)

    print("Starting training...")
    decay_step = 0

    total_steps_per_episode = []
    total_rewards_per_episode = []
    total_loss_per_episode = []
    total_score_per_episode = []

    best_board = None
    best_reward = 0
    best_score = 0
    best_steps = 0
    best_ep = -1

    optimizer = optim.RMSprop(dqn_net.parameters(), lr=learning_rate)

    for ep in range(episodes):
        step = 0
        episode_rewards = []
        board, valid_movements = env.reset()
        state = to_power_two_matrix(board)
        done = False
        loss_ep = []

        while True:
            step += 1

            decay_step += 1

            eps_threshold = explore_stop + (explore_start - explore_stop) * np.exp(
                -decay_rate * decay_step
            )
            action = selection_action(
                eps_threshold, valid_movements, dqn_net, state, size_board, device
            )
            new_board, reward, done, info = env.step(action)

            episode_rewards.append(reward)

            if done:
                total_steps_per_episode.append(step)

                next_state = np.zeros((1, size_board, size_board, 16))

                total_reward = np.sum(episode_rewards)

                total_rewards_per_episode.append(total_reward)

                memory.store(state, action, reward, next_state, done)

                loss_total_ep = np.sum(loss_ep) / step
                total_loss_per_episode.append(loss_total_ep)

                total_score_per_episode.append(info["total_score"])

                print("Episode:", ep)
                print("Total Reward:", total_reward)
                print("Total episodes", step)
                print("Eps_threshold:", eps_threshold)
                print("Loss ep:", loss_total_ep)
                env.render()
                print("---------------------------")

                if info["total_score"] > best_score:
                    best_score = info["total_score"]
                    best_reward = total_reward
                    best_ep = ep
                    best_board = deepcopy(new_board)
                    best_steps = step

            else:
                next_state = to_power_two_matrix(new_board)

                memory.store(state, action, reward, next_state, done)

                state = deepcopy(next_state)

                valid_movements = info["valid_movements"]

                board = deepcopy(new_board)

            loss = optimize_model(
                dqn_net,
                target_net,
                memory,
                learning_rate,
                batch_size,
                size_board,
                gamma,
                optimizer,
                device,
            )

            loss_ep.append(loss)

            if done:
                break

        if ep % ep_update_target == 0:
            target_net = deepcopy(dqn_net)

    print("***********************")
    print("Best ep", best_ep)
    print("Best Board:")
    print(best_board)
    print("Best step", best_steps)
    print("Best score", best_score)
    print("***********************")

    plot_info(
        total_steps_per_episode,
        total_rewards_per_episode,
        total_loss_per_episode,
        total_score_per_episode,
        interval_mean,
        episodes,
    )

In [3]:
!pip install ipdb --user



In [4]:
spec = [
    ("__size_board", int64),
    ("__seed", int64),
    ("__board", float64[:, :]),
    ("__total_score", int64),
    ("__merged", int64),
    ("__scores_move", int64),
    ("__temp_board", float64[:, :]),
    ("__valid_movements", float64[:]),
]


@jitclass(spec)
class Game2048:

    def __init__(self, size_board, seed):

        self.__size_board = size_board
        self.__seed = seed

        self.__board = self.__init_board()
        self.__total_score = 0
        self.__merged = 0
        self.__scores_move = 0
        self.__temp_board = np.zeros((size_board, size_board))

        self.__valid_movements = np.zeros(4)

        if self.__seed:
            random.seed(self.__seed)

        self.__add_two_or_four()
        self.__add_two_or_four()

    def __init_board(self):
        return np.zeros((self.__size_board, self.__size_board))

    def __get_empty_spaces_index(self):
        return np.where(self.__board == 0)

    def __add_two_or_four(self):
        indexes = self.__get_empty_spaces_index()

        index = np.random.choice(np.arange(len(indexes[0])))

        sample = np.random.rand(1)

        if sample[0] >= 0.9:
            self.__board[indexes[0][index]][indexes[1][index]] = 4
        else:
            self.__board[indexes[0][index]][indexes[1][index]] = 2

    def __reverse_array(self, array):
        temp_array = np.zeros(len(array))
        for cell in range(len(array)):
            temp_array[cell] = array[len(array) - cell - 1]

        return temp_array

    def __merge(self, array, reverse):
        array = array[array != 0]

        temp_array = np.zeros(self.__size_board)
        if reverse:
            count_index = self.__size_board - 1
        else:
            count_index = 0
        i = 0
        while True:
            if i >= (len(array) - 1):
                if i == (len(array) - 1):
                    temp_array[count_index] = array[i]
                return temp_array

            if (array[i] == array[i + 1]) and array[i] != 0:
                temp_array[count_index] = array[i] + array[i + 1]
                self.__scores_move += temp_array[count_index]
                self.__merged += 1
                i = i + 2
                if reverse:
                    count_index -= 1
                else:
                    count_index += 1
            else:
                if array[i] != 0:
                    temp_array[count_index] = array[i]
                    if reverse:
                        count_index -= 1
                    else:
                        count_index += 1
                i = i + 1

    def __up(self):
        self.__temp_board = np.zeros((self.__size_board, self.__size_board))
        for column in range(self.__size_board):
            self.__temp_board[:, column] = self.__merge(
                self.__board[:, column].copy(), False
            )

    def __down(self):
        self.__temp_board = np.zeros((self.__size_board, self.__size_board))
        for column in range(self.__size_board):
            self.__temp_board[:, column] = self.__merge(
                self.__reverse_array(self.__board[:, column].copy()), True
            )

    def __right(self):
        self.__temp_board = np.zeros((self.__size_board, self.__size_board))
        for line in range(self.__size_board):
            self.__temp_board[line, :] = self.__merge(
                self.__reverse_array(self.__board[line, :].copy()), True
            )

    def __left(self):
        self.__temp_board = np.zeros((self.__size_board, self.__size_board))
        for line in range(self.__size_board):
            self.__temp_board[line, :] = self.__merge(
                self.__board[line, :].copy(), False
            )

    def __array_equal(self, a, b):
        for value_a, value_b in zip(a.flat, b.flat):
            if value_a != value_b:
                return False
        return True

    def __check_available_moves(self):
        self.__valid_movements = np.zeros(4)
        for i in range(4):
            self.make_move(i)
            if self.__array_equal(self.__board, self.__temp_board) is False:
                self.__valid_movements[i] = 1

    def make_move(self, move):
        self.__merged = 0
        self.__scores_move = 0
        if move == 0:
            self.__up()
        if move == 1:
            self.__down()
        if move == 2:
            self.__right()
        if move == 3:
            self.__left()

    def confirm_move(self):
        self.__board = self.__temp_board.copy()
        self.__total_score += self.__scores_move
        returned_move_scores = self.__scores_move
        returned_merged = self.__merged
        self.__add_two_or_four()
        self.__check_available_moves()

        return returned_move_scores, returned_merged, self.__valid_movements

    def get_board(self):
        return self.__board

    def get_total_score(self):
        return self.__total_score

    def reset(self):
        self.__board = self.__init_board()
        self.__total_score = 0
        self.__add_two_or_four()
        self.__add_two_or_four()
        return self.get_board()

In [5]:
class InvalidMove(Exception):
    pass


class Game2048Env(gym.Env):
    metadata = {"render.modes": ["human", "ansi"]}

    def __init__(self, size_board, seed=None):
        self.__size_board = size_board
        self.__game = Game2048(size_board, seed)

        self.action_space = spaces.Discrete(4)

        self.observation_space = spaces.Box(
            0, 2 ** 16, (size_board * size_board,), dtype=np.int
        )

        self.reward_range = (0., np.inf)

        self.np_random, seed = seeding.np_random(seed)

        self.__actions_legends = {0: "UP", 1: "DOWN", 2: "RIGHT", 3: "LEFT"}

        self.__old_max = 0

        self.__last_action = None
        self.__last_scores_move = None

        print("Environment initialised...")

    def __reward_calculation(self, merged):
        reward = 0
        max_board = self.__game.get_board().max()
        if max_board > self.__old_max:
            self.__old_max = max_board
            reward += math.log(self.__old_max, 2) * 0.1

        reward += merged

        return reward

    def reset(self):
        self.__game.reset()
        valid_movements = np.ones(4)
        return (self.__game.get_board(), valid_movements)

    def step(self, action):
        done = False
        reward = 0
        try:
            self.__last_action = self.__actions_legends[action]

            self.__game.make_move(action)
            returned_move_scores, returned_merged, valid_movements = (
                self.__game.confirm_move()
            )

            reward = self.__reward_calculation(returned_merged)

            if len(np.nonzero(valid_movements)[0]) == 0:
                done = True

            self.__last_scores_move = returned_move_scores

            info = dict()
            info["valid_movements"] = valid_movements
            info["total_score"] = self.__game.get_total_score()
            info["last_action"] = self.__actions_legends[action]
            info["scores_move"] = returned_move_scores
            return self.__game.get_board(), reward, done, info

        except InvalidMove as e:
            print("Invalid move")
            done = False
            reward = 0

    def render(self, mode="human"):
        outfile = StringIO() if mode == "ansi" else sys.stdout
        info_render = "Score: {}\n".format(self.__game.get_total_score())
        info_render += "Highest: {}\n".format(self.__game.get_board().max())
        npa = np.array(self.__game.get_board())
        grid = npa.reshape((self.__size_board, self.__size_board))
        info_render += "{}\n".format(grid)
        info_render += "Last action: {}\n".format(self.__last_action)
        info_render += "Last scores move: {}".format(self.__last_scores_move)
        info_render += "\n"
        outfile.write(info_render)
        return outfile

    def get_actions_legends(self):
        return self.__actions_legends

In [6]:
spec_sum_tree = [
    ("__capacity", int64),
    ("__data_pointer", int64),
    ("__tree", float64[:]),
    ("__state", float64[:, :, :, :, :]),
    ("__action", int64[:]),
    ("__reward", float64[:]),
    ("__next_state", float64[:, :, :, :, :]),
    ("__done", bool_[:]),
]


@jitclass(spec_sum_tree)
class SumTree:
    def __init__(self, capacity, size_board=4):
        self.__data_pointer = 0
        self.__capacity = capacity
        self.__tree = np.zeros(2 * capacity - 1)

        self.__state = np.zeros((capacity, 1, size_board, size_board, 16))
        self.__action = np.zeros(capacity, dtype=np.int64)
        self.__reward = np.zeros(capacity)
        self.__next_state = np.zeros((capacity, 1, size_board, size_board, 16))
        self.__done = np.zeros(capacity, dtype=np.bool_)

    def update(self, tree_index, priority):
        change = priority - self.__tree[tree_index]
        self.__tree[tree_index] = priority
        while tree_index != 0:
            tree_index = (tree_index - 1) // 2
            self.__tree[tree_index] += change

    def add(self, priority, state, action, reward, next_state, done):
        self.__state[self.__data_pointer] = state
        self.__action[self.__data_pointer] = action
        self.__reward[self.__data_pointer] = reward
        self.__next_state[self.__data_pointer] = next_state
        self.__done[self.__data_pointer] = done

        tree_index = self.__data_pointer + self.__capacity - 1
        self.update(tree_index, priority)

        self.__data_pointer += 1

        if self.__data_pointer >= self.__capacity:
            self.__data_pointer = 0

    def get_leaf(self, value):
        parent_index = 0

        while True:
            left_child_index = 2 * parent_index + 1
            right_child_index = left_child_index + 1
            if left_child_index >= len(self.__tree):
                leaf_index = parent_index
                break

            else: 
                if value <= self.__tree[left_child_index]:
                    parent_index = left_child_index
                else:
                    value -= self.__tree[left_child_index]
                    parent_index = right_child_index

        data_index = leaf_index - self.__capacity + 1

        return (
            leaf_index,
            self.__tree[leaf_index],
            self.__state[data_index],
            self.__action[data_index],
            self.__reward[data_index],
            self.__next_state[data_index],
            self.__done[data_index],
        )

    def total_priority(self):
        return self.__tree[0]

    def get_priotiry(self):
        return self.__tree[-self.__capacity:]

    def get_all_tree(self):
        return self.__tree


spec_memory = [
    ("__per_e", float64),
    ("__per_a", float64),
    ("__per_b", float64),
    ("__per_b_increment_per_sampling", float64),
    ("__absolute_error_uper", float64),
    ("__tree", SumTree),
]


Transition = namedtuple(
    "Transition", ("state", "action", "reward", "next_state", "done")
)


class Memory:
    def __init__(self, size_board, capacity):
        self.__capacity = capacity
        self.__per_e = 0.01
        self.__per_a = 0.6
        self.__per_b = 0.4
        self.__per_b_increment_per_sampling = 0.001
        self.__absolute_error_upper = 1.
        self.__tree = SumTree(capacity, size_board)
    spec_store = [("max_priority", float64)]

    def store(self, state, action, reward, next_state, done):
        max_priority = np.max(self.__tree.get_priotiry())
        if max_priority == 0:
            max_priority = self.__absolute_error_upper
        self.__tree.add(max_priority, state, action, reward, next_state, done)

    def sample(self, batch_size):
        memory_batch = []

        batch_idx, batch_ISWeights = (
            np.empty((batch_size,), dtype=np.int32),
            np.empty((batch_size, 1), dtype=np.float32),
        )
        priority_segment = self.__tree.total_priority() / batch_size
        self.__per_b = np.min(
            [1., self.__per_b + self.__per_b_increment_per_sampling]
        )
        p_min = np.min(self.__tree.get_priotiry()) / self.__tree.total_priority()
        max_weight = (p_min * batch_size) ** (-self.__per_b)

        for i in range(batch_size):
            limit_a, limit_b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(limit_a, limit_b)

            index, priority, state, action, reward, next_state, done = self.__tree.get_leaf(
                value
            )

            sampling_probabilities = priority / self.__tree.total_priority()

            batch_ISWeights[i, 0] = (
                np.power(batch_size * sampling_probabilities, -self.__per_b)
                / max_weight
            )

            batch_idx[i] = index
            memory_batch.append(Transition(state, action, reward, next_state, done))

        return batch_idx, memory_batch, batch_ISWeights

    def batch_update(self, tree_indexes, abs_errors):
        abs_errors += self.__per_e
        clipped_errors = np.minimum(abs_errors, self.__absolute_error_upper)
        priorities = np.power(clipped_errors, self.__per_a)

        for tree_index, priority in zip(tree_indexes, priorities):
            self.__tree.update(tree_index, priority)

In [7]:
import argparse
import sys
import matplotlib.pyplot as plt
from numba import jit

In [8]:
@jit(nopython=True)
def to_power_two_matrix(matrix):
    power_matrix = np.zeros(
        shape=(1, matrix.shape[0], matrix.shape[1], 16), dtype=np.float32
    )
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i][j] == 0:
                power_matrix[0][i][j][0] = 1.0
            else:
                power = int(np.log(matrix[i][j]) / np.log(2))
                power_matrix[0][i][j][power] = 1.0

    return power_matrix


def selection_action(
    eps_threshold, valid_movements, dqn_net, state, size_board, device
):
    sample = np.random.rand(1)

    if sample > eps_threshold:
        with torch.no_grad():
            output = dqn_net(torch.from_numpy(state).float().to(device), 1, size_board)
            output_np = output.cpu().detach().numpy()
            ordered = np.flip(np.argsort(output_np), axis=1)[0]
            for x in ordered:
                if valid_movements[x] != 0:
                    return x

    else:
        return np.random.choice(np.nonzero(valid_movements)[0])


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--capacity", type=int, default=100000)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--size_board", type=int, default=4)
    parser.add_argument("--num_episodes", type=int, default=1000)
    parser.add_argument("--learning_rate", type=float, default=0.00025)
    parser.add_argument("--ep_update_target", type=int, default=10)
    parser.add_argument("--decay_rate", type=float, default=0.00005)
    parser.add_argument("--interval_mean", type=int, default=5)

    args = parser.parse_args()

    return args


def get_mean_interval(array, interval_mean):
    interval_mean_list = []
    for x in range(interval_mean):
        interval_mean_list.append(0)

    for i in range(len(array)):
        if i + interval_mean == len(array):
            break
        else:
            interval_mean_list.append(np.mean(array[i: interval_mean + i]))

    return interval_mean_list


def plot_info(
    total_steps_per_episode,
    total_rewards_per_episode,
    total_loss_per_episode,
    total_score_per_episode,
    interval_mean,
    episodes,
):

    interval_steps = get_mean_interval(total_steps_per_episode, interval_mean)
    plt.plot(range(episodes), total_steps_per_episode)
    plt.plot(range(episodes), interval_steps)
    plt.ylabel("Episode durations")
    plt.xlabel("Episodes")
    plt.savefig("episodes_durations.png", bbox_inches="tight")
    plt.close()

    interval_rewards = get_mean_interval(total_rewards_per_episode, interval_mean)
    plt.plot(range(episodes), total_rewards_per_episode)
    plt.plot(range(episodes), interval_rewards)
    plt.ylabel("Reward")
    plt.xlabel("Episodes")
    plt.savefig("episodes_rewards.png", bbox_inches="tight")
    plt.close()

    interval_score = get_mean_interval(total_score_per_episode, interval_mean)
    plt.plot(range(episodes), total_score_per_episode)
    plt.plot(range(episodes), interval_score)
    plt.ylabel("Score")
    plt.xlabel("Episodes")
    plt.savefig("episodes_scores.png", bbox_inches="tight")
    plt.close()

    interval_loss = get_mean_interval(total_loss_per_episode, interval_mean)
    plt.plot(range(episodes), total_loss_per_episode)
    plt.plot(range(episodes), interval_loss)
    plt.ylabel("Loss")
    plt.xlabel("Episodes")
    plt.savefig("episodes_losses.png", bbox_inches="tight")
    plt.close()

In [9]:
class CNN_2048_MODEL(nn.Module):
    def __init__(self, c_in_1, c_in_2, c_out_1, c_out_2):
        super(CNN_2048_MODEL, self).__init__()
        self.__c_in_1 = c_in_1
        self.__c_in_2 = c_in_2
        self.__c_out_1 = c_out_1
        self.__c_out_2 = c_out_2

        self.__expanded_size = (
            2 * 4 * c_out_2 * 2 + 3 * 3 * c_out_2 * 2 + 4 * 3 * c_out_1 * 2
        )

        self.__dense_value_1 = nn.Linear(self.__expanded_size, 256)
        self.__dense_value_2 = nn.Linear(256, 1)
        self.__dense_advantage_1 = nn.Linear(self.__expanded_size, 256)
        self.__dense_advantage_2 = nn.Linear(256, 4)

        self.__cnn_1 = nn.Conv2d(
            c_in_1,
            c_out_1,
            kernel_size=(1, 2),
            stride=(1, 1),
            padding=(0, 0),
            dilation=(1, 1),
        )

        self.__cnn_1_2 = nn.Conv2d(
            c_out_1,
            c_out_2,
            kernel_size=(1, 2),
            stride=(1, 1),
            padding=(0, 0),
            dilation=(1, 1),
        )

        self.__cnn_2 = nn.Conv2d(
            c_in_2,
            c_out_2,
            kernel_size=(2, 1),
            stride=(1, 1),
            padding=(0, 0),
            dilation=(1, 1),
        )

        self.__cnn_2_2 = nn.Conv2d(
            c_out_1,
            c_out_2,
            kernel_size=(2, 1),
            stride=(1, 1),
            padding=(0, 0),
            dilation=(1, 1),
        )

    def forward(self, features, batch_size, size_board):
        features_view = features.view(batch_size, 16, size_board, size_board)
        conv1_output = F.elu(self.__cnn_1(features_view))
        conv2_output = F.elu(self.__cnn_2(features_view))
        conv1_2_1_output = F.elu(self.__cnn_1_2(conv1_output))
        conv1_2_2_output = F.elu(self.__cnn_1_2(conv2_output))
        conv2_2_1_output = F.elu(self.__cnn_2_2(conv1_output))
        conv2_2_2_output = F.elu(self.__cnn_2_2(conv2_output))

        conv1_output_shape = list(conv1_output.shape)
        conv2_output_shape = list(conv2_output.shape)
        conv1_2_1_output_shape = list(conv1_2_1_output.shape)
        conv1_2_2_output_shape = list(conv1_2_2_output.shape)
        conv2_2_1_output_shape = list(conv2_2_1_output.shape)
        conv2_2_2_output_shape = list(conv2_2_2_output.shape)

        hidden1 = conv1_output.view(
            batch_size,
            (conv1_output_shape[1] * conv1_output_shape[2] * conv1_output_shape[3]),
        )

        hidden2 = conv2_output.view(
            batch_size,
            (conv2_output_shape[1] * conv2_output_shape[2] * conv2_output_shape[3]),
        )

        hidden1_2_1 = conv1_2_1_output.view(
            batch_size,
            (
                conv1_2_1_output_shape[1]
                * conv1_2_1_output_shape[2]
                * conv1_2_1_output_shape[3]
            ),
        )

        hidden1_2_2 = conv1_2_2_output.view(
            batch_size,
            (
                conv1_2_2_output_shape[1]
                * conv1_2_2_output_shape[2]
                * conv1_2_2_output_shape[3]
            ),
        )

        hidden2_2_1 = conv2_2_1_output.view(
            batch_size,
            (
                conv2_2_1_output_shape[1]
                * conv2_2_1_output_shape[2]
                * conv2_2_1_output_shape[3]
            ),
        )

        hidden2_2_2 = conv2_2_2_output.view(
            batch_size,
            (
                conv2_2_2_output_shape[1]
                * conv2_2_2_output_shape[2]
                * conv2_2_2_output_shape[3]
            ),
        )

        hidden = torch.cat(
            (hidden1, hidden2, hidden1_2_1, hidden1_2_2, hidden2_2_1, hidden2_2_2), 1
        )

        hidden_value_1 = F.elu(self.__dense_value_1(hidden))
        hidden_value_2 = self.__dense_value_2(hidden_value_1)

        advantage_action_1 = F.elu(self.__dense_advantage_1(hidden))
        advantage_action_2 = self.__dense_advantage_2(advantage_action_1)

        reduced_mean = torch.mean(advantage_action_2, dim=1, keepdim=True)
        output = hidden_value_2 + (advantage_action_2 - reduced_mean)
        return output

In [10]:
import time

In [None]:
!pip install numba==0.47.0 --user

In [12]:
def main():
    seed = 10
    capacity = 100000
    size_board = 4
    batch_size = 128
    episodes = 1000
    ep_update_target = 100
    learning_rate = 0.00025
    decay_rate = 0.00005
    interval_mean = 5
    explore_start = 1.
    explore_stop = 0.01
    gamma = 0.95

    env = Game2048Env(size_board, seed)

    memory = Memory(size_board, capacity)

    c_in_1 = c_in_2 = size_board * size_board
    c_out_1 = c_out_2 = 128
    dqn_net = CNN_2048_MODEL(c_in_1, c_in_2, c_out_1, c_out_2)
    target_net = deepcopy(dqn_net)

    start = time.time()
    pre_train(env, capacity, memory)
    print("Execution pre-train (in seconds):", time.time() - start)

    start = time.time()
    train(
        dqn_net,
        target_net,
        env,
        memory,
        batch_size,
        size_board,
        episodes,
        ep_update_target,
        decay_rate,
        explore_start,
        explore_stop,
        learning_rate,
        gamma,
        interval_mean,
    )
    print("Execution train (in seconds)", time.time() - start)

In [None]:
main()

Environment initialised...
Starting pretrain...
Execution pre-train (in seconds): 18.706619262695312
Starting training...
Episode: 0
Total Reward: 108
Total episodes 122
Eps_threshold: 0.9939793815551794
Loss ep: 0.40106879687700114
Score: 1012
Highest: 64.0
[[ 2. 16. 32.  4.]
 [ 8. 64.  8. 64.]
 [ 4. 16.  2.  8.]
 [ 2.  4. 32.  2.]]
Last action: UP
Last scores move: 0
---------------------------
Update target_net
Episode: 1
Total Reward: 171
Total episodes 185
Eps_threshold: 0.9849195386477164
Loss ep: 0.11639946602486276
Score: 2200
Highest: 256.0
[[  2.   4.   8.   2.]
 [  4.  64.  16.   4.]
 [  8.   4. 256.   8.]
 [  4.   8.   4.   2.]]
Last action: DOWN
Last scores move: 4
---------------------------
Episode: 2
Total Reward: 73
Total episodes 87
Eps_threshold: 0.9806878492518778
Loss ep: 0.0816317700791633
Score: 632
Highest: 64.0
[[ 2.  8.  4.  2.]
 [ 4. 16.  8.  4.]
 [16. 64. 32.  8.]
 [ 8.  2. 16.  4.]]
Last action: DOWN
Last scores move: 4
---------------------------
Episode: 