In [28]:
import torch
import torch.nn as nn

In [None]:
number_of_populations = 3

# Seed
torch.manual_seed(0)

# Random mask
empty_cells = torch.randint(0, 2, (2, number_of_populations, 5, 5), dtype=torch.bool).all(dim=1, keepdim=True)

grid = torch.randn(2, number_of_populations, 5, 5)
grid = grid.masked_fill(empty_cells, 0)  # Mask out empty cells

# Define the kernel (same for all channels)
kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.float32)
kernel = kernel / kernel.sum()  # Normalize kernel

# Duplicate kernel for each channel: [out_channels=3, in_channels/groups=1, kH, kW]
kernel = kernel.view(1, 1, 3, 3).repeat(number_of_populations, 1, 1, 1)  # shape: [3,1,3,3]

# Pad input
grid_padded = nn.CircularPad2d(1)(grid)  # Pad with reflection to avoid border issues

# Apply depthwise convolution: groups = number of channels
new_fitness = torch.nn.functional.conv2d(grid_padded, kernel, stride=1, padding=0, groups=grid.shape[1])

new_fitness.shape, empty_cells.shape

(torch.Size([2, 3, 5, 5]), torch.Size([2, 1, 5, 5]))

In [116]:
empty_cells_expanded = empty_cells.expand(-1, number_of_populations, -1, -1)

# Set fitness to 0 where the cell is not empty
new_fitness[~empty_cells_expanded] = 0

In [117]:
max_vals, max_indices = torch.max(new_fitness, dim=1, keepdim=True)

# Step 2: Create a mask that is True where the max occurs
mask = torch.arange(new_fitness.size(1), device=new_fitness.device).view(1, -1, 1, 1) == max_indices

# Step 3: Zero out everything that's not the max
result = new_fitness * mask.float()

In [118]:
# Update grid
new_grid = grid + result
grid[0]

tensor([[[-0.2011,  0.0000,  0.0000, -1.4073,  1.6268],
         [ 0.1723, -1.6115, -0.4794,  0.1574,  0.0000],
         [ 0.0000,  0.9979,  0.5436,  0.0788,  0.8629],
         [-0.0195,  0.7611,  0.6183, -0.2994, -0.1878],
         [ 1.9159,  0.0000, -2.3217, -1.1964,  0.2408]],

        [[-1.3962,  0.0000,  0.0000, -1.3952,  0.4751],
         [-0.8137,  0.9242,  1.5735,  0.7814,  0.0000],
         [ 0.0000,  0.5867,  0.1583,  0.1102, -0.8188],
         [-1.1894, -1.1959,  1.3119, -0.2098,  0.7817],
         [ 0.9897,  0.0000, -1.5090, -0.2871,  1.0216]],

        [[-0.5111,  0.0000,  0.0000, -0.4749, -0.6334],
         [-1.4677,  0.6074, -0.5472, -1.1005,  0.0000],
         [ 0.0000,  0.3398, -0.2635,  1.2805, -0.4947],
         [-1.2830,  0.4386, -0.0107,  1.3384, -0.2794],
         [-0.5518,  0.0000, -1.0619, -0.1144,  0.1954]]])

## Test


In [299]:
# Define 8 directional kernels for 10% value shift
kernels = torch.zeros((8, 1, 3, 3), dtype=torch.float32)
offsets = torch.tensor([
    [1,  0],  # 1: up
    [1,  1],  # 2: up-right
    [0,  1],  # 3: right
    [-1,  1],  # 4: down-right
    [-1,  0],  # 5: down
    [-1, -1],  # 6: down-left
    [ 0, -1],  # 7: left
    [1, -1],  # 8: up-left
], dtype=torch.long)

for i, (dy, dx) in enumerate(offsets):
    kernel = torch.zeros((1, 3, 3), dtype=torch.float32)
    kernel[0, 1 + dy, 1 + dx] = 0.1  # Set the offset value to 0.1
    kernels[i] = kernel

action = torch.randint(1, 9, (2, 1, 5, 5), device=grid.device)  # Random action for each population

In [300]:
# Assume grid has shape (batch, nb_pop, n, m) and we take max over nb_pop:
max_vals, max_indices = torch.max(grid, dim=1, keepdim=True)  # shape: (batch, 1, n, m)

# Create a one-hot action mask from max_indices (actions assumed in 1-8)
action_mask = F.one_hot((max_indices.squeeze(1)), num_classes=8)  # shape: (batch, n, m, 8)
action_mask = action_mask.permute(0, 3, 1, 2).float()  # shape: (batch, 8, n, m)

# Pad max_vals (single channel) using reflection padding
max_vals_padded = nn.CircularPad2d(1)(max_vals)  # shape: (batch, 1, n+2, m+2)

# 'kernels' should be defined as before with shape (8, 1, 3, 3)
# Perform convolution: input has 1 channel, kernels have 1 input channel, output 8 channels
new_vals = F.conv2d(max_vals_padded, kernels, stride=1, padding=0)  # shape: (batch, 8, n, m)

# Select contributions using the action mask: keep only the kernel corresponding to each cell's action
new_vals_masked = new_vals * action_mask

# Sum over the 8 directional contributions for each cell
new_vals_sum = new_vals_masked.sum(dim=1, keepdim=True)  # shape: (batch, 1, n, m)

In [301]:
# 1: up, 2: up-right, 3: right, 4: down-right, 5: down, 6: down-left, 7: left, 8: up-left


In [302]:
shifted_action = action - 1
up = torch.roll(torch.where(shifted_action == 0, max_vals, 0), shifts=-1, dims=2)
up_right = torch.roll(torch.where(shifted_action == 1, max_vals, 0), shifts=(-1, 1), dims=(2, 3))
right = torch.roll(torch.where(shifted_action == 2, max_vals, 0), shifts=1, dims=3)
down_right = torch.roll(torch.where(shifted_action == 3, max_vals, 0), shifts=(1, 1), dims=(2, 3))
down = torch.roll(torch.where(shifted_action == 4, max_vals, 0), shifts=1, dims=2)
down_left = torch.roll(torch.where(shifted_action == 5, max_vals, 0), shifts=(1, -1), dims=(2, 3))
left = torch.roll(torch.where(shifted_action == 6, max_vals, 0), shifts=-1, dims=3)
up_left = torch.roll(torch.where(shifted_action == 7, max_vals, 0), shifts=(-1, -1), dims=(2, 3))

# Combine all contributions
contributions = up + up_right + right + down_right + down + down_left + left + up_left

new_grid = 0.1 * contributions  # Update grid with 10% of the contributions

In [303]:
max_vals[0]

tensor([[[-0.2011,  0.0000,  0.0000, -0.4749,  1.6268],
         [ 0.1723,  0.9242,  1.5735,  0.7814,  0.0000],
         [ 0.0000,  0.9979,  0.5436,  1.2805,  0.8629],
         [-0.0195,  0.7611,  1.3119,  1.3384,  0.7817],
         [ 1.9159,  0.0000, -1.0619, -0.1144,  1.0216]]])

In [None]:
shifted_action[0]
# 0: up, 1: up-right, 2: right, 3: down-right, 4: down, 5: down-left, 6: left, 7: up-left

tensor([[[6, 5, 3, 6, 4],
         [7, 5, 2, 4, 6],
         [5, 2, 1, 4, 0],
         [1, 0, 2, 0, 3],
         [5, 2, 0, 2, 3]]])

In [311]:
max_vals[0]

tensor([[[-0.2011,  0.0000,  0.0000, -0.4749,  1.6268],
         [ 0.1723,  0.9242,  1.5735,  0.7814,  0.0000],
         [ 0.0000,  0.9979,  0.5436,  1.2805,  0.8629],
         [-0.0195,  0.7611,  1.3119,  1.3384,  0.7817],
         [ 1.9159,  0.0000, -1.0619, -0.1144,  1.0216]]])

In [306]:
up[0]

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.8629],
         [ 0.0000,  0.7611,  0.0000,  1.3384,  0.0000],
         [ 0.0000,  0.0000, -1.0619,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]])

In [555]:
# Set seed
torch.manual_seed(0)

action_grid = torch.randint(9, 17, (2, 1, 5, 5))
grid = torch.randn(2, 3, 5, 5)

flat_grid, _ = torch.max(grid, dim=1, keepdim=True)  # shape: (batch, 1, n, m)
initial_flat_grid = flat_grid.clone()  # Store the initial state for comparison

In [556]:
action_grid = torch.zeros_like(action_grid)
action_grid[0, 0, 3, 2] = 8 + 8
action_grid[0]

tensor([[[ 0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0],
         [ 0,  0, 16,  0,  0],
         [ 0,  0,  0,  0,  0]]])

In [557]:
# Get the directional index (0 to 7) for attack actions:
shifted_action = (action_grid - 9)

# Initialize bonus accumulator (same shape as flat_grid: [batch, 1, rows, cols])
bonus = torch.zeros_like(flat_grid)

empty_cells = (flat_grid <= 1e-6).all(dim=1)  # shape: [batch, rows, cols]

circ_kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=flat_grid.dtype, device=flat_grid.device).view(1, 1, 3, 3)
nb_existing_neighbors = torch.nn.functional.conv2d(
    nn.CircularPad2d(1)((~empty_cells).float()).unsqueeze(1),
    circ_kernel
)

# Define the 8 directional shifts (row, col) corresponding to neighbors:
# 0: up, 1: up-right, 2: right, 3: down-right,
# 4: down, 5: down-left, 6: left, 7: up-left.
shifts = [(-1, 0), (-1, 1), (0, 1), (1, 1),
        (1, 0), (1, -1), (0, -1), (-1, -1)]

# Process each direction separately
for i, (shift_row, shift_col) in enumerate(shifts):
    # Create mask for cells whose shifted_action equals this direction (attack events)
    mask = (shifted_action == i)
    if mask.sum() == 0:
        continue  # No attack events in this direction
    
    # Get attacker values from flat_grid where attack occurs.
    attacker_vals = torch.where(mask, flat_grid, torch.zeros_like(flat_grid))

    # Get defender values from the neighbor cell by rolling the flat_grid.
    defender_mapped_to_attacker_vals = torch.roll(flat_grid, shifts=(-shift_row, -shift_col), dims=(2, 3))
    defender_mapped_to_attacker_vals = torch.where(mask, defender_mapped_to_attacker_vals, torch.zeros_like(flat_grid))
    
    # Compute the value to subtract: the minimum of the two values.
    vals = torch.min(attacker_vals, defender_mapped_to_attacker_vals)
    
    # Subtract val from the attacker cell.
    flat_grid = torch.where(mask, flat_grid - vals, flat_grid)

    # For the defender, roll the mask to align with neighbor positions.
    defender_mask = torch.roll(mask, shifts=(shift_row, shift_col), dims=(2, 3))
    vals_mapped_to_defender = torch.roll(vals, shifts=(shift_row, shift_col), dims=(2, 3))
    flat_grid = torch.where(defender_mask, flat_grid - vals_mapped_to_defender, flat_grid)

    attacker_dying_cells = (flat_grid <= 1e-6) & mask
    defender_dying_cells = (flat_grid <= 1e-6) & defender_mask

    bonus = torch.zeros_like(flat_grid)  # Reset bonus for each direction
    bonus[attacker_dying_cells] = vals[attacker_dying_cells]  # Assign bonus to dying cells
    bonus[defender_dying_cells] = vals_mapped_to_defender[defender_dying_cells]  # Assign bonus to dying cells

    bonus = bonus / nb_existing_neighbors.clamp(min=1e-6)  # Avoid division by zero

    # Distribute bonus using convolution with circular padding.
    distributed_bonus = torch.nn.functional.conv2d(
        torch.nn.functional.pad(bonus, (1, 1, 1, 1), mode='circular'),
        circ_kernel
    )

    # Remove bonus from empty cells
    distributed_bonus[empty_cells.unsqueeze(1)] = 0
    # flat_grid += distributed_bonus  # Update flat_grid with distributed bonus
        


In [558]:
initial_flat_grid[0]

tensor([[[ 1.0554e+00,  1.7784e-01,  1.1149e+00,  2.7995e-01,  8.0575e-01],
         [ 1.1133e+00,  3.3801e-01,  4.5440e-01,  1.5210e+00,  3.4105e+00],
         [ 7.8131e-01,  1.0395e+00,  1.8197e+00, -3.3039e-03, -7.2915e-02],
         [ 1.8855e-01,  1.1108e+00,  1.2899e+00, -9.2146e-01,  2.5672e+00],
         [ 7.1009e-01,  1.0367e+00,  1.9218e+00,  2.0820e+00,  5.1987e-01]]])

In [560]:
flat_grid[0]

tensor([[[ 1.0554e+00,  1.7784e-01,  1.1149e+00,  2.7995e-01,  8.0575e-01],
         [ 1.1133e+00,  3.3801e-01,  4.5440e-01,  1.5210e+00,  3.4105e+00],
         [ 7.8131e-01,  0.0000e+00,  1.8197e+00, -3.3039e-03, -7.2915e-02],
         [ 1.8855e-01,  1.1108e+00,  2.5041e-01, -9.2146e-01,  2.5672e+00],
         [ 7.1009e-01,  1.0367e+00,  1.9218e+00,  2.0820e+00,  5.1987e-01]]])

In [564]:
distributed_bonus[0]

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1299, 0.1299, 0.1299, 0.0000, 0.0000],
         [0.1299, 0.0000, 0.1299, 0.0000, 0.0000],
         [0.1299, 0.1299, 0.1299, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])