In [1]:
!pip install gymnasium
!pip install tetris_gymnasium

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/958.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m46.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0
Collecting tetris_gymnasium
  Downloading tetris_gymnasium-0.2.1-py3-none-any.whl.metadata (6.2 kB)
Downloading tetris_gymnasium-0.2.1-py3-none-any.whl (24 kB)
Installing collected packages: tetris_gymnasium
Successfully installed tetris_gymnasium-0.2.1


In [2]:
import gymnasium as gym
from tetris_gymnasium.envs import Tetris
import matplotlib
import matplotlib.pyplot as plt

# Set up Colab

In [3]:
env = gym.make('tetris_gymnasium/Tetris', render_mode='rgb_array')

if 'inline' in matplotlib.get_backend():
  from IPython import display
plt.ion()

<contextlib.ExitStack at 0x7b7a7d2e9f10>

# Setting up for Training

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque

class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()

        self.fc1_1 = nn.Linear(24 * 18, 512)
        self.fc1_2 = nn.Linear(24 * 18, 512)
        self.fc1_3 = nn.Linear(4 * 4, 64)
        self.fc1_4 = nn.Linear(4 * 16, 64)

        self.fc2 = nn.Linear(512 + 512 + 64 + 64, 512)
        self.fc3 = nn.Linear(512, 8)

    def forward(self, board, active_tetromino_mask, holder, queue):
        x1 = torch.relu(self.fc1_1(board.view(1, 432)))
        x2 = torch.relu(self.fc1_2(active_tetromino_mask.view(1, 432)))
        x3 = torch.relu(self.fc1_3(holder.view(1, 16)))
        x4 = torch.relu(self.fc1_4(queue.view(1, 64)))

        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class DQNAgent:
    def __init__(self, action_size, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01, lr=1e-4, batch_size=32, memory_size=10000):
        self.action_size = action_size
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size

        self.model = DQN()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()

        self.memory = ReplayMemory(memory_size)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)

        board = torch.tensor(state['board'], dtype=torch.float32)
        active_tetromino_mask = torch.tensor(state['active_tetromino_mask'], dtype=torch.float32)
        holder = torch.tensor(state['holder'], dtype=torch.float32)
        queue = torch.tensor(state['queue'], dtype=torch.float32)

        q_values = self.model(board, active_tetromino_mask, holder, queue)
        return torch.argmax(q_values).item()

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        batch = self.memory.sample(self.batch_size)
        for state, action, reward, next_state, done in batch:
            board, active_tetromino_mask, holder, queue = state.values()
            next_board, next_active_tetromino_mask, next_holder, next_queue = next_state.values()

            board_tensor = torch.tensor(board, dtype=torch.float32)
            active_tetromino_mask_tensor = torch.tensor(active_tetromino_mask, dtype=torch.float32)
            holder_tensor = torch.tensor(holder, dtype=torch.float32)
            queue_tensor = torch.tensor(queue, dtype=torch.float32)

            next_board_tensor = torch.tensor(next_board, dtype=torch.float32)
            next_active_tetromino_mask_tensor = torch.tensor(next_active_tetromino_mask, dtype=torch.float32)
            next_holder_tensor = torch.tensor(next_holder, dtype=torch.float32)
            next_queue_tensor = torch.tensor(next_queue, dtype=torch.float32)

            next_q_values = self.model(next_board_tensor, next_active_tetromino_mask_tensor, next_holder_tensor, next_queue_tensor)
            next_q_value = torch.max(next_q_values).item()

            target_q_value = reward + (1 - done) * self.gamma * next_q_value

            current_q_values = self.model(board_tensor, active_tetromino_mask_tensor, holder_tensor, queue_tensor)
            current_q_value = current_q_values[0][action]

            loss = self.loss_fn(current_q_value, torch.tensor(target_q_value, dtype=torch.float32))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def train(self, env, num_episodes=1000):
        for e in range(num_episodes):
            state, _ = env.reset()
            done = False
            total_reward = 0

            while not done:
                action = self.act(state)
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                self.memory.push(state, action, reward, next_state, done)

                self.replay()

                state = next_state
                total_reward += reward

            print(f"Episode {e}/{num_episodes}, Total Reward: {total_reward}, Epsilon: {self.epsilon}")

# Training

In [5]:
agent = DQNAgent(action_size=8)
agent.train(env, num_episodes=20)

Episode 0/20, Total Reward: 10, Epsilon: 0.9046104802746175
Episode 1/20, Total Reward: 10, Epsilon: 0.7666961448653229
Episode 2/20, Total Reward: 10, Epsilon: 0.5790496471185967
Episode 3/20, Total Reward: 8, Epsilon: 0.500708706245853
Episode 4/20, Total Reward: 8, Epsilon: 0.37627099809304654
Episode 5/20, Total Reward: 13, Epsilon: 0.25578670228422234
Episode 6/20, Total Reward: 11, Epsilon: 0.1473726336968319
Episode 7/20, Total Reward: 12, Epsilon: 0.07416156859737154
Episode 8/20, Total Reward: 16, Epsilon: 0.027350922279685573
Episode 9/20, Total Reward: 13, Epsilon: 0.011263082556340478
Episode 10/20, Total Reward: 15, Epsilon: 0.00998645168764533
Episode 11/20, Total Reward: 11, Epsilon: 0.00998645168764533
Episode 12/20, Total Reward: 16, Epsilon: 0.00998645168764533
Episode 13/20, Total Reward: 10, Epsilon: 0.00998645168764533
Episode 14/20, Total Reward: 10, Epsilon: 0.00998645168764533
Episode 15/20, Total Reward: 15, Epsilon: 0.00998645168764533
Episode 16/20, Total Rew

# Testing

In [6]:
from pathlib import Path
import numpy as np

done = False
total_reward = 0
frames = []
state, _ = env.reset()

while not done:
  action = agent.act(state)
  next_state, reward, terminated, truncated, _ = env.step(action)
  agent.memory.push(state, action, reward, next_state, done)
  agent.replay()
  frame = env.render()
  frames.append(frame)
  done = terminated or truncated
  total_reward+=reward
  state = next_state

env.close()
print(f'Total rewards: {total_reward}')

import imageio
video_path = 'tetris_test.mp4'
imageio.mimsave(video_path, frames, fps=30)

from IPython.display import HTML
from base64 import b64encode
mp4 = open(video_path,'rb').read()
video_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML(f'<video width="640" height="480" controls><source src="{video_url}" type="video/mp4"></video>')

Total rewards: 13


