### Load the dataset

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()    

In [138]:
# Asset root for rendering. You can change this if you want to use custom game assets.
ASSET_ROOT = "../data/assets/"

In [139]:
# Rendering and display
from grid_universe.renderer.texture import TextureRenderer
from IPython.display import display

In [140]:
# Default renderer used throughout the notebook unless overridden in a cell
renderer = TextureRenderer(resolution=240, asset_root=ASSET_ROOT)
renderer_large = TextureRenderer(resolution=480, asset_root=ASSET_ROOT)

In [141]:
import os
import numpy as np
import torch, torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from typing import List
from PIL import Image, ImageDraw
from IPython.display import display
from sklearn.model_selection import train_test_split
import time
import random

In [142]:
### INSPIRED BY VGG16 architecture
def get_model(num_classes: int) -> nn.Module:
    res = nn.Sequential(
        nn.Conv2d(3, 8, kernel_size=3, padding=1),
        nn.LeakyReLU(0.1),
        nn.BatchNorm2d(8),
        nn.MaxPool2d(2),
        nn.Conv2d(8, 16, kernel_size=3, padding=1),
        nn.LeakyReLU(0.1),
        nn.BatchNorm2d(16),
        nn.MaxPool2d(2),

        nn.Flatten(),

        nn.Linear(3136, 70),
        nn.LeakyReLU(0.1),
        nn.Dropout(p=0.3),
        nn.Linear(70, num_classes)
    )
    return res



### The augmentation needs to randomly draw the directional triangle into the picture classes where there could be a direction

In [98]:

def draw_direction_triangles_on_image(
    image: Image.Image, size: int, dx: int, dy: int, count: int
) -> Image.Image:
    """
    Draw 'count' filled triangles pointing (dx, dy) on the given RGBA image.
    Triangles are centered: the centroid of each triangle is symmetrically arranged
    around the image center. Spacing is between triangle centroids.
    """
    if count <= 0 or (dx, dy) == (0, 0):
        return image

    draw = ImageDraw.Draw(image)
    cx, cy = size // 2, size // 2

    # Triangle geometry (relative to size)
    tri_height = max(4, int(size * 0.16))
    tri_half_base = max(3, int(size * 0.10))
    spacing = max(2, int(size * 0.12))  # distance between triangle centroids

    # Axis-aligned direction and perpendicular
    ux, uy = dx, dy  # points toward the triangle tip
    px, py = -uy, ux  # perpendicular (for base width)

    # Offsets for centroids: 1 -> [0], 2 -> [-0.5s, +0.5s], 3 -> [-s, 0, +s], ...
    offsets = [(i - (count - 1) / 2.0) * spacing for i in range(count)]

    # For an isosceles triangle, the centroid lies 1/3 of the height from the base toward the tip.
    # If C is the centroid, then:
    #   tip = C + (2/3)*tri_height * u
    #   base_center = C - (1/3)*tri_height * u
    tip_offset = (2.0 / 3.0) * tri_height
    base_offset = (1.0 / 3.0) * tri_height

    for off in offsets:
        # Centroid position
        Cx = cx + int(round(ux * off))
        Cy = cy + int(round(uy * off))

        # Tip and base-center positions
        tip_x = int(round(Cx + ux * tip_offset))
        tip_y = int(round(Cy + uy * tip_offset))
        base_x = int(round(Cx - ux * base_offset))
        base_y = int(round(Cy - uy * base_offset))

        # Base vertices around base center along the perpendicular
        p1 = (tip_x, tip_y)
        p2 = (
            int(round(base_x + px * tri_half_base)),
            int(round(base_y + py * tri_half_base)),
        )
        p3 = (
            int(round(base_x - px * tri_half_base)),
            int(round(base_y - py * tri_half_base)),
        )

        draw.polygon([p1, p2, p3], fill=(255, 255, 255, 220), outline=(0, 0, 0, 220))

    return image



In [86]:
labels = ['boots', 'box', 'coin', 'dragon', 'exit', 'floor', 'gem', 'ghost', 'human', 'key', 'lava', 'locked', 'metalbox', 'opened', 'portal', 'robot', 'shield', 'sleeping', 'spike', 'wall', 'wolf', 'dragon']

class RandomDirection:
    def __init__(self):
        pass

    def __call__(self, image: Image.Image, label: int):
        if labels[label] == 'box' or labels[label] == 'metalbox' or labels[label] == 'robot':
            directions = [(1, 0), (-1, 0), (0, 1), (0, -1)]
            random.seed(time.time())
            decision = random.randint(-1, 3)

            if decision < 0:
                return image

            dx, dy = directions[decision]
            new_image: Image.Image = draw_direction_triangles_on_image(image=image, size=image.size[0], dx=dx, dy=dy, count=1)       

            return new_image
        return  image



In [87]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, transform):
        self.base_dataset = base_dataset
        self.random_direction = RandomDirection()
        self.targets = base_dataset.targets
        self.transform = transform

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image, label = self.base_dataset[idx]
        new_image: Image.Image = self.random_direction(image, label)
        image = self.transform(new_image)
        return image, label
    



In [88]:
def get_augmentations():
    T = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(58),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    ])
    return T

In [89]:
ASSET_DIR = "../data/assets/imagen1"

base_dataset = ImageFolder(root=ASSET_DIR)
dataset = CustomDataset(base_dataset, transform=get_augmentations())

In [120]:
targets = dataset.targets
indices = list(range(len(targets)))

train_idx, test_idx = train_test_split(
    indices,
    test_size=0.2,
    stratify=targets,
    random_state=int(time.time())
)

train_data = Subset(dataset, train_idx)
test_data = Subset(dataset, test_idx)

In [121]:
train_labels = [dataset[i][1] for i in train_idx]
test_labels = [dataset[i][1] for i in test_idx]

In [122]:
train_loader = DataLoader(train_data, batch_size=256, shuffle=True)
test_loader = DataLoader(test_data, batch_size=len(test_data))

In [123]:
def get_accuracy(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
    y_pred = torch.argmax(pred, dim=1).long()
    label = label.view(-1).long() 
    return (y_pred == label).float().mean()

def get_model_accuracy(model: nn.Module):
    with torch.no_grad():
        model.eval()

        sum_acc = 0
        cnt = 0

        for x, y in test_loader:
            pred = model(x)

            sum_acc += get_accuracy(pred, y)
            cnt += 1
        
        return float(sum_acc / cnt)

        

In [124]:
def train_model(loader: torch.utils.data.DataLoader, model: nn.Module):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()  

    epoch_losses = []
    for i in range(50):
        epoch_loss = 0
        model.train()
        for _, (x, y) in enumerate(loader):
            optimiser.zero_grad()
            # x, y = x.cuda(), y.cuda()
            y_pred = model(x)
            
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()
            loss.backward()
            optimiser.step()

        epoch_loss = epoch_loss / len(loader)
        epoch_losses.append(epoch_loss)
        print("Epoch: {}, Loss: {}, Accuracy: {}".format(i, epoch_loss, get_model_accuracy(model)))
        

    return model, epoch_losses

In [125]:
model = get_model(21)
train_model(train_loader, model)

Epoch: 0, Loss: 2.8652151823043823, Accuracy: 0.10891088843345642
Epoch: 1, Loss: 2.105020046234131, Accuracy: 0.1782178282737732
Epoch: 2, Loss: 1.6250371932983398, Accuracy: 0.3465346395969391
Epoch: 3, Loss: 1.2851938009262085, Accuracy: 0.5742574334144592
Epoch: 4, Loss: 1.0332130193710327, Accuracy: 0.6732673048973083
Epoch: 5, Loss: 0.8370908200740814, Accuracy: 0.6930692791938782
Epoch: 6, Loss: 0.7628727853298187, Accuracy: 0.7029703259468079
Epoch: 7, Loss: 0.5791459679603577, Accuracy: 0.7425742745399475
Epoch: 8, Loss: 0.4744180589914322, Accuracy: 0.7722772359848022
Epoch: 9, Loss: 0.37492232024669647, Accuracy: 0.8118811845779419
Epoch: 10, Loss: 0.31579071283340454, Accuracy: 0.8316831588745117
Epoch: 11, Loss: 0.30646124482154846, Accuracy: 0.8514851331710815
Epoch: 12, Loss: 0.27616460621356964, Accuracy: 0.8514851331710815
Epoch: 13, Loss: 0.22948893159627914, Accuracy: 0.8811880946159363
Epoch: 14, Loss: 0.2030021846294403, Accuracy: 0.8811880946159363
Epoch: 15, Loss

(Sequential(
   (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (1): LeakyReLU(negative_slope=0.1)
   (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (4): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (5): LeakyReLU(negative_slope=0.1)
   (6): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (8): Flatten(start_dim=1, end_dim=-1)
   (9): Linear(in_features=3136, out_features=70, bias=True)
   (10): LeakyReLU(negative_slope=0.1)
   (11): Dropout(p=0.3, inplace=False)
   (12): Linear(in_features=70, out_features=21, bias=True)
 ),
 [2.8652151823043823,
  2.105020046234131,
  1.6250371932983398,
  1.2851938009262085,
  1.0332130193710327,
  0.8370908200740814,
  0.7628727853298187,
  0.5791459679603577,
  0.47

In [135]:
torch.save(model.state_dict(), "image-classification-stratified-80-20")

In [136]:
def get_accuracy_for_each_label(model: nn.Module):
    for epoch in range(0, 3):
        with torch.no_grad():
            model.eval()
            
            acc: List[int] = [0 for i in range(21)]
            cnt: List[int] = [0 for i in range(21)]

            for x, y in test_loader:
                for label in y:
                    cnt[label] += 1
                
                pred = model(x)
                y_pred = torch.argmax(pred, dim=1).long()
                label = y.view(-1).long()

                for i in range(len(y_pred)):
                    if y_pred[i] == label[i]:
                        acc[label[i]] += 1
                    else:
                        print("WRONG LABEL, expected: {}, got: {}".format(label[i], y_pred[i]))

            for i in range(21):
                print("Total count: {}, label: {}, Accuracy: {}".format(cnt[i], i, float(acc[i] / cnt[i])))

get_accuracy_for_each_label(model)

WRONG LABEL, expected: 2, got: 17
Total count: 4, label: 0, Accuracy: 1.0
Total count: 4, label: 1, Accuracy: 1.0
Total count: 5, label: 2, Accuracy: 0.8
Total count: 5, label: 3, Accuracy: 1.0
Total count: 5, label: 4, Accuracy: 1.0
Total count: 5, label: 5, Accuracy: 1.0
Total count: 5, label: 6, Accuracy: 1.0
Total count: 5, label: 7, Accuracy: 1.0
Total count: 6, label: 8, Accuracy: 1.0
Total count: 5, label: 9, Accuracy: 1.0
Total count: 5, label: 10, Accuracy: 1.0
Total count: 4, label: 11, Accuracy: 1.0
Total count: 4, label: 12, Accuracy: 1.0
Total count: 5, label: 13, Accuracy: 1.0
Total count: 5, label: 14, Accuracy: 1.0
Total count: 5, label: 15, Accuracy: 1.0
Total count: 4, label: 16, Accuracy: 1.0
Total count: 5, label: 17, Accuracy: 1.0
Total count: 4, label: 18, Accuracy: 1.0
Total count: 5, label: 19, Accuracy: 1.0
Total count: 6, label: 20, Accuracy: 1.0
WRONG LABEL, expected: 2, got: 8
Total count: 4, label: 0, Accuracy: 1.0
Total count: 4, label: 1, Accuracy: 1.0
To

In [None]:
model = get_model(21)
model.load_state_dict(torch.load("image-classification-stratified-80-20", weights_only=True))

In [86]:
# Autoâ€‘generate a loader snippet for the trained PyTorch model
from utils import generate_torch_loader_snippet

example_input = torch.tensor(
    [[0.0, 0.0]], dtype=torch.float32
)  # minimal example for tracing if needed
snippet = generate_torch_loader_snippet(
    model=get_model(21), prefer="auto"
, compression="zlib")

with open("image_model.py", "w") as f:
    f.write(snippet)