In [9]:
import torch

In [41]:
class FullBinaryTree:
    
    def __init__(self, max_depth, num_features):
        """
        Initialize the full binary tree.
        
        Args:
            max_depth (int): Maximum depth.
            num_features (int or list of ints): Number of features for each node.
        """
        self.max_depth = max_depth
        self.num_features = num_features
        self.nodes = [[torch.randn(num_features)]]
        for k in range(1, max_depth):
            self.nodes.append([])
            for i in range(0, 2 ** k): 
                self.nodes[k].append(torch.randn(num_features))
    
    def get_paths(self, leaf_indices):
        """
        Find the path(s) from leaf index to base.
        
        Args:
            leaf_indices (list of ints): List of leaf indices.
        """
        paths = []
        for idx in leaf_indices:
            path = [self.nodes[self.max_depth - 1][idx]]
            c_idx = idx
            for k in range(self.max_depth - 2, -1, -1):
                c_idx = c_idx // 2
                path.append(self.nodes[k][c_idx])
            paths.append(path)
        return torch.Tensor(paths)

In [3]:
i = torch.arange(0, 100)
j = torch.arange(0, 100)
grid = torch.stack(torch.meshgrid(i, j), -1)
print(grid)

tensor([[[ 0,  0],
         [ 0,  1],
         [ 0,  2],
         ...,
         [ 0, 97],
         [ 0, 98],
         [ 0, 99]],

        [[ 1,  0],
         [ 1,  1],
         [ 1,  2],
         ...,
         [ 1, 97],
         [ 1, 98],
         [ 1, 99]],

        [[ 2,  0],
         [ 2,  1],
         [ 2,  2],
         ...,
         [ 2, 97],
         [ 2, 98],
         [ 2, 99]],

        ...,

        [[97,  0],
         [97,  1],
         [97,  2],
         ...,
         [97, 97],
         [97, 98],
         [97, 99]],

        [[98,  0],
         [98,  1],
         [98,  2],
         ...,
         [98, 97],
         [98, 98],
         [98, 99]],

        [[99,  0],
         [99,  1],
         [99,  2],
         ...,
         [99, 97],
         [99, 98],
         [99, 99]]])
