In [1]:
import torch, torchvision
import torch.nn.functional as F
from torch.optim.adamw import AdamW

In [2]:
def init_board(img_path: str, init_val: float = None) -> torch.Tensor:
    # Load RGBA → float32 ∈ [0,1]
    img = torchvision.io.read_image(img_path,
                           mode=torchvision.io.image.ImageReadMode.RGB_ALPHA).float() / 255.0
    # Binary alpha: 0 if fully transparent else 1
    img[3] = (img[3] > 0).float()
    C, H, W = img.shape
    if init_val is not None:
        features = torch.fill(torch.zeros(16, H, W), init_val)
    else:
        features = torch.rand(16, H, W)
        
    features[3:16, H // 2, W // 2] = 1 # set the seed
    
    return features


board = init_board("data/mudkip.png", init_val=0)

In [None]:
def get_perception(state_grid: torch.Tensor) -> torch.Tensor:
    """
    Applies Sobel filters to each channel of the input state_grid.
    """
    
    # Sobel kernels
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=state_grid.dtype, device=state_grid.device)
    sobel_y = sobel_x.t()

    # Prepare kernels for depthwise conv2d
    C, H, W = state_grid.shape
    sobel_x = sobel_x.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
    sobel_y = sobel_y.view(1, 1, 3, 3).repeat(C, 1, 1, 1)

    # Add batch dimension
    x_grid = F.conv2d(state_grid, sobel_x, padding=1, groups=C)
    y_grid = F.conv2d(state_grid, sobel_y, padding=1, groups=C)

    state_grid = torch.cat((state_grid, x_grid, y_grid), dim=0)
    return state_grid

a_perception = get_perception(board)

In [4]:
from torch import nn

class FancyCARule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(48, 128, 3, padding=1)
        self.conv2 = nn.Conv2d(128, 16, 3, padding=1)
        self.act = nn.ReLU()
    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        return x
    
class CARule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Linear(48, 128)
        self.conv2 = nn.Linear(128, 16)
        self.act = nn.ReLU()
    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        return x

In [107]:
EPS = 5

update = FancyCARule()

def fancy_update_board(state_grid: torch.Tensor) -> torch.Tensor:
    C, H, W = state_grid.shape
    rand_mask = (torch.abs(torch.randn(H, W)) > EPS) # boolean mask to update 
    features_to_update = state_grid[:, rand_mask]
    
    idxs = rand_mask.nonzero()
    padded_state_grid = F.pad(state_grid, (1, 1, 1, 1), value=0)
    patches = []
    for y, x in idxs:
        patches.append(padded_state_grid[:, y-1: y+1, x-1: x+1])
    
    feature_slices = torch.stack(patches, dim=0)
    
    feature_step = update(feature_slices)
    
    print(feature_step.shape)
    
    state_grid[:, rand_mask] = state_grid[:, rand_mask] + feature_step
    
    return state_grid

update = CARule()

def update_board(perception: torch.Tensor) -> torch.Tensor:
    _, H, W = perception.shape
    rand_mask = (torch.abs(torch.randn(H, W)) > EPS) # boolean mask to selectively update cells
    features_to_update = perception[:, rand_mask].transpose(1, 0)
    feature_step = update(features_to_update)
    rt = torch.zeros(16, H, W)
    rt[0:16, rand_mask] = feature_step.transpose(1,0)
    
    return perception[:16, :, :]

a_new_board = update_board(a_perception)

In [108]:
# train loop

STEPS = 10000
num_iters_per_update = (64, 96)

update_rule = CARule()

def train(img_path: str):
    board = init_board(img_path)
    target = torchvision.io.read_image(img_path,
                           mode=torchvision.io.image.ImageReadMode.RGB_ALPHA).float() / 255.0
    target = target[:3, :, :]
    opt = AdamW(update_rule.parameters())
    loss_fn = torch.nn.MSELoss()
    
    steps_till_opt = int((64 + (96 - 64) * torch.rand(1)).item())
    for steps in range(STEPS):
        alive_mask = (F.max_pool2d(board[3:4, None, :, :], 3, stride=1, padding=1) > 0.1).int()[0, 0]
        perception = get_perception(board)
        dboard = update_board(perception)
        board = board + dboard * alive_mask
        steps_till_opt -= 1
        
        if steps_till_opt == 0:
            steps_till_opt = int((64 + (96 - 64) * torch.rand(1)).item())
            board_img = board[:3, :, :]
            loss_val = loss_fn(target, board_img)
            loss_val.backward()
            opt.step()
            opt.zero_grad()
            board = board.detach()
            torchvision.utils.save_image(board_img, f"board_img_step_{steps}.png")
            print(loss_val)

In [109]:
train('data/mudkip.png')

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn