# Train Full Tree

🔴 WiP

(Currently it trains poorly, so there's some error)

The first paper version suggests using the WHOLE tree during training, and ONLY
during inference do we choose a single path thru the tree.

Notice the `p` and `1-p` in the diagram.

So let's do it!

If we just train the whole tree for each input item, and use the whole tree
for inference, we get FF-level accuracy:
```
Accuracy of the network over test images: 97.370 %
```

But that's distributing the computation over the whole tree.
Every node is pulling weight.

So we introduce a penalty that, for a given input, pushes each node
to channel its work thru either one branch or t'other.

i.e. an even split (corresponding to a p of 0.5) will be penalized.

This pushes the network to "harden" the decisions. It pushes it to use
p values close to 0 or close to 1.

And then when we do inference, we just pick the winning path, based
on the assunmption that THAT path is doing the bulk of the work.

Simultaneously it's pushing the y-projections to accomodate this
"hardened" decision-process.



# 🔹 Dataloaders

In [1]:
NBATCH, NBATCH_TEST = 128, -1


In [2]:
# Load MNIST
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def load_MNIST():
    # Transformations
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    # MNIST dataset
    trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Data loaders
    nTest = NBATCH_TEST if NBATCH_TEST > 0 else len(testset)
    trainloader = DataLoader(trainset, batch_size=NBATCH, shuffle=True)
    testloader = DataLoader(testset, batch_size=nTest, shuffle=False)

    return trainloader, testloader

# trainloader, testloader = load_MNIST()

In [3]:
def load_CIFAR10():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # https://github.com/microsoft/debugpy/issues/1166
    # ^ set nWorkers=1 to avoid this
    def create_dataloader(train, nBatch, shuffle, nWorkers=1):
        dataset = datasets.CIFAR10(
            root='./data', train=train, download=True, transform=transform
        )
        return DataLoader(
            dataset, batch_size=nBatch, shuffle=shuffle, num_workers=nWorkers
        )

    nTest = NBATCH_TEST if NBATCH_TEST > 0 else 1024
    trainloader = create_dataloader(train=True, nBatch=NBATCH, shuffle=True)
    testloader = create_dataloader(train=False, nBatch=1024, shuffle=False)

    # Class labels
    # classes = 'plane car bird cat deer dog frog horse ship truck'.split()
    
    return trainloader, testloader  #, classes

# trainloader, testloader, classes = load_CIFAR10()


In [4]:
trainloader, testloader = load_CIFAR10()

batch = next(iter(trainloader))
images, labels = batch
labels[:10]

Files already downloaded and verified
Files already downloaded and verified


tensor([2, 9, 9, 8, 7, 0, 3, 5, 1, 4])

# 🔹 Test Harness

In [5]:
NEPOCH = 10
EVERY_N = -1

import torch
from tqdm import tqdm

# Test harness
# We'll use this later
# def orthogonality_loss(basis_vectors):
#     # Compute pairwise dot products
#     dot_products = torch.matmul(basis_vectors, basis_vectors.T)
    
#     # Zero out diagonal elements (self dot products)
#     eye = torch.eye(dot_products.size(0)).to(dot_products.device)
#     dot_products = dot_products * (1 - eye)
    
#     # Sum of squares of off-diagonal elements (which should be close to zero)
#     loss = (dot_products ** 2).sum()
#     return loss

def train_and_test(net, dataloader):
    trainloader, testloader = dataloader()

    # Loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

    # Training the network
    for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if hasattr(net, 'loss'):
                loss += net.loss()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % EVERY_N == EVERY_N - 1:  # print EVERY_N mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / EVERY_N:.3f}')
                running_loss = 0.0

        # Testing the network on the test data
        correct, total = 0, 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'[{epoch}] Accuracy of the network over test images: {100 * correct / total:.3f} %')


# 🔹 use MNIST / FF to check it works

In [6]:
# Neural network architecture
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(in_features=28*28, out_features=500)
        self.fc2 = torch.nn.Linear(in_features=500, out_features=10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    # def loss(self):
    #     # Calculate orthogonality loss for each PiSlice layer
    #     loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
    #     loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
    #     return loss1 + loss2

train_and_test(Net(), dataloader=load_MNIST)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:07<01:09,  7.75s/it]

[0] Accuracy of the network over test images: 94.360 %


 20%|██        | 2/10 [00:15<01:01,  7.75s/it]

[1] Accuracy of the network over test images: 96.230 %


 30%|███       | 3/10 [00:23<00:54,  7.80s/it]

[2] Accuracy of the network over test images: 95.660 %


 40%|████      | 4/10 [00:31<00:46,  7.81s/it]

[3] Accuracy of the network over test images: 96.680 %


 50%|█████     | 5/10 [00:38<00:38,  7.71s/it]

[4] Accuracy of the network over test images: 96.940 %


 60%|██████    | 6/10 [00:46<00:31,  7.76s/it]

[5] Accuracy of the network over test images: 97.290 %


 70%|███████   | 7/10 [00:54<00:23,  7.77s/it]

[6] Accuracy of the network over test images: 97.790 %


 80%|████████  | 8/10 [01:02<00:15,  7.78s/it]

[7] Accuracy of the network over test images: 97.720 %


 90%|█████████ | 9/10 [01:09<00:07,  7.63s/it]

[8] Accuracy of the network over test images: 97.560 %


100%|██████████| 10/10 [01:16<00:00,  7.70s/it]

[9] Accuracy of the network over test images: 97.190 %





# 🔸 FFF (with full-tree training)

In [7]:
INIT_STRAT = 'hyperspherical-shell'
DEPTH = 8

In [8]:
import torch
from torch.nn.functional import normalize

from typing import Optional
from math import floor, log2, sqrt

class FFF(torch.nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        if INIT_STRAT == 'gaussian':
            # This from orig authors; scaling looks off for self.Y
            def create_basis_vectors_of(length, scaling):
                return torch.nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
            self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
            self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        elif INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def create_random_unit_vectors_of(length):
                weights = torch.randn(nNodes, length)  # Initialize weights randomly
                weights = normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return torch.nn.Parameter(weights)
            self.X = create_random_unit_vectors_of(length=nIn)
            self.Y = create_random_unit_vectors_of(length=nOut)

            self.p_loss = None

    def forward(self, x: torch.Tensor):
        nBatch, nOut, nNodes = x.shape[0], self.Y.shape[-1], 2 ** self.depth - 1
        is_training = x.requires_grad
    
        # Walk the tree, assembling y piecemeal
        y = torch.zeros((nBatch, nOut), dtype=torch.float, device=x.device, requires_grad=True)
        
        loss = torch.tensor(0.0, requires_grad=True)  # scalar tensor
        # this would be wrong (1D tensor not same as scalar)
        # loss = torch.zeros(1, requires_grad=True)

        if is_training:
            def process_node(nodeIndex, p_in):
                nonlocal y
                if nodeIndex >= nNodes:
                    return

                # Project x onto the current node's INPUT basis vector
                #   λ = x DOT currNode.X
                # (nBatch, nIn) nIn -> nBatch
                λ = torch.einsum('bi, i -> b', x, self.X[nodeIndex])

                # Project this contribution into OUTPUT space:
                #   y += λ currNode.Y
                # nBatch, nOut -> (nBatch, nOut)
                # y += torch.einsum("b, j -> b j", λ, self.Y[nodeIndex])
                dy = torch.einsum('b,j -> bj', λ, self.Y[nodeIndex])
                y = y + torch.einsum('b, bj -> bj', p_in, dy)

                # ChatGPT says we can elide:
                # y += torch.einsum("b, j, b -> b j", p_in, self.Y[nodeIndex], λ)

                # e.g.
                #   λ = -10, so p ~ 0.0, so 1-p is HIGH
                #   ... and we want to give HIGH weight to the LEFT node
                #   ... so THAT gets the 1-p
                p = torch.sigmoid(λ)

                # 0 if p is 0 or 1, 1 if p=0.5
                nonlocal loss
                loss = loss + torch.mean(4 * p * (1 - p), dim=0)

                # ^ alternative way:
                #     EPSILON = 1e-6
                #     p = torch.clamp(p, min=EPSILON, max=1-EPSILON)
                #     minus_p = torch.clamp(minus_p, min=EPSILON, max=1-EPSILON)
                #     return -p * torch.log(p) - minus_p * torch.log(minus_p)

                process_node((nodeIndex * 2) + 1, 1-p)
                process_node((nodeIndex * 2) + 2, p)

            process_node(
                nodeIndex=0,
                p_in = torch.ones((nBatch), dtype=torch.float, requires_grad=True)
            )

        else:
            def process_node(depth, nodeIndices):
                if depth >= self.depth:
                    return
                
                nonlocal y
                # Project x onto the current node's INPUT basis vector
                #   λ = x DOT currNode.X
                # (nBatch, nIn) (nBatch, nIn) -> nBatch
                λ = torch.einsum('bi, bi -> b', x, self.X[nodeIndices])

                # Project this contribution into OUTPUT space:
                #   y += λ currNode.Y
                # nBatch, (nBatch, nOut) -> (nBatch, nOut)
                y = y + torch.einsum('n, nj -> nj', λ, self.Y[nodeIndices])

                # We'll branch right if x is "sunny-side" of the
                # hyperplane defined by node.x (else left)
                branch_choices = (λ > 0).long()  # 1 if λ > 0, else 0, so -ve λ => LEFT, +ve λ => RIGHT

                # figure out index of node in next layer to visit
                process_node(
                    depth+1,
                    nodeIndices=(nodeIndices * 2) + 1 + branch_choices
                )

            process_node(
                depth=0,
                nodeIndices=torch.zeros(nBatch, dtype=torch.int, device=x.device)
            )

        self.loss = loss / nNodes

        return y

    def __repr__(self):
        return f"FFF({self.X.shape[-1]}, {self.Y.shape[-1]}, depth={self.depth})"


# sanity check

fff = FFF(nIn=10, nOut=10, depth=2)
nBatch = 2

print('Testing training:')
x = torch.randn((nBatch, 10), requires_grad=True)
y = fff(x)
print('loss:', fff.loss)
cost = torch.norm(y)
cost.backward()
print('x.grad:', x.grad)

print('Testing inference:')
x = torch.randn((nBatch, 10), requires_grad=False)
y = fff(x)
print('loss:', fff.loss)
print('y:', y)


Testing training:
loss: tensor(0.8769, grad_fn=<DivBackward0>)
x.grad: tensor([[ 0.5445,  0.3354,  0.0194, -0.4392, -0.3159,  0.0731,  0.1984, -0.1836,
         -0.2570, -0.1346],
        [-0.1014, -0.0053, -0.1769,  0.0853,  0.0147, -0.0228, -0.0308,  0.0714,
          0.0155,  0.1200]])
Testing inference:
loss: tensor(0., grad_fn=<DivBackward0>)
y: tensor([[ 0.2908,  0.2454, -0.0561,  0.0537,  0.2312,  0.5513,  0.1997, -0.0057,
          0.6941, -0.3613],
        [-0.3041, -0.5350, -0.4723,  0.8024, -0.7482,  0.6833,  0.4123,  0.2983,
         -0.5228,  0.9592]], grad_fn=<AddBackward0>)


# 🧪 Test it out on MNIST

In [12]:
DEPTH = 8

# Neural network architecture
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFF(nIn=28*28, nOut=500, depth=DEPTH)
        self.fc2 = FFF(nIn=500, nOut=10, depth=DEPTH)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    def loss(self):
        # Calculate orthogonality loss for each PiSlice layer
        # loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
        # loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
        # return loss1 + loss2
        return 0.01 * (self.fc1.loss + self.fc2.loss)

train_and_test(Net(), load_MNIST)
# train_and_test(Net(), ortho=True)

 10%|█         | 1/10 [02:01<18:12, 121.38s/it]

[0] Accuracy of the network over test images: 50.190 %


 20%|██        | 2/10 [03:59<15:56, 119.58s/it]

[1] Accuracy of the network over test images: 51.080 %


 30%|███       | 3/10 [05:59<13:57, 119.60s/it]

[2] Accuracy of the network over test images: 45.970 %


 40%|████      | 4/10 [07:58<11:55, 119.28s/it]

[3] Accuracy of the network over test images: 40.900 %


 40%|████      | 4/10 [08:16<12:25, 124.24s/it]


KeyboardInterrupt: 

# 🧪 Test it out on CIFAR10

In [10]:
import torch.nn as nn
from torch.nn.functional import relu

class Net_regular(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool(relu(self.conv1(x)))
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = relu(self.fc1(x))
        x = relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Net_FFF(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = FFF(nIn=16 * 5 * 5, nOut=120)
        self.fc2 = FFF(nIn=120, nOut=84)
        self.fc3 = FFF(nIn=84, nOut=10)

    def forward(self, x):
        x = self.pool(relu(self.conv1(x)))
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = relu(self.fc1(x))
        x = relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [11]:
# net = Net_regular()
train_and_test(Net_regular(), dataloader=load_CIFAR10)

Files already downloaded and verified
Files already downloaded and verified


 10%|█         | 1/10 [00:53<08:03, 53.70s/it]

[0] Accuracy of the network over test images: 44.140 %


 20%|██        | 2/10 [01:47<07:08, 53.51s/it]

[1] Accuracy of the network over test images: 48.360 %


 20%|██        | 2/10 [02:17<09:09, 68.64s/it]


KeyboardInterrupt: 

In [None]:
train_and_test(Net_FFF(), dataloader=load_CIFAR10)


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1/1 [02:30<00:00, 150.02s/it]

[0] Accuracy of the network over test images: 12.010 %



