In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from tetris import Tetris
from tetris_master import TetrisMaster2
from Agent import DQNAgent

import os

In [2]:
w, h = 10, 20
game_len = 5
epoch = 10000
betta = 0.3
gamma = 0.89
rf = 4
lr = 1e-5
epsilon, epsilon_min, epsilon_decay = 0.4, 0., 0.995
batch_size = 10

train_ep = 100
test_ep = 100

In [5]:
env = Tetris(w, h)

# Initialize the agent
state_dim = (env.board.shape, 1)
action_dim = 4

cp_path = 'checkpoints'

model = TetrisMaster2(8, 32, 100).double()

agent = DQNAgent(
    model, state_dim, action_dim, lr, betta, gamma, epsilon, epsilon_min, epsilon_decay
)

TypeError: __init__() missing 9 required positional arguments: 'model', 'state_dim', 'action_dim', 'learning_rate', 'betta', 'gamma', 'epsilon', 'epsilon_min', and 'epsilon_decay'

In [None]:
for e in range(epoch):
    # torch.autograd.set_detect_anomaly(True)

    # Training the DQN agent
    train_iter = tqdm(range(train_ep))
    model.train()
    for episode in train_iter:
        with torch.no_grad():
            state = env.reset()
            total_reward = 0
            done = False
            iters = 0
            while iters < game_len and not done:
                action = agent.select_action(state)
                next_state, reward, done = env.step(action)
                agent.remember(state, action, reward, next_state, done)
                state = next_state
                total_reward = total_reward + reward
                iters += 1
        loss = agent.replay(batch_size)
        train_iter.set_postfix(str=f"loss {loss.item():.5}, " + f"{iters=}, " +
                        f"rwd: {round(total_reward, rf)}, " +
                        f"eps {round(agent.epsilon, rf)}")
        
    # Evaluate the trained agent
    test_iter = tqdm(range(test_ep))
    model.eval()
    with torch.no_grad():
        for _ in test_iter:
            state = env.reset()
            total_reward = 0
            done = False
            iters = 0
            # while iters < game_len and not done:
            while iters < 100 and not done:
                action = agent.select_action(state, test=True)
                next_state, reward, done = env.step(action)
                state = next_state
                total_reward += reward
                iters += 1
            test_iter.set_postfix(str=f"Test Episode, loss {round(loss.item(), rf)}, " +
                        f"Total Reward: {round(total_reward, rf)}")
    test_iter.close()

    # print('saving')
    try:
        torch.save({
            'model_state_dict': agent.online_net.state_dict(),
        }, os.path.join(cp_path, 'test_model'))
    except KeyboardInterrupt:
        torch.save({
            'model_state_dict': agent.online_net.state_dict(),
        }, os.path.join(cp_path, 'test_model'))


100%|██████████| 1000/1000 [1:37:33<00:00,  5.85s/it, str=Episode: 1000, loss 0.2568, Total Reward: -0.0086]
100%|██████████| 10/10 [02:48<00:00, 16.88s/it, str=Test Episode, loss 0.2568, Total Reward: -0.0165]
