In [1]:
import torch
import torch.nn as nn
import numpy as np

from flash_ansr.models.encoders import SAB, MAB

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
X = torch.randn(111, 3)

In [3]:
mins = X.min(dim=0).values
maxs = X.max(dim=0).values
ranges = maxs - mins
print(mins)
print(maxs)
print(ranges)

tensor([-2.4810, -2.4991, -2.8691])
tensor([2.0126, 2.4688, 2.2511])
tensor([4.4936, 4.9679, 5.1202])


In [4]:
def split_mask(X: torch.Tensor) -> torch.Tensor:
    mins = X.min(dim=0).values
    maxs = X.max(dim=0).values
    ranges = maxs - mins
    max_range_dim = ranges.argmax()
    
    split_value = torch.median(X[:, max_range_dim])
    # print(f"Splitting {X.shape=} on dimension {max_range_dim} at value {split_value}")

    return X[:, max_range_dim] < split_value

In [5]:
split_mask(X)

tensor([False, False,  True, False,  True,  True, False, False, False,  True,
         True,  True,  True,  True, False, False, False, False,  True,  True,
         True, False, False,  True, False, False, False, False, False, False,
         True,  True,  True,  True,  True, False,  True,  True,  True, False,
         True,  True, False,  True, False, False, False,  True,  True, False,
        False,  True, False,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True,  True, False, False,  True,  True,  True,  True,
         True, False, False,  True,  True,  True, False, False, False, False,
        False,  True, False, False, False,  True,  True,  True,  True,  True,
        False, False, False, False, False, False,  True, False,  True, False,
        False, False, False, False, False, False, False,  True,  True,  True,
        False])

In [6]:
X = torch.randn(128, 111, 3)

In [7]:
def split_mask_batched(X: torch.Tensor) -> torch.Tensor:
    if X.ndim != 3:
        raise ValueError(f"Expected 3D tensor, got {X.ndim}D tensor of shape {X.shape}")

    mins = X.min(dim=1).values
    maxs = X.max(dim=1).values
    ranges = maxs - mins
    max_range_dim = ranges.argmax(dim=1)

    split_value = torch.median(X[np.arange(X.shape[0]), :, max_range_dim], dim=1).values

    return X[np.arange(X.shape[0]), :, max_range_dim] < split_value.unsqueeze(1)

In [8]:
def split_mask_batched(X: torch.Tensor) -> torch.Tensor:
    if X.ndim != 3:
        raise ValueError(f"Expected 3D tensor, got {X.ndim}D tensor of shape {X.shape}")

    batch_size, num_points, _ = X.shape

    # Find dimension with max range for each batch
    mins = X.min(dim=1).values
    maxs = X.max(dim=1).values
    ranges = maxs - mins
    max_range_dim = ranges.argmax(dim=1)

    # Initialize result mask
    result_mask = torch.zeros((batch_size, num_points), dtype=torch.bool, device=X.device)

    # Process each batch independently
    for i in range(batch_size):
        # Extract values along the max range dimension for this batch
        values = X[i, :, max_range_dim[i]]

        # Calculate target size for the first half
        target_size = num_points // 2

        # Find median value (equivalent to finding kth element)
        median_value = torch.kthvalue(values, target_size + 1).values

        # Create masks for each category
        less_mask = values < median_value
        equal_mask = values == median_value

        # Count elements in each category
        less_count = torch.sum(less_mask).item()
        equal_count = torch.sum(equal_mask).item()

        # Calculate how many median elements go to first half to maintain balance
        median_in_first = max(0, target_size - less_count)

        # If we need to split the median elements
        if 0 < median_in_first < equal_count:
            # Get indices of the median elements
            equal_indices = torch.nonzero(equal_mask, as_tuple=True)[0]

            # Add the required number of median elements to the less mask
            result_mask[i, less_mask] = True
            result_mask[i, equal_indices[:median_in_first]] = True
        else:
            # Simple case: all less than median, and possibly some equal to median
            result_mask[i] = values <= median_value if less_count < target_size else less_mask

    return result_mask


In [9]:
mask = split_mask_batched(X)
mask.shape

torch.Size([128, 111])

In [10]:
X[mask].shape

torch.Size([7040, 3])

In [11]:
%%timeit
X = torch.randn(128, 512, 3)
mask = split_mask_batched(X)

3.99 ms ± 186 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%%timeit
X = torch.randn(128, 512, 3, device=device)
mask = split_mask_batched(X)

36.8 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
class Node:
    def __init__(self, X: torch.Tensor, left=None, right=None):
        self.X = X
        self.left = left
        self.right = right

    @property
    def shape(self):
        return self.X.shape

In [13]:
# def create_tree(X: torch.Tensor, max_depth: int = 5, min_n_leaf: int | None = None, depth: int = 0) -> torch.Tensor:
#     if min_n_leaf is None:
#         min_n_leaf = np.sqrt(X.shape[0])

#     if depth + 1 >= max_depth or X.shape[0] <= min_n_leaf:
#         return Node(X)

#     mask = split_mask(X)

#     return Node(X, create_tree(X[mask], max_depth, min_n_leaf, depth + 1), create_tree(X[~mask], max_depth, min_n_leaf, depth + 1))

In [11]:
def create_tree(X: torch.Tensor, max_depth: int = 5, depth: int = 0) -> torch.Tensor:
    if depth + 1 >= max_depth:
        return Node(X)

    if X.shape[0] > 1:
        mask = split_mask(X)
        return Node(X, create_tree(X[mask], max_depth, depth + 1), create_tree(X[~mask], max_depth, depth + 1))
    
    return Node(X, create_tree(X, max_depth, depth + 1), create_tree(X, max_depth, depth + 1)) 

In [12]:
X = torch.randn(512, 3)
tree = create_tree(X)

In [13]:
X = torch.randn(4, 3)
tree = create_tree(X)

In [14]:
X = torch.randn(152, 3)
tree = create_tree(X)

In [17]:
# %%timeit
# X = torch.randn(512, 3)
# tree = create_tree(X)

In [18]:
# can make this batched!

In [15]:
def create_tree_batched(X: torch.Tensor, max_depth: int = 5, depth: int = 0) -> torch.Tensor:
    if X.ndim != 3:
        raise ValueError(f"Expected 3D tensor, got {X.ndim}D tensor of shape {X.shape}")

    if depth + 1 >= max_depth:
        return Node(X)

    if X.shape[1] > 1:
        mask_batched = split_mask_batched(X)
        X_left = torch.stack([X[i][mask_batched[i]] for i in range(X.shape[0])], dim=0)
        X_right = torch.stack([X[i][~mask_batched[i]] for i in range(X.shape[0])], dim=0)    
        # print(X_left.shape)
        # print(X_right.shape)    
        return Node(X, create_tree_batched(X_left, max_depth, depth + 1), create_tree_batched(X_right, max_depth, depth + 1))
    
    return Node(X, create_tree_batched(X, max_depth, depth + 1), create_tree_batched(X, max_depth, depth + 1)) 

In [16]:
X = torch.randn(128, 111, 3)
tree_batched = create_tree_batched(X)

In [17]:
X = torch.randn(128, 152, 4)
tree_batched = create_tree_batched(X)

In [21]:
%%timeit
X = torch.randn(128, 512, 3)
tree_batched = create_tree_batched(X)

33.6 ms ± 925 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
%%timeit
X = torch.randn(128, 512, 3, device=device)
tree_batched = create_tree_batched(X)

432 ms ± 48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
tree_batched.left.left.left.shape

torch.Size([128, 13, 3])

In [24]:
class TreeSetEncoder(nn.Module):
    def __init__(self, depth: int, features_in: int, features_hidden: int = 512, n_heads: int = 8, clean_path: bool = True):
        super().__init__()
        self.depth = depth
        self.features_in = features_in
        self.features_hidden = features_hidden

        self.linear_in = nn.Linear(features_in, features_hidden)

        self.sabs = nn.ModuleList([SAB(features_hidden, n_heads=n_heads, clean_path=clean_path) for _ in range(depth)])
        self.mabs_lr = nn.ModuleList([MAB(features_hidden,n_heads=n_heads, clean_path=clean_path) for _ in range(depth)])
        self.mabs_rl = nn.ModuleList([MAB(features_hidden, n_heads=n_heads, clean_path=clean_path) for _ in range(depth)])

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if X.ndim > 3:
            raise ValueError(f"Expected 2D or 3D tensor, got {X.ndim}D tensor of shape {X.shape}")
        elif X.ndim == 2:
            X = X.unsqueeze(0)
        elif X.ndim != 3:
            raise ValueError(f"Expected 2D or 3D tensor, got {X.ndim}D tensor of shape {X.shape}")

        tree_batched = create_tree_batched(X.cpu(), max_depth=self.depth)
        return self.encode_tree(tree_batched)

    def encode_tree(self, tree: Node, depth: int = 0) -> list[tuple[int, int]]:
        if tree.left is None and tree.right is None:
            # Leaf node
            return self.linear_in(tree.X.to(device))

        X_left = self.encode_tree(tree.left, depth + 1)
        X_right = self.encode_tree(tree.right, depth + 1)

        X_left = self.sabs[depth](X_left)
        X_right = self.sabs[depth](X_right)

        X_lr = self.mabs_lr[depth](X_left, X_right)

        X_lr_rl = self.mabs_rl[depth](X_right, X_lr)

        return X_lr_rl
    
    @property
    def n_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

In [25]:
encoder = TreeSetEncoder(depth=5, features_in=3).to(device)
print(f'{encoder.n_params:,} params')

15,761,408 params


In [26]:
%%timeit
X = torch.randn(128, 512, 3, device=device)
z = encoder.forward(X)

72.8 ms ± 3.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
%%timeit
X = torch.randn(128, 512, 3, device=device)
with torch.no_grad():
    z = encoder.forward(X)

74.3 ms ± 1.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
X = torch.randn(128, 512, 3, device=device)
z = encoder.forward(X.to(device))
z.shape

torch.Size([128, 33, 512])

In [29]:
X = torch.randn(128, 128, 3, device=device)
z = encoder.forward(X.to(device))
z.shape

torch.Size([128, 9, 512])

In [30]:
X = torch.randn(128, 10, 3, device=device)
z = encoder.forward(X.to(device))
z.shape

torch.Size([128, 2, 512])

In [31]:
X = torch.randn(128, 1, 3, device=device)
z = encoder.forward(X.to(device))
z.shape

torch.Size([128, 1, 512])