# Optimization prototype

The problem with FFF:

- ~50% of profiled time is spent in lookups
- each batch-item takes a different route thru the tree
- this forces many lookups: e.g. tree depth 8, nBatch=1024 -> 8k lookups

# Proposal
- sort our input vectors as we go
    - so we end up with all the input vectors that map to 0000 (left left left left), then 0001 0010 etc
    - then we can process each of these 16 chunks simul on CUDA

We'll try to get a naive Python impl working (using a recursion) to check the logic.

TODO: If the same weights are used, we can `assert FFF(x) == FFFF(x)`

# Status
- takes 5x longer to train and gives 66% accuracy from 1 epoch (FFF gives 88%)

Nevertheless, the idea looks solid

In [1]:
import torch as torch
import torch.nn as nn
from typing import Optional
from math import floor, log2, sqrt
import torch.nn.functional as F


In [2]:
import lovely_tensors as lt
lt.monkey_patch()

In [3]:
NEPOCH = 10
BATCH_SIZE = 64
EVERY_N = 200

In [4]:
INIT_STRAT = 'hyperspherical-shell'

class FFFF(nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        if INIT_STRAT == 'gaussian':
            # This from orig authors; scaling looks off for self.Y
            def create_basis_vectors_of(length, scaling):
                return nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
            self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
            self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        elif INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def create_random_unit_vectors_of(length):
                weights = torch.randn(nNodes, length)  # Initialize weights randomly
                weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return nn.Parameter(weights)
            self.X = create_random_unit_vectors_of(length=nIn)
            self.Y = create_random_unit_vectors_of(length=nOut)

    # assuming batch_size >> treeSize
    def forward(self, x: torch.Tensor):
        nBatch, nOut = x.shape[0], self.Y.shape[-1]

        y = torch.zeros((nBatch, nOut), dtype=torch.float)

        def process_node(curr_node, depth, indices):
            if depth == self.depth or len(indices) == 0:
                return

            # Project x[indices] onto the current node's INPUT basis vector
            #   λ = x[indices] DOT currNode.X
            λ = torch.einsum("ni, i -> n", x[indices], self.X[curr_node])  # (nInd, nIn), nIn -> nInd

            # Update y-values for our indices
            #   y[indices] += λ * self.Y[indices]
            y[indices] += torch.einsum("n, j -> nj", λ, self.Y[curr_node])  # nInd, nOut -> nInd, nOut

            # Split indices based on λ >= 0
            indices_left = torch.where(λ < 0.0)[0]
            indices_right = torch.where(λ >= 0.0)[0]

            process_node(2*curr_node + 1, depth+1, indices_left)
            process_node(2*curr_node + 2, depth+1, indices_right)

        # our root node will process ALL batch-items
        all_indices = torch.arange(nBatch, dtype=torch.int)

        process_node(curr_node=0, depth=0, indices=all_indices)
        return y


# Load MNIST

In [5]:
import torchvision
import torchvision.transforms as transforms

# Transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)


# Test harness

In [6]:
from torch import nn, functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [7]:
# We'll use this later
def orthogonality_loss(basis_vectors):
    # Compute pairwise dot products
    dot_products = torch.matmul(basis_vectors, basis_vectors.T)
    
    # Zero out diagonal elements (self dot products)
    eye = torch.eye(dot_products.size(0)).to(dot_products.device)
    dot_products = dot_products * (1 - eye)
    
    # Sum of squares of off-diagonal elements (which should be close to zero)
    loss = (dot_products ** 2).sum()
    return loss

In [8]:
def train_and_test(net, ortho=False):
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

    # Training the network
    for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if ortho:  # hasattr(net, 'orthogonality_penalty'):
                loss += .001 * net.orthogonality_penalty()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % EVERY_N == EVERY_N - 1:  # print EVERY_N mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / EVERY_N:.3f}')
                running_loss = 0.0

    print('Finished Training')

    # Testing the network on the test data
    correct, total = 0, 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network over test images: {100 * correct / total:.3f} %')


In [9]:
# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFFF(nIn=28*28, nOut=500, depth=8)
        self.fc2 = FFFF(nIn=500, nOut=10, depth=8)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    def orthogonality_penalty(self):
        # Calculate orthogonality loss for each PiSlice layer
        loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
        loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
        return loss1 + loss2

train_and_test(Net())
# train_and_test(Net(), ortho=True)

  0%|          | 0/10 [00:00<?, ?it/s]

[1,   200] loss: 1.857
[1,   400] loss: 1.397
[1,   600] loss: 1.250
[1,   800] loss: 1.131


 10%|█         | 1/10 [01:42<15:24, 102.76s/it]

[2,   200] loss: 1.106
[2,   400] loss: 1.097
[2,   600] loss: 1.152
[2,   800] loss: 1.041


 20%|██        | 2/10 [03:25<13:43, 103.00s/it]

[3,   200] loss: 0.984
[3,   400] loss: 0.957
[3,   600] loss: 0.907
[3,   800] loss: 0.921


 30%|███       | 3/10 [05:09<12:03, 103.42s/it]

[4,   200] loss: 0.899
[4,   400] loss: 0.855
[4,   600] loss: 0.852
[4,   800] loss: 0.852


 40%|████      | 4/10 [06:51<10:16, 102.83s/it]

[5,   200] loss: 0.843
[5,   400] loss: 0.839
[5,   600] loss: 0.831
[5,   800] loss: 0.838


 50%|█████     | 5/10 [08:35<08:35, 103.17s/it]

[6,   200] loss: 0.848
[6,   400] loss: 0.823
[6,   600] loss: 0.814
[6,   800] loss: 0.825


 60%|██████    | 6/10 [10:19<06:53, 103.29s/it]

[7,   200] loss: 0.812
[7,   400] loss: 0.808
[7,   600] loss: 0.811
[7,   800] loss: 0.790


 70%|███████   | 7/10 [12:04<05:11, 103.83s/it]

[8,   200] loss: 0.811
[8,   400] loss: 0.800
[8,   600] loss: 0.779
[8,   800] loss: 0.807


 80%|████████  | 8/10 [13:44<03:25, 102.72s/it]

[9,   200] loss: 0.787
[9,   400] loss: 0.789
[9,   600] loss: 0.788
[9,   800] loss: 0.788


 90%|█████████ | 9/10 [15:27<01:42, 102.77s/it]

[10,   200] loss: 0.783
[10,   400] loss: 0.784
[10,   600] loss: 0.785
[10,   800] loss: 0.783


100%|██████████| 10/10 [17:11<00:00, 103.18s/it]


Finished Training
Accuracy of the network over test images: 72.660 %
