In [132]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)
from typing import *

In [37]:
class ResidualModule(nn.Module):
    def __init__(self, dims):
        super(ResidualModule, self).__init__()
        self.fc1 = nn.Linear(dims, dims)

    def forward(self, x):
        x_1 = self.fc1(x)
        x_1 = F.relu(x_1)
        x = x + x_1
        return x

In [44]:
class Agent(nn.Module):
    def __init__(self, task_dims, env_dims):
        super(Agent, self).__init__()
        self.processor = ResidualModule(task_dims)
        self.head_pos = nn.Parameter(torch.randn(1, env_dims), requires_grad=True)
        self.tail_pos = nn.Parameter(torch.randn(1, env_dims), requires_grad=True)
        with torch.no_grad():
            self.normalize_positions()

    def forward(self, x):
        x = self.processor(x)
        return x

    def normalize_positions(self):
        self.head_pos.data = self.head_pos.data / self.head_pos.data.norm(2)
        self.tail_pos.data = self.tail_pos.data / self.tail_pos.data.norm(2)


In [None]:
# TODO:
# - M

In [373]:
class ManifoldWorms(nn.Module):
    def __init__(self, task_dims: int, env_dims: int, n_modules: int, reach: float = 1, garbage_decay: float = 0.9):
        super(ManifoldWorms, self).__init__()
        self.agents = nn.ModuleList([ResidualModule(task_dims) for _ in range(n_modules)])
        self.head_positions = nn.Parameter(torch.randn(n_modules + 1, env_dims), requires_grad=True)
        self.tail_positions = nn.Parameter(torch.randn(env_dims, n_modules + 1), requires_grad=True)
        self.state = torch.zeros(1, n_modules + 1, task_dims).requires_grad_(False)
        self.reach_threshold = np.clip(1 - reach, -1, 1).item()
        self.garbage_scale = np.clip(1 - garbage_decay, 0, 1).item()
        with torch.no_grad():
            self.positions_normalization()

    def forward(self, x: Optional[torch.Tensor] = None):
        if x is not None:
            self.state[:, -1] = self.state[:, -1] + x
        self.step()
    
    def step(self):
        self.positions_normalization()
        # Calculate similarity matrix between all heads and tails
        similarities = self.positions_similarity().unsqueeze(0).repeat(self.state.shape[0], 1, 1)  # [B, H, T]
        # Create mask for valid interactions (heads close enough to tails)
        closeness_mask = (similarities > self.reach_threshold).float()  # [B, H, T]
        closeness_mask[:, -1] = 0 # No state should be sent to the entrance head
        # Apply softmax across TAILS for each HEAD (dim=2)
        influence = similarities * closeness_mask
        influence = influence.masked_fill(closeness_mask == 0, float('-inf'))  # Mask out unreachable tails
        influence = F.softmax(influence, dim=2).nan_to_num(0)  # Normalize over tails per head
        influence = influence * closeness_mask  # Re-apply mask to ensure hard cutoff
        # Calculate resources consumed FROM TAILS via heads
        consumed_resources = influence.bmm(self.state)  # [B, H, D]
        # Process only agent heads (exclude entrance head at last index)
        tail_outputs = torch.stack([
            self.agents[i](consumed_resources[:, i]) 
            for i in range(len(self.agents))  # Process n_modules agents
        ], dim=1)  # [B, M, D]
        tail_outputs = F.pad(tail_outputs, (0, 0, 0, 1))
        # Calculate exit output (resources sent to exit tail)
        exit_influence = influence[..., -1].unsqueeze(1)
        exited_resources = exit_influence.bmm(tail_outputs)
        # Calculate garbage (resources not consumed by any head)
        garbage = (1 - closeness_mask).bmm(self.state)  # [B, T, D]
        # Update state: remove consumed + add new deposits - exit outputs
        state = (
            self.state
            - garbage * self.garbage_scale
            - consumed_resources
            + tail_outputs
            - exited_resources
        )
        # Return the resources that exited and the garbage
        return state, exited_resources, garbage

    def positions_normalization(self):
        self.head_positions.data = self.head_positions.data / self.head_positions.data.norm(2, dim=1, keepdim=True)
        self.tail_positions.data = self.tail_positions.data / self.tail_positions.data.norm(2, dim=0, keepdim=True)

    def positions_similarity(self):
        dot_products = self.head_positions.mm(self.tail_positions)
        l2_norms = self.head_positions.norm(2, dim=1, keepdim=True).mm(self.tail_positions.norm(2, dim=0, keepdim=True))
        return dot_products / l2_norms


In [374]:
test = ManifoldWorms(3, 3, 4)

In [351]:
torch.eye(4)[0].unsqueeze(0).unsqueeze(-1) * torch.rand(1, 4, 3)

tensor([[[0.0793, 0.8151, 0.9445],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

In [375]:
test.step()

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [1, 4] but got: [1, 5].

In [143]:
torch.rand_like(test.state).shape

torch.Size([1, 5, 3])

In [164]:
test.step().unsqueeze(0).bmm(torch.ones(1, 5, 3))[:, 1].shape

torch.Size([1, 3])

In [106]:
test.positions_similarity()[torch.where(test.positions_similarity())]

tensor([-0.9843,  0.9439,  0.1510,  0.1296, -0.2182, -0.4939,  0.6764,  0.5787,
         0.9008, -0.0981,  0.7976, -0.7385, -0.7170, -0.5712, -0.1985,  0.9126,
        -0.9782,  0.0468, -0.1596,  0.4491, -0.8559,  0.5801,  0.6917,  0.1287,
         0.4573], grad_fn=<IndexBackward0>)

In [108]:
test.positions_similarity() * torch.where(test.positions_similarity() > 0.2, 1, 0)

tensor([[-0.0000, 0.9439, 0.0000, 0.0000, -0.0000],
        [-0.0000, 0.6764, 0.5787, 0.9008, -0.0000],
        [0.7976, -0.0000, -0.0000, -0.0000, -0.0000],
        [0.9126, -0.0000, 0.0000, -0.0000, 0.4491],
        [-0.0000, 0.5801, 0.6917, 0.0000, 0.4573]], grad_fn=<MulBackward0>)

In [111]:
F.softmax(test.positions_similarity() * torch.where(test.positions_similarity() > 0.2, 1, 0), 0)[:, 0]

tensor([0.1297, 0.1297, 0.2879, 0.3230, 0.1297], grad_fn=<SelectBackward0>)

In [45]:
a = Agent(10, 10)
a.head_pos, a.tail_pos


(Parameter containing:
 tensor([[ 0.2826,  0.0839, -0.5363,  0.3267,  0.1196, -0.0198,  0.6452,  0.1791,
          -0.1621,  0.1716]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.0271, -0.0178, -0.2622,  0.3032,  0.7373,  0.2175,  0.2847, -0.1851,
           0.1871,  0.3117]], requires_grad=True))

In [None]:
class Environment():
    def __init__(self, env_dims):
        self.entrance__pos = torch.eye(env_dims)[0]
        self.exit__pos = torch.eye(env_dims)[0] * -1
        # Measures the distance between agents' tails
        # 
        # For each agent
        #   aggregate the outputs of all agents and weight by the distance to the agent
        #   return the new position of the agent
        pass

    def normalize_positions(self, directory: Optional[Type[Agent]] = None):
        if directory is None:
            directory = self
        position_vars = [var_name for var_name in dir(directory) if var_name.endswith('__pos')]
        for var_name in position_vars:
            var = getattr(directory, var_name)
            var = var / torch.norm(var, 2)
            setattr(directory, var_name, var)
