### Load the dataset

In [83]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()    

In [84]:
# Asset root for rendering. You can change this if you want to use custom game assets.
ASSET_ROOT = "../data/assets/"

In [85]:
# Rendering and display
from grid_universe.renderer.texture import TextureRenderer
from IPython.display import display

In [86]:
# Default renderer used throughout the notebook unless overridden in a cell
renderer = TextureRenderer(resolution=240, asset_root=ASSET_ROOT)
renderer_large = TextureRenderer(resolution=480, asset_root=ASSET_ROOT)

In [87]:
import os
import numpy as np
import torch, torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from typing import List
from PIL import Image, ImageDraw
from IPython.display import display
from sklearn.model_selection import train_test_split
import time
import random

In [None]:
### INSPIRED BY VGG16 architecture
### WHY I USED BATCHNORM: https://towardsdatascience.com/exploring-the-superhero-role-of-2d-batch-normalization-in-deep-learning-architectures-b4eb869e8b60/
def get_model(num_classes: int) -> nn.Module:
    res = nn.Sequential(
        nn.Conv2d(4, 8, kernel_size=3, padding=1),
        nn.LeakyReLU(0.1),
        nn.BatchNorm2d(8),
        nn.MaxPool2d(2),
        nn.Conv2d(8, 16, kernel_size=3, padding=1),
        nn.LeakyReLU(0.1),
        nn.BatchNorm2d(16),
        nn.MaxPool2d(2),

        nn.Flatten(),

        nn.Linear(2304, 64),
        nn.LeakyReLU(0.1),
        nn.Linear(64, num_classes)
    )
    return res



### The augmentation needs to randomly draw the directional triangle into the picture classes where there could be a direction

In [None]:
### TAKEN FROM GRID UNIVERSE LIBRARY
def draw_direction_triangles_on_image(
    image: Image.Image, size: int, dx: int, dy: int, count: int
) -> Image.Image:
    """
    Draw 'count' filled triangles pointing (dx, dy) on the given RGBA image.
    Triangles are centered: the centroid of each triangle is symmetrically arranged
    around the image center. Spacing is between triangle centroids.
    """
    if count <= 0 or (dx, dy) == (0, 0):
        return image

    draw = ImageDraw.Draw(image)
    cx, cy = size // 2, size // 2

    # Triangle geometry (relative to size)
    tri_height = max(4, int(size * 0.16))
    tri_half_base = max(3, int(size * 0.10))
    spacing = max(2, int(size * 0.12))  # distance between triangle centroids

    # Axis-aligned direction and perpendicular
    ux, uy = dx, dy  # points toward the triangle tip
    px, py = -uy, ux  # perpendicular (for base width)

    # Offsets for centroids: 1 -> [0], 2 -> [-0.5s, +0.5s], 3 -> [-s, 0, +s], ...
    offsets = [(i - (count - 1) / 2.0) * spacing for i in range(count)]

    # For an isosceles triangle, the centroid lies 1/3 of the height from the base toward the tip.
    # If C is the centroid, then:
    #   tip = C + (2/3)*tri_height * u
    #   base_center = C - (1/3)*tri_height * u
    tip_offset = (2.0 / 3.0) * tri_height
    base_offset = (1.0 / 3.0) * tri_height

    for off in offsets:
        # Centroid position
        Cx = cx + int(round(ux * off))
        Cy = cy + int(round(uy * off))

        # Tip and base-center positions
        tip_x = int(round(Cx + ux * tip_offset))
        tip_y = int(round(Cy + uy * tip_offset))
        base_x = int(round(Cx - ux * base_offset))
        base_y = int(round(Cy - uy * base_offset))

        # Base vertices around base center along the perpendicular
        p1 = (tip_x, tip_y)
        p2 = (
            int(round(base_x + px * tri_half_base)),
            int(round(base_y + py * tri_half_base)),
        )
        p3 = (
            int(round(base_x - px * tri_half_base)),
            int(round(base_y - py * tri_half_base)),
        )

        draw.polygon([p1, p2, p3], fill=(255, 255, 255, 220), outline=(0, 0, 0, 220))

    return image



In [90]:
labels = ['boots', 'box', 'coin', 'exit', 'floor', 'gem', 'ghost', 'human', 'key', 'lava', 'locked', 'metalbox', 'opened', 'portal', 'robot', 'shield', 'sleeping', 'spike', 'wall', 'wolf', 'dragon']

ASSET_DIR = "../data/assets/imagen1"
class RandomDirection:
    def __init__(self):
        pass

    def __call__(self, image: Image.Image, label: int):
        # make image overlay with a background
        if labels[label] != 'floor':
            floor_dir = os.path.join(ASSET_DIR, "floor")
            floor_files = [
                f for f in os.listdir(floor_dir)
                if f.lower().endswith((".png"))
            ]

            if len(floor_files) > 0:
                fname = random.choice(floor_files)
                floor_path = os.path.join(floor_dir, fname)
                with open(floor_path, "rb") as f:
                    floor_img = Image.open(f).convert("RGBA")

                if floor_img.size != image.size:
                    floor_img = floor_img.resize(image.size)

                bg = floor_img.copy()
                bg.paste(image, (0, 0), image)
                image = bg

        if labels[label] == 'box' or labels[label] == 'metalbox' or labels[label] == 'robot':
            directions = [(1, 0), (-1, 0), (0, 1), (0, -1)]
            random.seed(time.time())
            decision = random.randint(-1, 3)

            if decision < 0:
                return image

            dx, dy = directions[decision]
            new_image: Image.Image = draw_direction_triangles_on_image(image=image, size=image.size[0], dx=dx, dy=dy, count=1)       

            return new_image
        return  image



In [91]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, transform):
        self.base_dataset = base_dataset
        self.random_direction = RandomDirection()
        self.targets = base_dataset.targets
        self.transform = transform

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image, label = self.base_dataset[idx]
        new_image = self.random_direction(image, label)
        image = self.transform(new_image)
        return image, label
    



In [None]:
### PARTS TAKEN FROM CHATGPT, conversation link: https://chatgpt.com/share/69115bb7-ae24-8008-a13d-26cafe249446 . PLEASE SCROLL TO THE BOTTOM

def get_augmentations():
    return transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(48),
        transforms.ColorJitter(
            brightness=0.3,
            contrast=0.5,
            saturation=0.3,
            hue=0.02,
        ),
        transforms.ToTensor(),
    ])

In [93]:
ASSET_DIR = "../data/assets/imagen1"

def rgba_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert("RGBA") 

base_dataset = ImageFolder(root=ASSET_DIR, loader=rgba_loader)
dataset = CustomDataset(base_dataset, transform=get_augmentations())

In [94]:
base_dataset[0][0].mode

'RGBA'

In [95]:
base_dataset.class_to_idx

{'boots': 0,
 'box': 1,
 'coin': 2,
 'exit': 3,
 'floor': 4,
 'gem': 5,
 'ghost': 6,
 'human': 7,
 'key': 8,
 'lava': 9,
 'locked': 10,
 'metalbox': 11,
 'opened': 12,
 'portal': 13,
 'robot': 14,
 'shield': 15,
 'sleeping': 16,
 'spike': 17,
 'wall': 18}

In [96]:
targets = dataset.targets
indices = list(range(len(targets)))

train_idx, test_idx = train_test_split(
    indices,
    test_size=0.4,
    random_state=int(time.time())
)

train_data = Subset(dataset, train_idx)
test_data = Subset(dataset, test_idx)

In [97]:
train_loader = DataLoader(train_data, batch_size=256, shuffle=True)
test_loader = DataLoader(test_data, batch_size=256, shuffle=True)

In [98]:
from matplotlib import pyplot as plt

In [99]:
def get_accuracy(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
    y_pred = torch.argmax(pred, dim=1).long()
    label = label.view(-1).long() 
    return (y_pred == label).float().mean()

def get_model_accuracy(model: nn.Module):
    with torch.no_grad():
        model.eval()

        sum_acc = 0
        cnt = 0

        for x, y in test_loader:
            pred = model(x)

            sum_acc += get_accuracy(pred, y)
            cnt += 1
        
        return float(sum_acc / cnt)

        

In [100]:
def train_model(loader: torch.utils.data.DataLoader, model: nn.Module):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()  

    epoch_losses = []
    for i in range(50):
        epoch_loss = 0
        model.train()
        for _, (x, y) in enumerate(loader):
            optimiser.zero_grad()
            # x, y = x.cuda(), y.cuda()
            y_pred = model(x)
            
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()
            loss.backward()
            optimiser.step()

        epoch_loss = epoch_loss / len(loader)
        epoch_losses.append(epoch_loss)
        print("Epoch: {}, Loss: {}, Accuracy: {}".format(i, epoch_loss, get_model_accuracy(model)))
        

    return model, epoch_losses

In [101]:
model = get_model(21)
train_model(train_loader, model)

Epoch: 0, Loss: 2.906214714050293, Accuracy: 0.10497237741947174
Epoch: 1, Loss: 2.081867814064026, Accuracy: 0.14364640414714813
Epoch: 2, Loss: 1.5715546607971191, Accuracy: 0.24861878156661987
Epoch: 3, Loss: 1.3893590569496155, Accuracy: 0.22651933133602142
Epoch: 4, Loss: 1.134418249130249, Accuracy: 0.23204420506954193
Epoch: 5, Loss: 0.8313997089862823, Accuracy: 0.20994475483894348
Epoch: 6, Loss: 0.6465904116630554, Accuracy: 0.24309392273426056
Epoch: 7, Loss: 0.6601065993309021, Accuracy: 0.27624309062957764
Epoch: 8, Loss: 0.4334820806980133, Accuracy: 0.2983425557613373
Epoch: 9, Loss: 0.591965064406395, Accuracy: 0.33701658248901367
Epoch: 10, Loss: 0.25143951177597046, Accuracy: 0.45303866267204285
Epoch: 11, Loss: 0.3218632936477661, Accuracy: 0.5138121843338013
Epoch: 12, Loss: 0.34115029871463776, Accuracy: 0.6353591084480286
Epoch: 13, Loss: 0.2238486409187317, Accuracy: 0.6629834175109863
Epoch: 14, Loss: 0.21913491934537888, Accuracy: 0.7513812184333801
Epoch: 15, 

(Sequential(
   (0): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (1): LeakyReLU(negative_slope=0.1)
   (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (4): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (5): LeakyReLU(negative_slope=0.1)
   (6): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (8): Flatten(start_dim=1, end_dim=-1)
   (9): Linear(in_features=2304, out_features=64, bias=True)
   (10): LeakyReLU(negative_slope=0.1)
   (11): Linear(in_features=64, out_features=21, bias=True)
 ),
 [2.906214714050293,
  2.081867814064026,
  1.5715546607971191,
  1.3893590569496155,
  1.134418249130249,
  0.8313997089862823,
  0.6465904116630554,
  0.6601065993309021,
  0.4334820806980133,
  0.591965064406395,
  0.

In [102]:
get_model_accuracy(model)

0.9668508172035217

In [103]:
torch.save(model.state_dict(), "image_models/crop-center")

In [104]:
def get_accuracy_for_each_label(model: nn.Module):
    for epoch in range(0, 3):
        with torch.no_grad():
            model.eval()
            
            acc: List[int] = [0 for i in range(21)]
            cnt: List[int] = [0 for i in range(21)]

            for x, y in test_loader:
                for label in y:
                    cnt[label] += 1
                
                pred = model(x)
                y_pred = torch.argmax(pred, dim=1).long()
                label = y.view(-1).long()

                for i in range(len(y_pred)):
                    if y_pred[i] == label[i]:
                        acc[label[i]] += 1
                    else:
                        print("WRONG LABEL, expected: {}, got: {}".format(label[i], y_pred[i]))

            for i in range(21):
                print("Total count: {}, label: {}, Accuracy: {}".format(cnt[i], i, float(acc[i] / cnt[i])))

get_accuracy_for_each_label(model)

WRONG LABEL, expected: 13, got: 4
WRONG LABEL, expected: 10, got: 16
WRONG LABEL, expected: 10, got: 18
WRONG LABEL, expected: 14, got: 7
WRONG LABEL, expected: 15, got: 7
WRONG LABEL, expected: 6, got: 17
WRONG LABEL, expected: 10, got: 16
Total count: 13, label: 0, Accuracy: 1.0
Total count: 7, label: 1, Accuracy: 1.0
Total count: 13, label: 2, Accuracy: 1.0
Total count: 8, label: 3, Accuracy: 1.0
Total count: 8, label: 4, Accuracy: 1.0
Total count: 10, label: 5, Accuracy: 1.0
Total count: 7, label: 6, Accuracy: 0.8571428571428571
Total count: 13, label: 7, Accuracy: 1.0
Total count: 9, label: 8, Accuracy: 1.0
Total count: 11, label: 9, Accuracy: 1.0
Total count: 7, label: 10, Accuracy: 0.5714285714285714
Total count: 13, label: 11, Accuracy: 1.0
Total count: 10, label: 12, Accuracy: 1.0
Total count: 8, label: 13, Accuracy: 0.875
Total count: 10, label: 14, Accuracy: 0.9
Total count: 10, label: 15, Accuracy: 0.9
Total count: 7, label: 16, Accuracy: 1.0
Total count: 8, label: 17, Accu

ZeroDivisionError: division by zero

In [28]:
model = get_model(21)
model.load_state_dict(torch.load("image_models/crop-center"))

<All keys matched successfully>

In [29]:
# Autoâ€‘generate a loader snippet for the trained PyTorch model
from utils import generate_torch_loader_snippet

example_input = torch.tensor(
    [[0.0, 0.0]], dtype=torch.float32
)  # minimal example for tracing if needed
snippet = generate_torch_loader_snippet(
    model=model, prefer="auto"
, compression="zlib")

with open("exported_models/image_model.py", "w") as f:
    f.write(snippet)