In [32]:
import torch
import torch.nn as nn
import math

def positional_encoding_2d(d_model, height, width):
    """
    :param d_model: dimension of the model
    :param height: height of the positions
    :param width: width of the positions
    :return: d_model*height*width position matrix
    """
    if d_model % 4 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dimension (got dim={:d})".format(d_model))
    pe = torch.zeros(d_model, height, width)
    # Each dimension use half of d_model
    d_model = int(d_model / 2)
    div_term = torch.exp(torch.arange(0., d_model, 2) *
                         -(math.log(10000.0) / d_model))
    pos_w = torch.arange(0., width).unsqueeze(1)
    pos_h = torch.arange(0., height).unsqueeze(1)
    pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    return pe

In [33]:
class GameOfLifeTransformer(nn.Module):
    def __init__(self, grid_size, d_model, nhead, num_layers):
        super(GameOfLifeTransformer, self).__init__()
        self.grid_size = grid_size
        self.d_model = d_model
        
        self.embedding = nn.Linear(1, d_model)
        self.pos_encoder = positional_encoding_2d(d_model, grid_size, grid_size)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.output_layer = nn.Linear(d_model, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, src):
        # src shape: (batch_size, num_frames, grid_size, grid_size)
        batch_size, num_frames, height, width = src.shape
        
        # Reshape and embed
        src = src.reshape(batch_size * num_frames, height * width, 1)
        src = self.embedding(src)  # (batch_size * num_frames, grid_size*grid_size, d_model)
        
        # Reshape positional encoding to match src
        pos_encoding = self.pos_encoder.view(1, height * width, self.d_model)
        pos_encoding = pos_encoding.repeat(batch_size * num_frames, 1, 1)
        
        src = src + pos_encoding.to(src.device)
        
        src = src.permute(1, 0, 2)  # (grid_size*grid_size, batch_size * num_frames, d_model)
        
        output = self.transformer_encoder(src)
        output = self.output_layer(output)
        output = self.sigmoid(output)
        
        return output.permute(1, 2, 0).reshape(batch_size, num_frames, height, width)

In [34]:
import numpy as np

def create_game_of_life_dataset(num_samples, grid_size, num_steps):
    def update(frame):
        # Game of Life update rules
        n = sum([np.roll(np.roll(frame, i, 0), j, 1)
                 for i in (-1, 0, 1) for j in (-1, 0, 1)
                 if (i != 0 or j != 0)])
        return ((n == 3) | ((frame == 1) & (n == 2))).astype(int)

    dataset = []
    for _ in range(num_samples):
        # Random initial state
        initial_state = np.random.choice([0, 1], size=(grid_size, grid_size))
        sequence = [initial_state]
        for _ in range(num_steps - 1):
            next_state = update(sequence[-1])
            sequence.append(next_state)
        dataset.append(np.array(sequence))
    
    return np.array(dataset)

In [35]:
import torch.optim as optim

# Hyperparameters
grid_size = 10
d_model = 64
nhead = 4
num_layers = 3
batch_size = 32
num_epochs = 50
learning_rate = 0.001

# Create dataset
dataset = create_game_of_life_dataset(1000, grid_size, 10)
dataset = torch.FloatTensor(dataset)

# Create model
model = GameOfLifeTransformer(grid_size, d_model, nhead, num_layers)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

# Training loop
for epoch in range(num_epochs):
    total_loss = 0
    for i in range(0, dataset.shape[0], batch_size):
        batch = dataset[i:i+batch_size]
        inputs = batch[:, :-1]  # All frames except the last
        targets = batch[:, 1:]  # All frames except the first
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

# Save the model
torch.save(model.state_dict(), "game_of_life_transformer.pth")

Using device: mps
Epoch 1/50, Loss: 17.9451
Epoch 2/50, Loss: 15.4759
Epoch 3/50, Loss: 15.0648
Epoch 4/50, Loss: 14.9679
Epoch 5/50, Loss: 14.9376
Epoch 6/50, Loss: 14.9256
Epoch 7/50, Loss: 14.9219
Epoch 8/50, Loss: 14.9182
Epoch 9/50, Loss: 14.9151
Epoch 10/50, Loss: 14.9101
Epoch 11/50, Loss: 14.9101
Epoch 12/50, Loss: 14.9091
Epoch 13/50, Loss: 14.9091
Epoch 14/50, Loss: 14.9070
Epoch 15/50, Loss: 14.9050
Epoch 16/50, Loss: 14.9040
Epoch 17/50, Loss: 14.9049
Epoch 18/50, Loss: 14.9052
Epoch 19/50, Loss: 14.9034
Epoch 20/50, Loss: 14.9037
Epoch 21/50, Loss: 14.9037
Epoch 22/50, Loss: 14.9043
Epoch 23/50, Loss: 14.9032
Epoch 24/50, Loss: 14.9033
Epoch 25/50, Loss: 14.9018
Epoch 26/50, Loss: 14.9016
Epoch 27/50, Loss: 14.9009
Epoch 28/50, Loss: 14.9007
Epoch 29/50, Loss: 14.9003
Epoch 30/50, Loss: 14.9014
Epoch 31/50, Loss: 14.9013
Epoch 32/50, Loss: 14.8993
Epoch 33/50, Loss: 14.8990
Epoch 34/50, Loss: 14.8964
Epoch 35/50, Loss: 14.8967
Epoch 36/50, Loss: 14.8968
Epoch 37/50, Loss: 