In [576]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from dataclasses import dataclass
import time

import importlib
import random

import tetris
importlib.reload(tetris);
from tetris import Tetris

from IPython import display
from IPython.display import clear_output

In [597]:
class Experiance:

    def __init__(self, state, new_state, action, reward, n_actions):
        self.state = state
        self.new_state = new_state
        self.action = action
        self.reward = reward
        self.done = None
        self.n_actions = n_actions

    def __str__(self):
        return f"Experiance ( {self.state}, {self.new_state}, {self.action}, {self.reward}, {n_actions}"

    def __repr__(self):
        return self.__str__()

In [598]:
class ReplayMemory:

    def __init__(self, capacity, batch_size):
        self.memory = []
        self.capacity = capacity
        self.batch_size = batch_size
        self.current_index = 0

    def push(self, exp):
        if self.current_index == self.capacity:
            self.current_index = 0

        if len(self.memory) == self.capacity:
            self.memory[self.current_index] = exp
        else:
            self.memory.append(exp)
        self.current_index += 1
        
    def sample(self):
        return [self.memory[i] for i in np.random.permutation(len(self.memory))[:self.batch_size]]

    def can_provide_sample(self):
        if len(self.memory) >= self.batch_size:
            return True
        return False
    
    @staticmethod
    def stack_memory(memory_list):
        actions = []
        states = []
        new_states = []
        rewards = []
        n_actions = []
        
        for exp in memory_list:
            states.append(exp.state)
            new_states.append(exp.new_state)
            actions.append(exp.action)
            rewards.append(exp.reward)
            n_actions.append(exp.n_actions)
            
        return (torch.stack(states), torch.stack(new_states), 
                torch.tensor(actions), torch.tensor(rewards), 
                torch.tensor(n_actions))

In [599]:
class DQN(nn.Module):

    def __init__(self, width, height):
        super().__init__()

        self.fc1 = nn.Linear(in_features=width*height, out_features=32)
        self.fc2 = nn.Linear(in_features=32, out_features=64)
        self.out = nn.Linear(in_features=64, out_features=9*4)

    def forward(self, t):
        # expected t shape (batch_size, width, height)
        t = t.reshape(t.shape[0], -1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

In [600]:
@dataclass(init=True, repr=True)
class Agent:
    
    lr : float
    discount_factor : float
    ep : float
    ep_decay : float
    ep_min : float

#     def __post_init__(self):

    def take_action(self, state, possible_actions):
        self.ep = max(self.ep_min, self.ep * (1-self.ep_decay))
        if self.ep > random.random():
            # explore
            return np.random.randint(len(possible_actions))
        else:
            # exploit
            action_pred = QValueEstimator.policy_predict(state.unsqueeze(0)).argmax().item()
            if action_pred < len(possible_actions):
                return action_pred
            return np.random.randint(len(possible_actions))

In [601]:
class QValueEstimator:
    
    policy_network = None
    target_network = None
    optimizer = None
    
    def __init__(self, width, height):
        QValueEstimator.policy_network = DQN(width, height)
        QValueEstimator.target_network = DQN(width, height)
        QValueEstimator.target_network.eval()
        
        QValueEstimator.update_target_network()
        
        QValueEstimator.optimizer = torch.optim.Adam(
            QValueEstimator.policy_network.parameters(),
            lr=0.01
        )
        
    @staticmethod
    def update_target_network():
        QValueEstimator.target_network.load_state_dict(QValueEstimator.policy_network.state_dict())
        
    @staticmethod
    def policy_predict(t):
        return QValueEstimator.policy_network(t)
    
    @staticmethod
    def update_policy_network(agent, memory_exps):
        states, new_states, actions, rewards, n_actions = ReplayMemory.stack_memory(memory_exps)
        next_state_pred = QValueEstimator.target_network(new_states).max(dim=1)[0]
        next_state_pred[n_actions] = 0
        optimal_actions = torch.tensor(rewards + agent.discount_factor * next_state_pred, requires_grad=True)
        pred_actions = torch.tensor([p[a] for a, p in zip(actions, QValueEstimator.policy_network(states))], requires_grad=True)
#         print(pred_actions, optimal_actions)
        QValueEstimator.optimizer.zero_grad()
#         loss = torch.sum((optimal_actions - pred_actions)**2)
        loss = F.mse_loss(pred_actions, optimal_actions)
        print(loss)
        loss.backward()
        QValueEstimator.optimizer.step()

In [602]:
class TetrisEnvManager:
    
    def __init__(self):
        self.tetris = Tetris()
        self.tetris.reset()
        
        self.action_space_n = 9 * 4
        
    def reset(self):
        self.tetris.reset()
        
    def step(self, action):
        return self.tetris.step(action, render=False)
    
    def get_possible_actions(self):
        return self.tetris.get_next_states().keys()
        
    def get_current_state(self):
        numpy_inp = np.array(self.tetris.get_current_board_state()).astype(np.float32).flatten()
        numpy_inp[numpy_inp != 0] = 1
        input_tensor = torch.tensor(numpy_inp, requires_grad=True)
        return input_tensor.reshape(self.height, self.width)
        
    def render(self):
        s = ""
        for row in self.tetris.get_current_board_state():
            for char in row:
                if char == 0:
                    s += " ."
                else:
                    s += " o"
            s += "\n"
        print(s)
        
    @property
    def width(self):
        return self.tetris.width
    
    @property
    def height(self):
        return self.tetris.height

In [603]:
tetris = Tetris()

In [604]:
tetris.width, 

(10,)

In [605]:
tetris_env = TetrisEnvManager()

In [606]:
score, done = tetris_env.step((0, 0))

In [607]:
QValueEstimator.policy_predict(tetris_env.get_current_state().unsqueeze(0)).argmax()

TypeError: 'NoneType' object is not callable

In [608]:
tetris_env.render()

 . . . . o o . . . .
 . . . . . o o . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . . . . . . . . . .
 . o . . . . . . . .
 o o o . . . . . . .



**Train loop.**

In [609]:
n_episodes = 10000
done = False
target_net_update = 30

agent = Agent(
    lr=0.001,
    discount_factor=0.99,
    ep=1,
    ep_decay=0.001,
    ep_min=0.2
)

QValueEstimator(tetris_env.width, tetris_env.height)

replay_memory = ReplayMemory(10000, 256)

all_rewards = []

for n in range(1, n_episodes+1):
    
    tetris_env.reset()
    
    total_reward = 0
    
    current_state = tetris_env.get_current_state()
    
    step = 1
    
    total_reward = 0
    
    while True:
        
        possilble_actions = list(tetris_env.get_possible_actions())
        
        action = agent.take_action(current_state, possilble_actions)
        
        reward, done = tetris_env.step(possilble_actions[action])
        
        next_state = tetris_env.get_current_state()

        clear_output(wait=True)
#         tetris_env.render()
#         plt.plot(all_rewards)
#         display.display(plt.show())
        print(f"***********Ep {n}***********")
        print("Action", possilble_actions[action])
        total_reward += reward
        print("Total reward", total_reward)
        print("Ep", agent.ep)
        
#         time.sleep(1)
        
        replay_memory.push(Experiance(current_state, next_state, action, reward, len(possilble_actions)))
        
        current_state = next_state.detach().clone()
        
        if replay_memory.can_provide_sample():
            
            QValueEstimator.update_policy_network(agent, replay_memory.sample())
        
        if done:
            all_rewards.append(total_reward)
            break
        
#         time.sleep(1)
        step += 1
    
    if n % target_net_update == 0:
        QValueEstimator.update_target_network()

***********Ep 14***********
Action (6, 0)
Total reward 10
Ep 0.7740428188605086


ValueError: not enough values to unpack (expected 5, got 4)

In [None]:
states, new_states, actions, rewards = ReplayMemory.stack_memory(replay_memory.memory)

In [573]:
states.shape, new_states.shape, actions.shape, rewards.shape

(torch.Size([74790, 20, 10]),
 torch.Size([74790, 20, 10]),
 torch.Size([74790]),
 torch.Size([74790]))

In [532]:
rewards.unique()

tensor([ 1, 11])

In [435]:
states

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1., 