In [1]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()    

In [2]:
ASSET_DIR = "../data/assets/with_direction"
# Asset root for rendering. You can change this if you want to use custom game assets.
ASSET_ROOT = "../data/assets/"

In [3]:
# Rendering and display
from grid_universe.renderer.texture import TextureRenderer
from IPython.display import display
# Default renderer used throughout the notebook unless overridden in a cell
renderer = TextureRenderer(resolution=240, asset_root=ASSET_ROOT)
renderer_large = TextureRenderer(resolution=480, asset_root=ASSET_ROOT)

In [4]:
import os
import numpy as np
import torch, torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from typing import List
from PIL import Image, ImageDraw
from IPython.display import display
from sklearn.model_selection import train_test_split
import time
import random

In [52]:
def get_model(num_classes: int) -> nn.Module:
    res = nn.Sequential(
        nn.Conv2d(3, 8, kernel_size=3, padding=1),
        nn.LeakyReLU(0.1),
        nn.BatchNorm2d(8),
        nn.MaxPool2d(2),

        nn.Flatten(),

        nn.Linear(1800, 64),
        nn.LeakyReLU(0.1),
        nn.Linear(64, num_classes),
    )
    return res    

In [53]:
def draw_direction_triangles_on_image(
    image: Image.Image, size: int, dx: int, dy: int, count: int
) -> Image.Image:
    """
    Draw 'count' filled triangles pointing (dx, dy) on the given RGBA image.
    Triangles are centered: the centroid of each triangle is symmetrically arranged
    around the image center. Spacing is between triangle centroids.
    """
    if count <= 0 or (dx, dy) == (0, 0):
        return image

    draw = ImageDraw.Draw(image)
    cx, cy = size // 2, size // 2

    # Triangle geometry (relative to size)
    tri_height = max(4, int(size * 0.16))
    tri_half_base = max(3, int(size * 0.10))
    spacing = max(2, int(size * 0.12))  # distance between triangle centroids

    # Axis-aligned direction and perpendicular
    ux, uy = dx, dy  # points toward the triangle tip
    px, py = -uy, ux  # perpendicular (for base width)

    # Offsets for centroids: 1 -> [0], 2 -> [-0.5s, +0.5s], 3 -> [-s, 0, +s], ...
    offsets = [(i - (count - 1) / 2.0) * spacing for i in range(count)]

    # For an isosceles triangle, the centroid lies 1/3 of the height from the base toward the tip.
    # If C is the centroid, then:
    #   tip = C + (2/3)*tri_height * u
    #   base_center = C - (1/3)*tri_height * u
    tip_offset = (2.0 / 3.0) * tri_height
    base_offset = (1.0 / 3.0) * tri_height

    for off in offsets:
        # Centroid position
        Cx = cx + int(round(ux * off))
        Cy = cy + int(round(uy * off))

        # Tip and base-center positions
        tip_x = int(round(Cx + ux * tip_offset))
        tip_y = int(round(Cy + uy * tip_offset))
        base_x = int(round(Cx - ux * base_offset))
        base_y = int(round(Cy - uy * base_offset))

        # Base vertices around base center along the perpendicular
        p1 = (tip_x, tip_y)
        p2 = (
            int(round(base_x + px * tri_half_base)),
            int(round(base_y + py * tri_half_base)),
        )
        p3 = (
            int(round(base_x - px * tri_half_base)),
            int(round(base_y - py * tri_half_base)),
        )

        draw.polygon([p1, p2, p3], fill=(255, 255, 255, 220), outline=(0, 0, 0, 220))

    return image

In [54]:
labels = ["no_direction", "right", "left", "down", "up"]
directions = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]

class RandomDirection:
    def __init__(self):
        pass

    def __call__(self, image: Image.Image):
        random.seed(time.time())
        decision = random.randint(0, 4)

        if decision == 0:
            return image, 0

        dx, dy = directions[decision]
        new_image: Image.Image = draw_direction_triangles_on_image(image=image, size=image.size[0], dx=dx, dy=dy, count=1)       

        return new_image, decision


In [55]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, transform):
        self.base_dataset = base_dataset
        self.random_direction = RandomDirection()
        self.targets = base_dataset.targets
        self.transform = transform

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image, _ = self.base_dataset[idx]
        new_image, new_label = self.random_direction(image)
        image = self.transform(new_image)
        return image, new_label


In [56]:
def get_augmentations():
    T = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop((30, 30)),
        transforms.ToTensor()
    ])
    return T

In [57]:
ASSET_DIR = "../data/assets/with_direction"

base_dataset = ImageFolder(root=ASSET_DIR)
dataset = CustomDataset(base_dataset, transform=get_augmentations())

(image, label) = dataset[11]

print(labels[label])
display(image)

right


tensor([[[0.7255, 0.7333, 0.6980,  ..., 0.7098, 0.7725, 0.7843],
         [0.7255, 0.7412, 0.7255,  ..., 0.7608, 0.7373, 0.7412],
         [0.6824, 0.7216, 0.7373,  ..., 0.7451, 0.7765, 0.7412],
         ...,
         [0.6941, 0.7451, 0.7373,  ..., 0.6863, 0.6706, 0.7020],
         [0.7216, 0.7137, 0.6588,  ..., 0.6784, 0.6353, 0.7373],
         [0.6863, 0.6784, 0.6863,  ..., 0.6980, 0.6863, 0.7176]],

        [[0.4078, 0.4078, 0.3725,  ..., 0.3490, 0.4196, 0.4510],
         [0.4118, 0.4196, 0.3961,  ..., 0.4549, 0.4314, 0.4078],
         [0.3686, 0.4000, 0.4118,  ..., 0.4078, 0.4627, 0.4157],
         ...,
         [0.3882, 0.4275, 0.4235,  ..., 0.3725, 0.3569, 0.3765],
         [0.4196, 0.3922, 0.3373,  ..., 0.3529, 0.3294, 0.4000],
         [0.3686, 0.3451, 0.3569,  ..., 0.3686, 0.3647, 0.3961]],

        [[0.1373, 0.1490, 0.1059,  ..., 0.0471, 0.1333, 0.1725],
         [0.1490, 0.1451, 0.1294,  ..., 0.1725, 0.1333, 0.1451],
         [0.0980, 0.1255, 0.1333,  ..., 0.1412, 0.1569, 0.

In [58]:
indices = list(range(len(dataset)))
targets = [dataset[i][1] for i in range(len(dataset))]

train_idx, test_idx = train_test_split(
    indices,
    test_size=0.3,
    random_state=int(time.time()),
    stratify=targets
)

train_data = Subset(dataset, train_idx)
test_data = Subset(dataset, test_idx)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16)

In [59]:
def get_accuracy(pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
    y_pred = torch.argmax(pred, dim=1).long()
    label = label.view(-1).long() 
    return (y_pred == label).float().mean()

def get_model_accuracy(model: nn.Module):
    with torch.no_grad():
        model.eval()

        sum_acc = 0
        cnt = 0

        for x, y in test_loader:
            pred = model(x)

            sum_acc += get_accuracy(pred, y)
            cnt += 1
        
        return float(sum_acc / cnt)

        

In [60]:
def train_model(loader: torch.utils.data.DataLoader, model: nn.Module):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()  

    epoch_losses = []
    for i in range(50):
        epoch_loss = 0
        model.train()
        for _, (x, y) in enumerate(loader):
            optimiser.zero_grad()
            # x, y = x.cuda(), y.cuda()
            y_pred = model(x)
            
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()
            loss.backward()
            optimiser.step()

        epoch_loss = epoch_loss / len(loader)
        epoch_losses.append(epoch_loss)
        print("Epoch: {}, Loss: {}, Accuracy: {}".format(i, epoch_loss, get_model_accuracy(model)))
        

    return model, epoch_losses

In [61]:
model = get_model(5)
train_model(train_loader, model)

Epoch: 0, Loss: 1.2412742376327515, Accuracy: 0.5625
Epoch: 1, Loss: 0.5429351925849915, Accuracy: 0.875
Epoch: 2, Loss: 0.24108202010393143, Accuracy: 0.84375
Epoch: 3, Loss: 0.1644803285598755, Accuracy: 0.96875
Epoch: 4, Loss: 0.1363201141357422, Accuracy: 1.0
Epoch: 5, Loss: 0.048028115183115005, Accuracy: 1.0
Epoch: 6, Loss: 0.018022849806584418, Accuracy: 1.0
Epoch: 7, Loss: 0.0075950766913592815, Accuracy: 1.0
Epoch: 8, Loss: 0.0048752520233392715, Accuracy: 1.0
Epoch: 9, Loss: 0.005527986562810838, Accuracy: 1.0
Epoch: 10, Loss: 0.003145370981656015, Accuracy: 1.0
Epoch: 11, Loss: 0.002376986318267882, Accuracy: 1.0
Epoch: 12, Loss: 0.0015924777835607529, Accuracy: 1.0
Epoch: 13, Loss: 0.0016375085397157818, Accuracy: 1.0
Epoch: 14, Loss: 0.0019424646743573248, Accuracy: 1.0
Epoch: 15, Loss: 0.0030713254818692803, Accuracy: 1.0
Epoch: 16, Loss: 0.0009081304888240993, Accuracy: 1.0
Epoch: 17, Loss: 0.001038972637616098, Accuracy: 1.0
Epoch: 18, Loss: 0.0008256497822003439, Accur

(Sequential(
   (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (1): LeakyReLU(negative_slope=0.1)
   (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (4): Flatten(start_dim=1, end_dim=-1)
   (5): Linear(in_features=1800, out_features=64, bias=True)
   (6): LeakyReLU(negative_slope=0.1)
   (7): Linear(in_features=64, out_features=5, bias=True)
 ),
 [1.2412742376327515,
  0.5429351925849915,
  0.24108202010393143,
  0.1644803285598755,
  0.1363201141357422,
  0.048028115183115005,
  0.018022849806584418,
  0.0075950766913592815,
  0.0048752520233392715,
  0.005527986562810838,
  0.003145370981656015,
  0.002376986318267882,
  0.0015924777835607529,
  0.0016375085397157818,
  0.0019424646743573248,
  0.0030713254818692803,
  0.0009081304888240993,
  0.001038972637616098,
  0.0008256497822003439,
  0.0005106575263198465,
  0.0003867987252306193,


In [64]:
torch.save(model.state_dict(), "direction_models/direction-classification-v0")

In [67]:
# Autoâ€‘generate a loader snippet for the trained PyTorch model
from utils import generate_torch_loader_snippet

example_input = torch.tensor(
    [[0.0, 0.0]], dtype=torch.float32
)  # minimal example for tracing if needed
snippet = generate_torch_loader_snippet(
    model=get_model(21), prefer="auto"
, compression="zlib")

with open("exported_models/direction_model.py", "w") as f:
    f.write(snippet)