In [1]:
import torch
from torch import nn
from torchvision.transforms import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Globally disable gradient calculations, since we're only gonna be doing
# predictions in this notebook
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2d00a499250>

In [2]:
# Constants
PUYO_H = 15
PUYO_W = 16

class_names = np.array(['0', 'J', 'R', 'G', 'B', 'Y', 'P'])

In [3]:
# Helper functions
def crop(img, x, y, width, height):
    return img[y:(y+height), x:(x+width)]

def crop_field_cells(img):
    cells = []
    for r in range(12):
        y = r * PUYO_H
        for c in range(6):
            x = c * PUYO_W
            cells.append(crop(img, x, y, PUYO_W, PUYO_H))
    return cells

In [4]:
# Classifier
class PuyoClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(720, 50),
            nn.ReLU(),
            nn.Linear(50, 7),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = x.view(-1, 3 * PUYO_H * PUYO_W)
        x = self.fc(x)
        return x

model = PuyoClassifier()
model.load_state_dict(torch.load("puyo-classifier-feb8-2021.pt"))
model.to(device)
model.eval()

PuyoClassifier(
  (fc): Sequential(
    (0): Linear(in_features=720, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=7, bias=True)
    (3): ReLU()
  )
)

In [20]:
p1 = {
    "x": 70,
    "y": 40
}

p2 = {
    "x": 314,
    "y": 40
}

def predict_field(img: str, player):
    im = np.array(Image.open(img).resize((1920 // 4, 1080 // 4)))
    field = crop(im, player["x"], player["y"], PUYO_W * 6, PUYO_H * 12)
    cells = crop_field_cells(field)

    ## Convert field puyos to PyTorch tensor
    puyos = np.array(cells)

    # Put in NCHW order
    puyos = puyos.transpose((0, 3, 1, 2))

    # Convert np array to tensor
    puyo_t = torch.tensor(puyos, device=device, dtype=torch.float).contiguous()

    # Normalize to how the data was setup during model training
    puyo_t = ((puyo_t / 255) - 0.5) / 0.5

    return model(puyo_t)

def get_max_pred(pred: torch.Tensor):
    return pred.argmax(axis=1).cpu().numpy().reshape((12, 6))


## Gummy Puyos
![Gummy](./images/regions/p1field.png)

In [21]:
pred = predict_field("./images/regions/p1field.png", p1)
results = get_max_pred(pred)
class_names[results]
print("Player 1")
print(class_names[results])
print()

pred = predict_field("./images/regions/p1field.png", p2)
results = get_max_pred(pred)
class_names[results]
print("Player 2")
print(class_names[results])
print()

Player 1
[['R' 'P' 'P' 'B' 'Y' 'B']
 ['P' 'R' 'Y' 'P' 'R' 'Y']
 ['R' 'P' 'Y' 'R' 'P' 'Y']
 ['P' 'P' 'R' 'R' 'B' 'R']
 ['R' 'R' 'Y' 'B' 'R' 'B']
 ['P' 'Y' 'P' 'B' 'R' 'R']
 ['P' 'B' 'R' 'Y' 'P' 'B']
 ['Y' 'P' 'P' 'B' 'P' 'P']
 ['R' 'R' 'R' 'P' 'Y' 'B']
 ['B' 'P' 'P' 'Y' 'B' 'Y']
 ['B' 'B' 'Y' 'Y' 'B' 'B']
 ['R' 'P' 'P' 'B' 'Y' 'P']]

Player 2
[['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' 'R']
 ['P' 'B' '0' '0' 'R' 'Y']
 ['P' 'R' 'R' '0' 'R' 'P']
 ['B' 'P' 'P' 'Y' 'P' 'P']
 ['B' 'P' 'B' 'Y' 'R' 'R']
 ['Y' 'Y' 'Y' 'B' 'P' 'B']
 ['P' 'B' 'B' 'R' 'P' 'B']
 ['P' 'Y' 'Y' 'R' 'Y' 'Y']
 ['B' 'P' 'P' 'B' 'B' 'R']]



## Sonic Puyos
![Sonic](./images/regions/sonic.png)

In [22]:
pred = predict_field("./images/regions/sonic.png", p1)
results = get_max_pred(pred)
class_names[results]
print("Player 1")
print(class_names[results])
print()

pred = predict_field("./images/regions/sonic.png", p2)
results = get_max_pred(pred)
class_names[results]
print("Player 2")
print(class_names[results])
print()

Player 1
[['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['Y' 'B' '0' '0' '0' 'R']
 ['R' 'R' '0' '0' '0' 'Y']
 ['B' 'B' '0' '0' '0' '0']
 ['G' 'G' '0' '0' '0' '0']
 ['Y' 'G' '0' '0' 'Y' '0']
 ['B' 'Y' 'B' 'B' 'R' '0']
 ['G' 'G' 'Y' 'Y' 'Y' 'R']
 ['G' 'B' 'B' 'B' 'R' 'G']
 ['B' 'R' 'G' 'R' 'B' 'G']
 ['R' 'G' 'G' 'R' 'Y' 'G']]

Player 2
[['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['G' '0' '0' '0' '0' '0']
 ['G' '0' '0' '0' '0' 'B']
 ['R' 'R' 'Y' '0' '0' 'B']
 ['Y' 'Y' 'Y' '0' '0' 'G']
 ['B' 'B' '0' '0' 'G' 'G']
 ['G' 'G' 'G' '0' 'Y' 'B']
 ['B' 'B' 'B' '0' 'Y' 'R']
 ['R' 'Y' 'R' 'Y' 'R' 'R']
 ['G' 'G' 'G' 'B' 'B' 'G']]



## Game Gear Puyos
![Game Gear](./images/regions/gamegear.png)

In [24]:
pred = predict_field("./images/regions/gamegear.png", p1)
results = get_max_pred(pred)
class_names[results]
print("Player 1")
print(class_names[results])
print()

pred = predict_field("./images/regions/gamegear.png", p2)
results = get_max_pred(pred)
class_names[results]
print("Player 2")
print(class_names[results])
print()

Player 1
[['G' '0' '0' '0' '0' 'G']
 ['R' '0' '0' '0' '0' '0']
 ['G' '0' '0' '0' '0' '0']
 ['G' 'B' '0' '0' 'B' 'G']
 ['Y' 'B' 'P' 'G' 'P' 'R']
 ['G' 'P' 'G' 'Y' 'G' 'P']
 ['P' 'B' 'Y' 'B' 'P' 'B']
 ['Y' 'G' 'Y' 'B' 'B' 'G']
 ['P' 'Y' 'P' 'G' 'P' 'P']
 ['Y' 'R' 'B' 'P' 'R' 'P']
 ['R' 'Y' 'P' 'B' 'G' 'Y']
 ['G' 'Y' 'R' 'G' 'G' 'B']]

Player 2
[['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['P' '0' '0' '0' '0' '0']
 ['P' '0' '0' '0' '0' '0']
 ['R' '0' '0' '0' '0' 'B']
 ['R' '0' '0' '0' '0' 'P']
 ['Y' 'Y' '0' '0' '0' 'P']
 ['R' 'R' '0' '0' '0' 'B']
 ['Y' 'Y' 'Y' '0' '0' 'B']
 ['G' 'G' 'G' '0' '0' 'G']]



## MSX Puyos
![MSX](./images/regions/msx.png)

In [25]:
pred = predict_field("./images/regions/msx.png", p1)
results = get_max_pred(pred)
class_names[results]
print("Player 1")
print(class_names[results])
print()

pred = predict_field("./images/regions/msx.png", p2)
results = get_max_pred(pred)
class_names[results]
print("Player 2")
print(class_names[results])
print()

Player 1
[['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' 'Y' '0' '0' '0' '0']
 ['0' 'G' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' 'B' '0' 'B' '0']
 ['0' '0' 'R' 'Y' 'P' '0']
 ['B' '0' 'B' 'R' 'R' 'P']
 ['P' 'B' 'G' 'Y' 'Y' 'P']]

Player 2
[['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' 'Y']
 ['0' '0' '0' '0' '0' 'G']
 ['0' 'G' 'Y' '0' 'R' 'G']
 ['0' 'P' 'Y' '0' 'Y' 'G']
 ['0' 'Y' 'P' 'R' 'Y' 'R']
 ['J' 'Y' 'P' 'B' 'B' 'R']
 ['P' 'P' 'J' 'J' 'G' 'J']
 ['B' 'Y' 'Y' 'P' 'B' 'B']]

