We start with all of the necessary imports and library settings.

In [4]:
import numpy as np
import torch
from torch import nn
import pickle
import copy
import cv2
import datetime
import os
import shutil
from abc import ABC, abstractmethod

# this just ensures that the tensors we use can be
# saved w/ pickle w/out throwing a warning
# (we have to save tensors with pickle instead of
# torch.save bc it's much faster)
torch.serialization.add_safe_globals(
    [
        np._core.multiarray._reconstruct,
        np.ndarray,
        np.dtypes.ObjectDType,
        np._core.multiarray.scalar,
        np.dtypes.Int32DType,
        np.dtypes.Float64DType,
    ]
)

Now, we set up the Tetris environment.

In [5]:
# each of the 7 tetrominoes is representable in three ways:
# -an int in [0, 6]
# -a 4x4 binary array
# -a 4x2 array of coordinates for the 4x4 array
# we will be changing between these representations throughout these
# functions that handle the game and pieces themselves
# but, generally, pieces go int -> coordinates -> binary array

# a note on coordinates and rotations:
# the specific positions of each piece in the 4x4 array are based on
# how the pieces act in Tetris for the NES
# that is, each piece has specific coordinates in the 4x4 array such that when
# the 4x4 array is stamped on to the board,
# each piece is consistent in behavior.
# additionally, this setup means that not all rotations are valid;
# for example, the O piece cannot be rotated at all, since it would cause the
# piece to move and lose its consistency in behavior to the NES game
# in the NES game, an invalid rotation input is sanitized to produce a
# valid output
# but, here, there is no expectation that
# the game env will get any invalid input,
# since only valid inputs are shown to the network when it decides how to move

# this translates int -> 4x2 array of coordinates of the piece
int_to_piececoords = [
    [(2, i + 1) for i in range(3)] + [(3, 2)],
    [(2, i + 1) for i in range(3)] + [(3, 3)],
    [(j + 2, i + 1 + j) for j in range(2) for i in range(2)],
    [(i + 2, j + 1) for i in range(2) for j in range(2)],
    [(j + 2, i + 2 - j) for j in range(2) for i in range(2)],
    [(2, i + 1) for i in range(3)] + [(3, 1)],
    [(2, i) for i in range(4)],
]
int_to_piececoords = [np.array(p) for p in int_to_piececoords]

# as mentioned, each piece only has so many valid rotations that
# do not break consistency
# this translates piece int -> list of valid rotations
# where 1 indicates a rotation to the right, -1 indicates a left rotation,
# and 0 indicates no rotation
valid_rotations = [
    [-1, 0, 1, 2],
    [-1, 0, 1, 2],
    [-1, 0],
    [0],
    [-1, 0],
    [-1, 0, 1, 2],
    [0, 1],
]


class piece:
    # a class for a single tetromino in a game of Tetris
    # initialized with a int, stores the coordinates,
    # binary array, and position of the piece

    def __init__(self, pieceint):
        self.pieceint = pieceint
        self.coords = int_to_piececoords[pieceint]
        self.piecearray = self.coordtopiecearray(self.coords)
        self.rotations = valid_rotations[pieceint]
        self.orient = 0
        self.pos = {"x": 3, "y": -2}

    # rotate this teromino in place by dir
    def rotate_piece(self, dir: int) -> None:
        self.coords = self.rotate_coords(self.coords, dir)
        self.orient += dir
        self.orient %= 4
        self.piecearray = self.coordtopiecearray(self.coords)

    # helper, does the array computation on the coordinates for rotation
    def rotate_coords(self, piececoords: np.ndarray, dir: int) -> np.ndarray:
        match dir % 4:
            case 1:
                piececoords = piececoords[:, ::-1]
                piececoords[:, 1] = 4 - piececoords[:, 1]
            case 2:
                piececoords = 4 - piececoords
            case 3:
                piececoords = piececoords[:, ::-1]
                piececoords[:, 0] = 4 - piececoords[:, 0]
            case _:
                pass

        return piececoords

    # helper, does coordinate array -> binary array
    def coordtopiecearray(self, coords: np.ndarray) -> np.ndarray:
        ret = np.zeros((4, 4), dtype=np.int8)
        ret[coords[:, 0], coords[:, 1]] = 1
        return ret


# here, we set up the randomizers that Tetris can use
n_pieces = len(valid_rotations)


# superclass for other randomizers
class randomizer(ABC):
    def __init__(self, seed=None):
        self.reset(seed=seed)

    @abstractmethod
    def reset(self, seed=None):
        pass

    @abstractmethod
    def get_new_piece(self):
        pass


# NES randomizer class
# the NES randomizer works as follows:
# -roll an 8-sided die
# -if the roll is the same as the result of the previous randomization or
#  an 8, roll a 7-sided die and return the result
# -otherwise, just return the first roll
class NESrandom(randomizer):
    def __init__(self, seed=None):
        super().__init__(seed)

    def reset(self, seed=None):
        self.last_piece = -1
        self.gen = np.random.default_rng(seed=seed)

    def get_new_piece(self):
        roll = self.gen.integers(0, n_pieces, endpoint=True)
        if roll in [self.last_piece, n_pieces]:
            roll = self.gen.integers(0, n_pieces, endpoint=False)

        self.last_piece = roll
        return piece(roll)


# bag randomizer class, the randomizer used
# by most (all?) modern Tetris implementations
# the bag randomizer works as follows:
# -start with a bag with one of each tetromino in it
# -randomly pull from the bag until it is empty, and then create a new bag
class Bagrandom(randomizer):
    def __init__(self, seed=None):
        super().__init__(seed)

    def reset(self, seed=None):
        self.gen = np.random.default_rng(seed=seed)
        self.newbag()

    def newbag(self):
        self.bag = np.arange(n_pieces)
        self.gen.shuffle(self.bag)

    def get_new_piece(self):
        roll = self.bag[0]
        self.bag = self.bag[1:]
        if self.bag.shape[0] == 0:
            self.newbag()
        return piece(roll)


# class for the actual tetris game implementation
# the board is represented by a 20x10 numpy array, so
# each "tile" in the game is just one spot in the array
# (this is, of course, upscaled for rendering)
# plays via the tetrisboard.step function
# the board used internally has buffers on the edges to prevent
# indexing errors when a piece hits the wall
class tetris_board:
    def __init__(
        self, height: int = 20, width: int = 10, seed=None, randomizer: str = "bag"
    ):
        self.buffer = 3
        self.line_clears = 0
        self.n_pieces = 0
        self.width = width
        self.height = height
        match randomizer:
            case "nes":
                self.piecegen = NESrandom(seed=seed)
            case _:
                self.piecegen = Bagrandom(seed=seed)
        self.board = np.ones((2 * self.buffer + height, 2 * self.buffer + width))

        self.reset(seed)

    # reset the env back to the start of the game
    def reset(self, seed=None) -> None:
        self.piecegen.reset(seed=seed)
        # clearing board
        self.board[self.buffer : -self.buffer, self.buffer : -self.buffer] = 0
        self.curr_piece = self.piecegen.get_new_piece()
        self.next_piece = self.piecegen.get_new_piece()
        self.line_clears = 0
        self.n_pieces = 0
        self.clear_types = {str(i + 1): 0 for i in range(4)}

    # "working board", just the board without the buffers
    def get_wboard(self) -> np.ndarray:
        return self.board[self.buffer : -self.buffer, self.buffer : -self.buffer]

    # given a piece (including it's position), place it on the board
    def stamp_piece(
        self, board: np.ndarray, piece: piece, copy: bool = True
    ) -> np.ndarray:
        if copy:
            board = board.copy()
        
        # we add the binary array of the piece to the board to place it
        # on, since the binary array is the piece representation in the game
        board[
            self.buffer + piece.pos["y"] : self.buffer + piece.pos["y"] + 4,
            self.buffer + piece.pos["x"] : self.buffer + piece.pos["x"] + 4,
        ] += piece.piecearray
        return board

    # "render board", the array object that will be displayed when
    # rendering the game
    # has extra information about line clears and such, as well as upscaling
    def rboard(self) -> np.ndarray:
        t = 2
        rboard = self.stamp_piece(self.board, self.curr_piece)
        # getting rid of the buffer, but keeping some of it
        # around for a border around the board
        rboard = rboard[
            self.buffer - t : -self.buffer + 1, self.buffer - 1 : -self.buffer + 1
        ]

        boardmask = np.full_like(rboard, False, dtype=bool)
        boardmask[t:-1, 1:-1] = True

        # making the pieces on the board distinct from the border
        rboard[boardmask] *= 2
        rboard[~boardmask] = 1  # and ensuring that the border is
        #                         all the same color

        # making the next-piece display
        next = np.ones((rboard.shape[0], 6))
        next[t : t + 4, 1:5] = self.next_piece.piecearray * 2
        rboard = np.concatenate([rboard, next], axis=1)
        return rboard

    # uses cv2 to render the game in a separate window
    def render(self, wait: int = 1) -> None:
        upscale = 20  # upscale factor
        rboard = self.rboard()
        rboard = rboard.repeat(upscale, 1).repeat(upscale, 0)
        hf = 1.5  # how much space (in terms of one "tile") to leave for each
        #           line denoting the clear types thus far
        rboard = np.concatenate(
            [np.ones((4 * int(hf * upscale), rboard.shape[1])), rboard], axis=0
        )
        rboard /= np.max(rboard)  # for cv2 rgb in [0, 1]

        # adding text to upscaled array denoting line clear types
        types = ["singles", "doubles", "triples", "tetrises"]
        for i in range(4):
            rboard = cv2.putText(
                rboard,
                f"{types[i]}: {self.clear_types[str(i + 1)]}",
                (int(1 * upscale), int((hf * i + 1.5) * upscale)),
                0,
                fontScale=0.7,
                color=255,
                thickness=1,
            )
        rboard = cv2.putText(  # (as well as total lines)
            rboard,
            f"total lines: {self.line_clears}",
            (int(1 * upscale), int((hf * 4 + 1.5) * upscale)),
            0,
            fontScale=0.7,
            color=255,
            thickness=1,
        )
        cv2.imshow("tetris", rboard)
        cv2.waitKey(wait)  # number of miliseconds per frame of rendering

    def new_piece(self) -> None:
        self.curr_piece = self.next_piece
        self.next_piece = self.piecegen.get_new_piece()
        self.n_pieces += 1

    # plays one frame of the game
    def step(self, act: list) -> bool:
        if np.any(
            self.stamp_piece(self.board, self.curr_piece)[self.buffer :] > 1
        ):
            # pieces are intersecting, so a loss
            self.reset()

        # we will be moving this piece around to see if the desired move
        # is valid and if the curr piece should be placed
        acted_piece = copy.deepcopy(self.curr_piece)

        # doing the actions in the act list;
        # in reality, we only do at most 2 actions per frame, but
        # this setup means we could do 3 if we wanted
        acted_piece.pos["x"] += act[0]
        acted_piece.pos["y"] += act[1]
        acted_piece.rotate_piece(act[2])

        act_board = self.stamp_piece(self.board, acted_piece)
        if np.any(act_board[self.buffer :] > 1):
            # intersect other piece or wall, move fails
            acted_piece = self.curr_piece
            # (if move succeeds, implicitly curr_piece <- acted piece)

        acted_piece.pos["y"] += 1  # move curr piece down by gravity
        down_board = self.stamp_piece(self.board, acted_piece)
        if np.any(
            down_board[self.buffer :] > 1
        ):  
            # moving down by gravity fails, so place piece
            # note: this does mean that pieces can be moved ("acted") on
            # the frame when there are no pieces between them and the floor;
            # gravity has to act on the piece one more time to make
            # it lock into place
            acted_piece.pos["y"] -= 1
            self.stamp_piece(self.board, acted_piece, copy=False)
            self.new_piece()

            # remove cleared lines, and record clears
            clearmask = (
                np.sum(
                    self.board[self.buffer : -self.buffer, self.buffer : -self.buffer],
                    1,
                ) == self.width
            )
            clears = np.sum(clearmask)

            # board without cleared lines
            tmpboard = self.board[
                self.buffer : -self.buffer, self.buffer : -self.buffer
            ][~clearmask]

            # move the rest of the pieces down and add empty lines at the top
            self.board[self.buffer : -self.buffer, self.buffer : -self.buffer] = (
                np.concatenate([np.zeros((clears, self.width)), tmpboard], axis=0)
            )

            self.line_clears += clears
            if clears > 0:
                # obv can only get one type of clear per placement
                self.clear_types[str(clears)] += 1

            return True  # true indicates that piece was placed
        else:
            self.curr_piece = acted_piece
            return False  # false indicates that piece was not placed

Now, we set up the agent and its network. For more details on the network training:

DQN: https://arxiv.org/pdf/1312.5602

Double DQN: https://arxiv.org/pdf/1509.06461

Prioritized Experience Replay: https://arxiv.org/pdf/1511.05952

In [6]:
# helper for printing results of training
def sixnum(data, r: int = 3) -> list:
    ret = [
        np.min(data),
        np.quantile(data, 0.25),
        np.mean(data),
        np.median(data),
        np.quantile(data, 0.75),
        np.max(data),
    ]
    ret = [float(i.round(r)) for i in ret]
    return ret


# class for the nn using pytorch
class qmodel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(12, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1),
        )
        self.double()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


# agent class, including caching experiences and training
class tbot:
    def __init__(
        self, load_path: str = "currmodel", load_cache: bool = True, device: str = "cpu"
    ):
        self.device = device

        # instead of storing the objects themselves, we store the state dict of
        # the networks and the optimizer, so we first create blank objects...
        self.online = qmodel().to(self.device)
        self.target = qmodel().to(self.device)
        self.target.load_state_dict(self.online.state_dict())
        self.optimizer = torch.optim.RAdam(self.online.parameters())
        st = 0
        cs = int(3e5)

        cache = {
            "cache_size": cs,
            "cached": np.zeros(cs, dtype="object"),
            "cached_probs": np.full(cs, -1, dtype=np.float64),
            "n_cached": 0,
        }

        if load_path is not None:
            # ...and then load the state dict for each one
            # (or the obj itself, for the cache)
            with open(f"saves/{load_path}/online.pickle", "rb") as f:
                self.online.load_state_dict(pickle.load(f))

            with open(f"saves/{load_path}/target.pickle", "rb") as f:
                self.target.load_state_dict(pickle.load(f))

            with open(f"saves/{load_path}/params.pickle", "rb") as f:
                params = pickle.load(f)
            st = params["step"]

            with open(f"saves/{load_path}/optim.pickle", "rb") as f:
                self.optimizer.load_state_dict(pickle.load(f))

            if load_cache:
                with open(f"saves/{load_path}/cache.pickle", "rb") as f:
                    cache = pickle.load(f)

        for p in self.target.parameters():
            p.requires_grad = False

        self.optimizer.param_groups[0]["lr"] = 2.5e-4

        #RAdamW optim
        self.optimizer.param_groups[0]["decoupled_weight_decay"] = True

        self.cache_size = cache["cache_size"]
        self.cached = cache["cached"]
        self.cached_probs = cache["cached_probs"]
        self.n_cached = cache["n_cached"]

        # lots of hyperparams below :)
        # steps to wait for cache burn in
        self.wait_to_train = 1e4

        # total times the agent has made a decision about a move and
        # cached the result
        self.step = st

        # epsilon controls exploration vs exploitation, starts high and decays
        # linearly
        self.eps_min = 0.1
        self.eps_max = 1
        self.eps_steps = 1e6 + self.wait_to_train
        self.epsilon = max(
            (
                self.eps_max
                - (self.eps_max - self.eps_min) * (self.step / self.eps_steps)
            ),
            self.eps_min,
        )

        # beta controls the bias-annealing for prioritized experience replay
        self.beta_max = 1
        self.beta_min = 0.4
        self.beta_steps = 5e6 + self.wait_to_train
        self.correction_beta = min(
            (
                self.beta_min
                + (self.beta_max - self.beta_min) * (self.step / self.beta_steps)
            ),
            self.beta_max,
        )

        # other hyperparams for prioritized experience replay
        self.delta_epsilon = 1e-3
        self.stoch_alpha = 0.6
        self.do_prioritization = True

        # this is used for recording which cache entries
        # have not been assigned a priority measure
        self.not_measured_inds = []

        # how often (in steps) and how large of a batch to sample
        # from the cache
        self.train_every = 4
        self.batch_size = 32

        # steps between saves to currmodel or to create a new checkpoint folder
        self.checkpoint_every = 1e5
        self.save_every = 1e4

        # rate of syncing and discount rate for double dqn
        self.sync_every = 1e4
        self.discount_rate = 1 - 5e-4

        # just used for recording results during training to get an idea of
        # what's going on
        self.prev_loss = []
        self.prev_est = []
        self.prev_target = []
        self.prev_rew = []

        # loss fun for optimization
        self.loss_fn = nn.L1Loss(reduction="sum")

    # eps and beta will both decay linearly (though in opposite directions)
    def update_eps(self) -> None:
        if self.n_cached > self.wait_to_train:
            self.epsilon = max(
                (
                    self.eps_max
                    - (self.eps_max - self.eps_min) * (self.step / self.eps_steps)
                ),
                self.eps_min,
            )
            self.correction_beta = min(
                (
                    self.beta_min
                    + (self.beta_max - self.beta_min) * (self.step / self.beta_steps)
                ),
                self.beta_max,
            )

    # given features for various states, return the index of the state with the
    # greatest estimated q value
    # (or act randomly, contingent on epsilon)
    # force_model just only uses the network, and ensures that the agent
    # does not update its parameters as if it is training; essentially
    # "freezes" the agent's training progress
    def act(self, Xps: torch.Tensor, force_model: bool = False) -> int:
        if force_model or (np.random.rand() > self.epsilon):
            with torch.no_grad():
                self.online.eval()
                a_id = int(torch.argmax(self.online(Xps).flatten()).item())
        else:
            a_id = int(np.random.rand() * (Xps.shape[0]))

        if not force_model:
            self.update_eps()
            self.step += 1

        return a_id

    # an experience tuple consists of
    # (state_1 features, possible_next_states features,
    # reward for state_1, game_end (binary))
    # new entries are added to empty spots until full, and then the oldest
    # entries are replaced first (like a circular array)
    def cache(self, experience: tuple[torch.Tensor, torch.Tensor, int, int]) -> None:
        newind = (self.n_cached) % self.cache_size
        self.cached[newind] = experience
        self.cached_probs[newind] = self.delta_epsilon
        if self.do_prioritization:
            self.not_measured_inds.append(newind)

        self.n_cached += 1

    # for prioritization
    def compute_cache_probs(self) -> np.ndarray:
        n = min(self.cache_size, self.n_cached)
        stoch_probs = np.power(self.cached_probs[:n], self.stoch_alpha)
        return stoch_probs / np.sum(stoch_probs)

    # randomly sample cache entries (according to cache sampling scheme),
    # return indices
    def recall(self) -> np.ndarray:
        p = np.cumsum(self.compute_cache_probs())
        inds = np.searchsorted(p, np.random.rand(self.batch_size))
        return inds

    # save curr model as a checkpoint
    def save_checkpoint(self) -> None:
        d = datetime.datetime.today()
        path = (
            "checkpoints/"
            + f"{d.year}-{d.month}-{d.day}-{d.hour}-{d.minute}-{d.second}"
        )
        self.save_models(path=path)

    # save a snapshot of the current model status
    # (including networks, optim, steps, and cache) to path in saves folder
    # note that this requires that a "saves" folder exists in the same
    # dir as this file
    # save cache is optional because larger caches can take >1 minute to save
    # (...on my machine that was saving to a SSD, so it
    # may be excruciatingly long for a HDD)
    def save_models(self, path: str = "currmodel", save_cache: bool = True) -> None:
        if "saves" not in os.listdir():
            os.mkdir("saves")

        path_parts = path.split("/")
        curr = "saves"
        for p in path_parts:
            if p not in os.listdir(curr):
                os.mkdir(f"{curr}/{p}")
            curr += "/" + p

        for p in [
            "online.pickle",
            "target.pickle",
            "optim.pickle",
            "params.pickle",
            "cache.pickle",
        ]:
            if p not in os.listdir(f"saves/{path}"):
                open(f"saves/{path}/{p}", "x").close()

        with open(f"./saves/{path}/online.pickle", "wb") as f:
            pickle.dump(self.online.state_dict(), f, pickle.HIGHEST_PROTOCOL)

        with open(f"./saves/{path}/target.pickle", "wb") as f:
            pickle.dump(self.target.state_dict(), f, pickle.HIGHEST_PROTOCOL)

        with open(f"./saves/{path}/optim.pickle", "wb") as f:
            pickle.dump(self.optimizer.state_dict(), f, pickle.HIGHEST_PROTOCOL)

        with open(f"./saves/{path}/params.pickle", "wb") as f:
            pickle.dump({"step": self.step}, f, pickle.HIGHEST_PROTOCOL)

        if save_cache:
            with open(f"./saves/{path}/cache.pickle", "wb") as f:
                pickle.dump(
                    {
                        "n_cached": self.n_cached,
                        "cached": self.cached,
                        "cached_probs": self.cached_probs,
                        "cache_size": self.cache_size,
                    },
                    f,
                    pickle.HIGHEST_PROTOCOL,
                )

    # batch together entries of the cache for use in training
    def cachetodict(
        self, cache_entries: torch.Tensor
    ) -> dict[torch.Tensor, np.ndarray, torch.Tensor, torch.Tensor, torch.Tensor]:
        return {
            "s": torch.concat([c[0].to(self.device) for c in cache_entries], dim=0),

            # this gives us indices where the s' for one state ends and
            # another begins, since they are all batched together
            "divs": np.cumsum([0] + [c[1].shape[0] for c in cache_entries]),
            
            "s'": torch.concat([c[1].to(self.device) for c in cache_entries], dim=0),
            "r": torch.tensor([c[2] for c in cache_entries], device=self.device),
            "d": torch.tensor([c[3] for c in cache_entries], device=self.device),
        }

    # do a backpropogation step
    def update_online(self) -> None:
        if self.do_prioritization:
            self.get_new_measures()

        recall_inds = self.recall()
        r = self.cachetodict(self.cached[recall_inds])

        with torch.no_grad():
            self.online.eval()
            self.target.eval()
            lines = self.online(r["s'"]).flatten()

            # inds for the s' with max q value as
            # determined by online
            inds = torch.tensor(  
                [
                    torch.argmax(lines[r["divs"][j] : r["divs"][j + 1]]).item()
                    + r["divs"][j]
                    for j in range(r["divs"].shape[0] - 1)
                ]
            )

            ap = r["s'"][inds]
            Qpt = self.target(ap).flatten()  # target q-values for those s' from above

            # bellman equation :)
            # use 1-r["d"] to ensure that a terminal state's reward is
            # calculated correctly;
            # if s is a terminal state, its cumulative reward is
            # just its reward alone
            td_target =  r["r"] + (1 - r["d"]) * self.discount_rate * Qpt

        self.online.train()
        self.optimizer.zero_grad()
        td_est = self.online(r["s"]).flatten()
        delta_torch = torch.abs((td_target - td_est).detach())
        delta = delta_torch.cpu().numpy()

        if self.do_prioritization:
            # bias anealling, see paper for details
            N = min(self.cache_size, self.n_cached)
            computed_cached_probs = self.compute_cache_probs()
            weight_corrections = np.power(
                computed_cached_probs[recall_inds] * N, -self.correction_beta
            )
            weight_corrections = weight_corrections / np.max(weight_corrections)
            weight_corrections = torch.from_numpy(weight_corrections).to(self.device)

            # (we can do this multiplication on the values
            # themselves instead of the losses
            # since we use L1 loss; they are mathematically equivalent
            # since the weights are > 0)
            td_est_corr = td_est * weight_corrections
            td_target_corr = td_target * weight_corrections

            loss = self.loss_fn(td_est_corr, td_target_corr)
            
            self.cached_probs[recall_inds] =  delta + self.delta_epsilon
        else:
            loss = self.loss_fn(td_est, td_target)

        # recording some results from this training step for information
        self.prev_loss.append(np.mean(delta))
        self.prev_est.append(td_est.mean().item())
        self.prev_target.append(td_target.mean().item())
        self.prev_rew.append(r["r"].double().mean().item())

        # print summary of training results for information
        if self.step % (40 * self.train_every) == 0:
            print("est:", sixnum(self.prev_est))
            print("target:", sixnum(self.prev_target))
            print("loss:", sixnum(self.prev_loss))
            print("rew:", sixnum(self.prev_rew))
            self.prev_loss = []
            self.prev_est = []
            self.prev_target = []
            self.prev_rew = []

        # update weights, finally
        loss.backward()
        self.optimizer.step()

    # for entries that do not have prioritization measures, get the measures by
    # getting the td error
    # (basically a truncated version of update_online)
    def get_new_measures(self) -> None:
        notmeasured = np.array(self.not_measured_inds)
        r = self.cachetodict(self.cached[notmeasured])

        with torch.no_grad():
            self.online.eval()
            self.target.eval()
            lines = self.online(r["s'"]).flatten()
            inds = torch.tensor(
                [
                    torch.argmax(lines[r["divs"][j] : r["divs"][j + 1]]).item()
                    + r["divs"][j]
                    for j in range(r["divs"].shape[0] - 1)
                ]
            )
            ap = r["s'"][inds]
            Qpt = self.target(ap).flatten()
            td_target = r["r"] + (1 - r["d"]) * self.discount_rate * Qpt
            td_est = self.online(r["s"]).flatten()

        delta = np.abs((td_target - td_est).cpu().detach().numpy())
        self.cached_probs[notmeasured] = delta + self.delta_epsilon

        self.not_measured_inds = []

    # this will be run at every step to determine what to do with the model
    def learn(self) -> None:
        if self.n_cached > self.wait_to_train:
            if self.step % self.train_every == 0:
                self.update_online()

            if (self.step % self.sync_every == 0): 
                self.target.load_state_dict(self.online.state_dict())
                for p in self.target.parameters():
                    p.requires_grad = False

                # since target model has updated, throw away previous training
                # values, since they dont represent the current model
                self.prev_loss = []
                self.prev_est = []
                self.prev_target = []
                self.prev_rew = []

        if self.step % self.checkpoint_every == 0:
            self.save_checkpoint()

        if self.step % self.save_every == 0:
            # we can save a lot of time by copying over the most recent
            # checkpoint cache, since saving the current cache every time would
            # be too slow
            try:
                most_recent = sorted(os.listdir("saves/checkpoints"))[-1]
            except FileNotFoundError or IndexError:
                self.save_models(save_cache=True)
            else:
                self.save_models(save_cache=False)
                recent_cache = f"saves/checkpoints/{most_recent}/cache.pickle"
                curr_cache = "saves/currmodel/cache.pickle"
                shutil.copyfile(recent_cache, curr_cache)

Now, we set up some various helper functions for the training loop.

In [7]:
# given a rotation and movement to do, creates a list of actions that
# can be performed one-by-one in the game to achieve it
# this also makes sure to make the list of actions as short as possible
# an action is a 3-int list where
# [x-movement, y-movement, rotation]
# where +1 rotates right and -1 rotates left
# and similarly for the movements
def construct_actionlist(rotation: int, move: int) -> list:
    rot = np.sign(rotation)
    r = abs(rotation)
    mov = np.sign(move)
    m = abs(move)

    ret = np.zeros((max(r, m), 3), dtype=int)
    ret[:m, 0] = int(mov)
    ret[:r, 2] = int(rot)
    return ret.tolist()


# given starting states and a block to place, returns all of the possible
# next states from each of the starting states
# as well as their corresponding actions
# this seemingly simple function is so monstrous because it's vectorized;
# if it was made shorter using python loops it would run too slow
def get_sps(
    states: np.ndarray, block: piece
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    safety = 0

    # make sure piece is in default position
    block.rotate_piece(-block.orient)

    heights = np.sum(np.cumsum(states, axis=1) > 0, axis=1)

    buffer = 3

    # starting states still batched, but with buffer
    c = np.ones(
        (states.shape[0], states.shape[1] + 2 * buffer, states.shape[2] + 2 * buffer)
    )
    c[:, buffer:-buffer, buffer:-buffer] = states

    # from very left==0 the "middle" of the board is at index 3
    # (when placing piece via binary array)
    middle = 3

    # given that the array is 4x4, there are (boardwidth+2*buffer)-4 places
    # to put it on the board without indexing out of bounds
    movements = np.arange(c.shape[2] - 4)
    rotations = np.array(block.rotations)
    mesh = np.meshgrid(movements, rotations)

    # all combinations of movement + rotation in a nx2 array
    mr = np.concatenate(
        [mesh[0].flatten()[:, None], mesh[1].flatten()[:, None]], axis=1
    )
    longmr = np.tile(mr, (c.shape[0], 1))

    # now, there is a starting state for each movement+rotation
    # combination for each starting state
    longc = c.repeat(mr.shape[0], axis=0)

    # retain original indices of each starting state in the original batch
    c_inds = np.arange(c.shape[0]).repeat(mr.shape[0])

    # now, have the associated rotated array of the block alongside
    # each starting state + movement + rotation combination
    longblocks = np.zeros((longmr.shape[0], 4, 4))
    for i in block.rotations:
        block.rotate_piece(i - block.orient)
        longblocks[longmr[:, 1] == i] = block.piecearray[None]
    block.rotate_piece(-block.orient)

    # now, start working on minimum falling distance for each tetromino
    mid_move = longmr.copy()
    mid_move[:, 0] -= buffer + middle

    # if going to be doing no movement and no rotation, then
    # min-falling is just 0, so leave it alone
    mask = np.sum(np.abs(mid_move), axis=1) != 0
    minl = np.zeros(longc.shape[0])

    # if you are going to be doing some movement, then min fall is equal to
    # the number of actions you have to do on the way down
    # (since gravity acts once per frame and a movement
    # and rotation can be done together in one frame)
    # (plus a safety factor)
    minl[mask] = np.max(np.abs(mid_move[mask]), axis=1) + safety

    p = 2  # pieces start 2 tiles above the ceiling

    # this gives us the starting height of each tile of the 4x4 tetromino array
    ha = (
        longblocks * np.arange(
            states.shape[1] + p, 
            states.shape[1] + p - 4, 
            step=-1
            )[None, :, None]
    )

    # this is the minimum falling dist for each tetromino to make sure that
    # none of the tetromino is in the ceiling, which would be bad
    ha_max = np.max(ha, axis=(1, 2)) - states.shape[1]
    minl[ha_max > minl] = ha_max[ha_max > minl]

    h_large = 100
    ha[ha == 0] = h_large

    # lowest starting height for each column of the tetromino
    ha = np.min(ha, axis=1)

    # include the buffer in the heights to make sure that it doesnt think
    # it can place blocks in the wall; make the height of the buffer as
    # high as the wall
    heights = np.concatenate(
        [
            np.full((states.shape[0], buffer), states.shape[1]),
            heights,
            np.full((states.shape[0], buffer), states.shape[1]),
        ],
        axis=1,
    ) 

    # make sure heights line up with repeated starting states
    heightsr = heights.repeat(mr.shape[0], axis=0)

    # for each movement option, we have to consider each of the 4
    # columns of the tetromino array
    indx = longmr[:, 0:1] + np.arange(4)[None]

    # take the distance between the bottom of each column of the tetromino
    # and the blocks below (after it has been moved and rotated),
    # and take the min of them;
    # that is the amount of tiles that the tetromino can fall before being
    # placed, and so it's how much time it has to act and move it and such
    abletofall = (
        np.min(ha - heightsr[np.arange(heightsr.shape[0])[:, None], indx], axis=1) - 1
    )

    # if the amount it can fall is not enough, then that
    # movement + rotation is invalid
    mask = abletofall < minl

    p_rot = np.arange(4) - 1
    inds = np.arange(c.shape[0])

    # the coord cube:
    # a "cube" boolean mask where each dimension corresponds to a rotation,
    # movement, or starting-state index
    # there are a few extra restrictions on movement that we solve with the
    # coord cube:
    coord_cube = np.ones(
        (p_rot.shape[0], inds.shape[0], movements.shape[0]), dtype=bool
    )

    # -anything where abletofall < min_fall is obv invalid
    coord_cube[longmr[:, 1][mask] + 1, c_inds[mask], longmr[:, 0][mask]] = False

    # -if a given movement+rotation+starting state is invalid, then
    #  any movement going further in the same direction should be discarded
    coord_cube[:, :, buffer + middle :] = np.cumprod(
        coord_cube[:, :, buffer + middle :], axis=2
    )
    coord_cube[:, :, : buffer + middle + 1] = np.cumprod(
        coord_cube[:, :, : buffer + middle + 1][:, :, ::-1], axis=2
    )[:, :, ::-1]

    # -if the no-action is invalid for a given starting state, then
    #  make everything for that starting state invalid
    coord_cube = (
        coord_cube * coord_cube[1, :, buffer + middle][None, :, None]
    )

    # now, match up coord cube mask to
    # starting states + movement + rotation combinations
    retmask = coord_cube[longmr[:, 1] + 1, c_inds, longmr[:, 0]]

    # return values
    retc = longc[retmask]
    retblocks = longblocks[retmask]
    retmr = longmr[retmask]
    retfall = abletofall[retmask] - p + buffer
    retcind = c_inds[retmask]

    # however, we now need to actually place the tetrominoes on the
    # starting states that are valid
    # the blocks are already rotated, so just need to move and fall
    mr_rep = retmr[:, 0:1] + np.arange(4)[None]
    fall_rep = retfall[:, None] + np.arange(4)[None]
    mr_matr = mr_rep[:, None].repeat(4, axis=1)
    fall_matr = fall_rep[:, None].repeat(4, axis=1).transpose((0, 2, 1))
    """
    above, essentially doing:
    1 1 1 1     1 2 3 4
    2 2 2 2 and 1 2 3 4
    3 3 3 3     1 2 3 4
    4 4 4 4     1 2 3 4 
    for the coordinate offsets to get placement coordinates for each tile
    (notice how theyre transposes)
    """

    retc[
        # 16 tiles to place for each state
        np.arange(retc.shape[0]).repeat(4 * 4),
        fall_matr.flatten().astype(int),
        mr_matr.flatten().astype(int),
    ] += retblocks.flatten()

    retmr[:, 0] -= buffer + middle  # since the very left was 0 before

    retc = retc[:, buffer:-buffer, buffer:-buffer]

    # we return the after-states, the movement+rotations to get to those
    # after-states, and the indices of each after-state corresponding to
    # their starting-state in the "states" argument
    return (retc, retmr, retcind,)

# get features for a batch of tetris board states
# features are
# heights of each column, number of line clears, and number of holes
def batchedfeatures(
    states: np.ndarray,
    device: str = "cpu",
    clip_features: bool = True,
    exlines: np.ndarray = None,
    exlines_operation: str = "max",
) -> torch.Tensor:
    states = torch.from_numpy(states).to(device)

    clearmask = torch.sum(states, dim=2) == 10
    clears = torch.sum(clearmask, dim=1, keepdim=True, dtype=torch.double)

    clearsums = torch.cumsum(clearmask, dim=1)[:, :, None]
    statesums = torch.cumsum(states, dim=1)
    colholes = torch.sum(((statesums - clearsums) > 0) & (states == 0), dim=1)
    holes = torch.sum(colholes, dim=1, keepdim=True, dtype=torch.double)

    numblocks = torch.sum(states, dim=1) - clears
    heights = (numblocks + colholes).double()

    # when we do lookahead, the features for a state do not include the lines
    # achieved after the first block was placed, so we pass those clears in
    # manually
    if exlines is not None:
        exlines = torch.from_numpy(exlines).to(device).double()
        clears = clears[:, 0]

        match exlines_operation:
            case "sum":
                clears += exlines
            case _:
                clears[clears < exlines] = exlines[clears < exlines]

        clears = clears[:, None]

    if clip_features:
        maxholes = 60
        maxclears = 4
        maxheight = 20

        heights = heights / maxheight
        clears = clears / maxclears
        holes = holes / maxholes

    return torch.cat([heights, holes, clears], dim=1)


# given some states, remove the line clears and return the cleared states
# and the number of lines from each one
def remove_lines(states: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    mask1 = np.sum(states, axis=2) != states.shape[2]
    lines = np.sum(~mask1, axis=1)
    maxlines = np.max(lines)
    t = np.concatenate(
        [np.zeros((states.shape[0], maxlines, states.shape[2])), states], axis=1
    )
    mask2 = np.concatenate(
        [np.ones((mask1.shape[0], maxlines), dtype=bool), mask1], axis=1
    )
    mask2 *= np.cumsum(mask2[:, ::-1], axis=1)[:, ::-1] <= states.shape[1]

    return (
        t.reshape(t.shape[1] * t.shape[0], t.shape[2])[mask2.flatten()].reshape(
            states.shape
        ),
        lines,
    )


# wrapper for get_sps that mirrors the structure of get_sps_twoblock
def get_sps_one_block(
    state: np.ndarray, block: piece
) -> dict[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    first = get_sps(state[None], block)
    return {
        "sps": first[0],
        "movement_rotation": first[1],
        "old_state_inds": first[2],
        "exlines": np.zeros((first[0].shape[0],)),
    }


# given two blocks, get all of the possible states after placing both blocks
# by calling get_sps twice
# this is how we do lookahead; consider all ways to place both blocks,
# compute features for all after-states, and then choose the best
# ensures that the states are accurate by removing lines that are cleared
# after placing the first block
# also returns the line clears that were achieved after placing the first block
def get_sps_two_blocks(
    state: np.ndarray, block1: piece, block2: piece
) -> dict[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    first = get_sps(state[None], block1)
    rmlines = remove_lines(first[0])
    second = get_sps(rmlines[0], block2)
    moves = first[1][second[2]]
    exlines = rmlines[1][second[2]]
    return {
        "sps": second[0],
        "movement_rotation": moves,
        "old_state_inds": second[2],
        "exlines": exlines,
    }

Now, for the actual training loop.

In [None]:
# NOTE: i have only tested loading and saving networks on a windows machine,
# i dont see why it shouldnt work elsewhere, but there is no guarentee

# TO TRAIN YOUR OWN NETWORK:
# -either set usecurr to false or set load_path to None, and run this cell
# -when the bot eventually saves, in this current directory, if it doesnt exist,
#  a folder called "saves" will be created, and another folder called "currmodel"
#  will be created inside of it. Additionally,
#  a folder called "checkpoints" will periodically save
#  timestamped model checkpoints
# -the saved bot can be run by setting load_path to "currmodel", or
#  run a checkpoint by setting load_path to "checkpoints/{checkpoint name}"

# For training:
# set force_model to False (so that it actually trains)
# for a large training speedup, set render and midrender to False


usecurr = True  # use a saved model or start fresh
force_model = True  # force network eval or do training
use_queue = True  # use lookahead by considering next block
render = True  # render the game as it's played
mid_render = True  # render the blocks falling, or skip the falling animation
fps = 1e3  # playing speed in frames per second, maximum 1e3
# (fps may be innacurate at higher values due to computation speed)

# examples of "good" models from training
# 2mil is most safe, 6mil is most reckless (goes for full tetrises),
# and 3mil is in-between
paths = ["steps_2mil", "steps_3mil", "steps_6mil"]

# *****change this value here to use one of the other pre-trained models*****
load_path = paths[2]

bot = tbot(load_path=load_path if usecurr else None, load_cache=not force_model)
env = tetris_board()

render_wait = max(int(1e3 / fps), 1)

# note: the game will play and the training/testing loop will
# run until the cell is interrupted
while True:
    done = False

    print("======================================================")
    print(f"steps: {bot.step}")
    print(f"eps: {bot.epsilon}")

    env.reset()

    prev_lines = env.line_clears

    st = 1

    while True:
        # line clears for this state; clears that happened between
        # last state and this one
        curr_lines = env.line_clears - prev_lines

        prev_lines = env.line_clears

        # reward function; 1 for each piece placed plus a quadratic
        # factor of the line clears (similar to NES tetris scoring)
        state_reward = 1 + curr_lines**2

        curr_state = env.get_wboard()

        if force_model and use_queue:
            sps = get_sps_two_blocks(curr_state, env.curr_piece, env.next_piece)
        else:
            sps = get_sps_one_block(curr_state, env.curr_piece)

        # this indicates that there is nowhere to place the
        # current tetromino (or two tetrominoes), so a loss
        if sps["sps"].shape[0] == 0:  
            done = True

            # this doesnt really matter since the model value for this
            # will never be used
            sps_features = torch.zeros(
                (1, 12), device=bot.device, dtype=torch.float64
            )
        else:
            # features of the upcoming states;
            # read sps as "s-primes", multiple of s'
            sps_features = batchedfeatures(
                sps["sps"], device=bot.device, exlines=sps["exlines"]
            )

            # index of the state in sps with the greatest estimated q-value;
            # we will use this index to access everything else
            # about this state from sps
            bot_act_ind: int = bot.act(sps_features, force_model=force_model)
            move_rot = sps["movement_rotation"]
            actions: list = construct_actionlist(
                move_rot[bot_act_ind, 1], move_rot[bot_act_ind, 0]
            )
            push_down = np.zeros((30, 3), dtype=int)
            push_down[:, 1] = 1
            actions += push_down.tolist()

            placed = False

            # perform each actual action and then move down until placed
            for m in actions:
                placed = env.step(m)

                if placed:
                    break

                if render and mid_render:  # rendering block on its way down
                    env.render(render_wait)

            if not force_model:  # force model is esentially eval mode
                bot.learn()

        # at st==1, the curr state was not "reached" by any means,
        # so the bot doesnt have to know how to estimate its q-value
        if st != 1 and not force_model:
            curr_features = batchedfeatures(
                curr_state[None], device=bot.device, exlines=np.array(curr_lines)[None]
            )
            bot.cache((curr_features, sps_features, state_reward, int(done)))

        st += 1
        if render:
            env.render(render_wait)

        if done:
            print(
                "total reward:",
                env.n_pieces
                + np.sum(
                    [env.clear_types[str(i + 1)] * (i + 1) ** 2 for i in range(4)]
                ),
            )
            print("total lines:", env.line_clears)
            print("total pieces:", env.n_pieces)
            break