# Train Full Tree

The first paper version suggests using the WHOLE tree during training, and ONLY
during inference do we choose a single path thru the tree.

Notice the `p` and `1-p` in the diagram.

So let's do it!

```
Accuracy of the network over test images: 97.370 %
```


In [1]:
EVERY_N = 200
NEPOCH = 10
BATCH_SIZE = 256
INIT_STRAT = 'hyperspherical-shell'
DEPTH = 8


In [2]:
import torch as torch
import torch.nn as nn
from typing import Optional
from math import floor, log2, sqrt
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [3]:
class FFF(nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        if INIT_STRAT == 'gaussian':
            # This from orig authors; scaling looks off for self.Y
            def create_basis_vectors_of(length, scaling):
                return nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
            self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
            self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        elif INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def create_random_unit_vectors_of(length):
                weights = torch.randn(nNodes, length)  # Initialize weights randomly
                weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return nn.Parameter(weights)
            self.X = create_random_unit_vectors_of(length=nIn)
            self.Y = create_random_unit_vectors_of(length=nOut)

    def forward(self, x: torch.Tensor):
        nBatch, nOut, nNodes = x.shape[0], self.Y.shape[-1], 2 ** self.depth - 1

        # Walk the tree, assembling y piecemeal
        y = torch.zeros((nBatch, nOut), dtype=torch.float, device=x.device, requires_grad=True)
        
        def process_node(nodeIndex, p_in):
            nonlocal y
            if nodeIndex >= nNodes:
                return

            # Project x onto the current node's INPUT basis vector
            #   λ = x DOT currNode.X
            # (nBatch, nIn) nIn -> nBatch
            λ = torch.einsum("b i, i -> b", x, self.X[nodeIndex])

            # Project this contribution into OUTPUT space:
            #   y += λ currNode.Y
            # nBatch, nOut -> (nBatch, nOut)
            # y += torch.einsum("b, j -> b j", λ, self.Y[nodeIndex])
            dy = torch.einsum("b, j -> b j", λ, self.Y[nodeIndex])
            y = y + torch.einsum("b, b j -> b j", p_in, dy)

            p = torch.sigmoid(λ)
            process_node((nodeIndex * 2) + 1, p)
            process_node((nodeIndex * 2) + 2, 1-p)

        process_node(
            nodeIndex=0,
            p_in = torch.ones((nBatch), dtype=torch.float, requires_grad=True)
        )
        return y

    def __repr__(self):
        return f"FFF({self.X.shape[-1]}, {self.Y.shape[-1]}, depth={self.depth})"

fff = FFF(nIn=10, nOut=10, depth=2)
nBatch = 2
x = torch.randn((nBatch, 10), requires_grad=True)
y = fff(x)
print('y:', y)
cost = torch.norm(y)
cost.backward()
print('x.grad:', x.grad)



y: tensor([[-0.2318,  0.3069, -0.0949,  0.6865, -0.1751,  0.2108, -0.4833, -0.6652,
          0.5381,  0.0780],
        [ 0.1936, -0.3158,  0.1007, -0.4679,  0.3560, -0.1575,  0.1962,  0.5556,
         -0.6976, -0.3377]], grad_fn=<AddBackward0>)
x.grad: tensor([[-0.0090, -0.3167,  0.1280,  0.2915,  0.4237, -0.3709, -0.0853, -0.1341,
          0.1077, -0.0385],
        [-0.0338,  0.4347, -0.2762, -0.3213, -0.4947,  0.2217,  0.1488,  0.0905,
         -0.2188,  0.0957]])


In [4]:
# torch.Size([2, 10]) torch.Size([10])

p, q = torch.randn((2, 10)), torch.randn(10)
p.shape, q.shape

λ = torch.einsum("b i, i -> b", p, q)

# nBatch, nOut -> (nBatch, nOut)
Y = torch.randn((10))
y = torch.einsum("b, j -> b j", λ, Y)
y.shape


torch.Size([2, 10])

In [5]:
# Load MNIST
import torchvision
import torchvision.transforms as transforms

# Transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)

# Test harness
# We'll use this later
def orthogonality_loss(basis_vectors):
    # Compute pairwise dot products
    dot_products = torch.matmul(basis_vectors, basis_vectors.T)
    
    # Zero out diagonal elements (self dot products)
    eye = torch.eye(dot_products.size(0)).to(dot_products.device)
    dot_products = dot_products * (1 - eye)
    
    # Sum of squares of off-diagonal elements (which should be close to zero)
    loss = (dot_products ** 2).sum()
    return loss
def train_and_test(net, ortho=False):
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

    # Training the network
    for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if ortho:  # hasattr(net, 'orthogonality_penalty'):
                loss += .001 * net.orthogonality_penalty()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % EVERY_N == EVERY_N - 1:  # print EVERY_N mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / EVERY_N:.3f}')
                running_loss = 0.0

    print('Finished Training')

    # Testing the network on the test data
    correct, total = 0, 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network over test images: {100 * correct / total:.3f} %')


In [6]:
# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFF(nIn=28*28, nOut=500, depth=DEPTH)
        self.fc2 = FFF(nIn=500, nOut=10, depth=DEPTH)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    def orthogonality_penalty(self):
        # Calculate orthogonality loss for each PiSlice layer
        loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
        loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
        return loss1 + loss2

train_and_test(Net())
# train_and_test(Net(), ortho=True)

  0%|          | 0/10 [00:00<?, ?it/s]

[1,   200] loss: 0.323


 10%|█         | 1/10 [00:47<07:08, 47.57s/it]

[2,   200] loss: 0.123


 20%|██        | 2/10 [01:36<06:27, 48.42s/it]

[3,   200] loss: 0.092


 30%|███       | 3/10 [02:23<05:32, 47.56s/it]

[4,   200] loss: 0.079


 40%|████      | 4/10 [03:08<04:40, 46.79s/it]

[5,   200] loss: 0.065


 50%|█████     | 5/10 [03:54<03:51, 46.28s/it]

[6,   200] loss: 0.072


 60%|██████    | 6/10 [04:42<03:08, 47.03s/it]

[7,   200] loss: 0.051


 70%|███████   | 7/10 [05:29<02:20, 46.95s/it]

[8,   200] loss: 0.052


 80%|████████  | 8/10 [06:14<01:33, 46.50s/it]

[9,   200] loss: 0.045


 90%|█████████ | 9/10 [07:01<00:46, 46.47s/it]

[10,   200] loss: 0.039


100%|██████████| 10/10 [07:49<00:00, 46.91s/it]


Finished Training
Accuracy of the network over test images: 97.370 %
