# Train Full Tree

The first paper version suggests using the WHOLE tree during training, and ONLY
during inference do we choose a single path thru the tree.

Notice the `p` and `1-p` in the diagram.

So let's do it!

```
Accuracy of the network over test images: 97.370 %
```


In [1]:
EVERY_N = 200
NEPOCH = 1
BATCH_SIZE = 128
INIT_STRAT = 'hyperspherical-shell'
DEPTH = 8


In [2]:
import torch as torch
import torch.nn as nn
from typing import Optional
from math import floor, log2, sqrt
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [3]:
class FFF(nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        if INIT_STRAT == 'gaussian':
            # This from orig authors; scaling looks off for self.Y
            def create_basis_vectors_of(length, scaling):
                return nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
            self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
            self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        elif INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def create_random_unit_vectors_of(length):
                weights = torch.randn(nNodes, length)  # Initialize weights randomly
                weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return nn.Parameter(weights)
            self.X = create_random_unit_vectors_of(length=nIn)
            self.Y = create_random_unit_vectors_of(length=nOut)

    def forward(self, x: torch.Tensor):
        nBatch, nOut, nNodes = x.shape[0], self.Y.shape[-1], 2 ** self.depth - 1
        is_training = x.requires_grad
    
        # Walk the tree, assembling y piecemeal
        y = torch.zeros((nBatch, nOut), dtype=torch.float, device=x.device, requires_grad=True)
        
        if is_training:
            def process_node(nodeIndex, p_in):
                nonlocal y
                if nodeIndex >= nNodes:
                    return

                # Project x onto the current node's INPUT basis vector
                #   λ = x DOT currNode.X
                # (nBatch, nIn) nIn -> nBatch
                λ = torch.einsum('bi, i -> b', x, self.X[nodeIndex])

                # Project this contribution into OUTPUT space:
                #   y += λ currNode.Y
                # nBatch, nOut -> (nBatch, nOut)
                # y += torch.einsum("b, j -> b j", λ, self.Y[nodeIndex])
                dy = torch.einsum('b,j -> bj', λ, self.Y[nodeIndex])
                y = y + torch.einsum('b, bj -> bj', p_in, dy)

                # ChatGPT says we can elide:
                # y += torch.einsum("b, j, b -> b j", p_in, self.Y[nodeIndex], λ)

                # e.g.
                #   λ = -10, so p ~ 0.0, so 1-p is HIGH
                #   ... and we want to give HIGH weight to the LEFT node
                #   ... so THAT gets the 1-p
                p = torch.sigmoid(λ)
                process_node((nodeIndex * 2) + 1, 1-p)
                process_node((nodeIndex * 2) + 2, p)

                # TODO: penalize p close to 0.5:
                #   cost = .5**2 - (p - 0.5)**2

            process_node(
                nodeIndex=0,
                p_in = torch.ones((nBatch), dtype=torch.float, requires_grad=True)
            )

        else:
            def process_node(depth, nodeIndices):
                if depth >= self.depth:
                    return
                
                nonlocal y
                # Project x onto the current node's INPUT basis vector
                #   λ = x DOT currNode.X
                # (nBatch, nIn) (nBatch, nIn) -> nBatch
                λ = torch.einsum('bi, bi -> b', x, self.X[nodeIndices])

                # Project this contribution into OUTPUT space:
                #   y += λ currNode.Y
                # nBatch, (nBatch, nOut) -> (nBatch, nOut)
                y = y + torch.einsum('n, nj -> nj', λ, self.Y[nodeIndices])

                # We'll branch right if x is "sunny-side" of the
                # hyperplane defined by node.x (else left)
                branch_choices = (λ > 0).long()  # 1 if λ > 0, else 0, so -ve λ => LEFT, +ve λ => RIGHT

                # figure out index of node in next layer to visit
                process_node(
                    depth+1,
                    nodeIndices=(nodeIndices * 2) + 1 + branch_choices
                )

            process_node(
                depth=0,
                nodeIndices=torch.zeros(nBatch, dtype=torch.int, device=x.device)
            )

        return y

    def __repr__(self):
        return f"FFF({self.X.shape[-1]}, {self.Y.shape[-1]}, depth={self.depth})"


fff = FFF(nIn=10, nOut=10, depth=2)
nBatch = 2

print('Testing training:')
x = torch.randn((nBatch, 10), requires_grad=True)
y = fff(x)
cost = torch.norm(y)
cost.backward()
print('x.grad:', x.grad)

print('Testing inference:')
x = torch.randn((nBatch, 10), requires_grad=False)
y = fff(x)
print('y:', y)


Testing training:
x.grad: tensor([[-0.0915, -0.4196, -0.4384,  0.3333,  0.4608,  0.0976,  0.5762, -0.6377,
         -0.4291,  0.0791],
        [-0.0241, -0.0702, -0.0911,  0.0158,  0.0300, -0.0136,  0.0750, -0.0916,
          0.0019,  0.0147]])
Testing inference:
y: tensor([[ 0.0445,  0.3534, -0.2324,  0.1784, -0.3330, -0.0412, -0.6443,  0.3417,
          0.1268, -0.1882],
        [-0.1466, -0.0334, -0.0481,  0.0902, -0.3356, -0.3017,  0.3522, -0.0384,
          0.4485,  0.2918]], grad_fn=<AddBackward0>)


In [4]:
# Load MNIST
import torchvision
import torchvision.transforms as transforms

def load_MNIST():
    # Transformations
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    # MNIST dataset
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Data loaders
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)

    return trainloader, testloader

trainloader, testloader = load_MNIST()

In [13]:

# Test harness
# We'll use this later
def orthogonality_loss(basis_vectors):
    # Compute pairwise dot products
    dot_products = torch.matmul(basis_vectors, basis_vectors.T)
    
    # Zero out diagonal elements (self dot products)
    eye = torch.eye(dot_products.size(0)).to(dot_products.device)
    dot_products = dot_products * (1 - eye)
    
    # Sum of squares of off-diagonal elements (which should be close to zero)
    loss = (dot_products ** 2).sum()
    return loss

def train_and_test(net, ortho=False):
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

    # Training the network
    for epoch in range(NEPOCH):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if ortho:  # hasattr(net, 'orthogonality_penalty'):
                loss += .001 * net.orthogonality_penalty()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % EVERY_N == EVERY_N - 1:  # print EVERY_N mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / EVERY_N:.3f}')
                running_loss = 0.0

        # Testing the network on the test data
        correct, total = 0, 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'[{epoch}] Accuracy of the network over test images: {100 * correct / total:.3f} %')


In [6]:
# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFF(nIn=28*28, nOut=500, depth=DEPTH)
        self.fc2 = FFF(nIn=500, nOut=10, depth=DEPTH)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    def orthogonality_penalty(self):
        # Calculate orthogonality loss for each PiSlice layer
        loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
        loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
        return loss1 + loss2

In [7]:
train_and_test(Net())
# train_and_test(Net(), ortho=True)

  0%|          | 0/1 [00:00<?, ?it/s]

[1,   200] loss: 0.759
[1,   400] loss: 0.406


100%|██████████| 1/1 [01:35<00:00, 95.79s/it]

[0] Accuracy of the network over test images: 52.150 %





# 🔸 CIFAR10

In [8]:
import torchvision
import torchvision.transforms as transforms
import torch

def load_CIFAR10():
    # https://github.com/microsoft/debugpy/issues/1166
    # ^ set nWorkers=1 to avoid this
    def create_dataloader(train, nBatch, shuffle, nWorkers=1):
        dataset = torchvision.datasets.CIFAR10(
            root='./data', train=train, download=True, transform=transform
        )
        return torch.utils.data.DataLoader(
            dataset, batch_size=nBatch, shuffle=shuffle, num_workers=nWorkers
        )

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainloader = create_dataloader(train=True, nBatch=BATCH_SIZE, shuffle=True)
    testloader = create_dataloader(train=False, nBatch=1024, shuffle=False)

    # Class labels
    classes = 'plane car bird cat deer dog frog horse ship truck'.split()
    
    return trainloader, testloader, classes

trainloader, testloader, classes = load_CIFAR10()


Files already downloaded and verified
Files already downloaded and verified


In [9]:
import torch.nn as nn

class Net_regular(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Net_FFF(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = FFF(nIn=16 * 5 * 5, nOut=120)
        self.fc2 = FFF(nIn=120, nOut=84)
        self.fc3 = FFF(nIn=84, nOut=10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [15]:
net = Net_regular()
train_and_test(net)

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
ERROR:tornado.general:SEND Error: Host unreachable


[1,   200] loss: 1.851


0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


In [11]:
train_and_test(Net_FFF())


  0%|          | 0/1 [00:00<?, ?it/s]

[1,   200] loss: 1.791


100%|██████████| 1/1 [01:59<00:00, 119.52s/it]

[0] Accuracy of the network over test images: 11.070 %



