# Optimization prototype

The problem with FFF:

- ~50% of profiled time is spent in lookups
- each batch-item takes a different route thru the tree
- this forces many lookups: e.g. tree depth 8, nBatch=1024 -> 8k lookups

# Proposal
- sort our input vectors as we go
    - so we end up with all the input vectors that map to 0000 (left left left left), then 0001 0010 etc
    - then we can process each of these 16 chunks simul on CUDA

We'll try to get a naive Python impl working (using a recursion) to check the logic.

TODO: If the same weights are used, we can `assert FFF(x) == FFFF(x)`

# Status
- takes 5x longer to train and gives 66% accuracy from 1 epoch (FFF gives 88%)

Nevertheless, the idea looks solid

In [1]:
from typing import Optional
from math import floor, log2, sqrt
import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
import lovely_tensors as lt
lt.monkey_patch()

In [3]:
NEPOCH = 10
BATCH_SIZE = 64
EVERY_N = 200
INIT_STRAT = 'hyperspherical-shell'


In [4]:

def set_random_seed(random_seed=42):
    # PyTorch
    torch.manual_seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(random_seed)

    # Numpy
    np.random.seed(random_seed)

    # Python's `random` module
    random.seed(random_seed)

    # If you are using cudnn, set this to True to make computation deterministic
    # Note: This might reduce performance
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [5]:
class FFF(nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        if INIT_STRAT == 'gaussian':
            # This from orig authors; scaling looks off for self.Y
            def create_basis_vectors_of(length, scaling):
                return nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
            self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
            self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        elif INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def create_random_unit_vectors_of(length):
                weights = torch.randn(nNodes, length)  # Initialize weights randomly
                weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return nn.Parameter(weights)
            self.X = create_random_unit_vectors_of(length=nIn)
            self.Y = create_random_unit_vectors_of(length=nOut)

    def forward(self, x: torch.Tensor):
        nBatch, nIn, nOut = x.shape[0], self.X.shape[-1], self.Y.shape[-1]

        current_node = torch.zeros(nBatch, dtype=torch.long, device=x.device)

        # Walk the tree, assembling y piecemeal
        y = torch.zeros((nBatch, nOut), dtype=torch.float, device=x.device)
        for _ in range(self.depth):
            # Project x onto the current node's INPUT basis vector
            #   λ = x DOT currNode.X
            # (nBatch, nIn,) (nBatch, nIn) -> (nBatch,)
            λ = torch.einsum("b i, b i -> b", x, self.X[current_node])

            # Project this contribution into OUTPUT space:
            #   y += λ currNode.Y
            # (nBatch,) (nBatch, nOut) -> (nBatch, nOut)
            y += torch.einsum("b, b j -> b j", λ, self.Y[current_node])

            # We'll branch right if x is "sunny-side" of the
            # hyperplane defined by node.x (else left)
            branch_choice = (λ > 0).long()

            # figure out index of node in next layer to visit
            current_node = (current_node * 2) + 1 + branch_choice

        return y

    def __repr__(self):
        return f"FFF({self.X.shape[-1]}, {self.Y.shape[-1]}, depth={self.depth})"


In [6]:
class FFFF(nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        if INIT_STRAT == 'gaussian':
            # This from orig authors; scaling looks off for self.Y
            def create_basis_vectors_of(length, scaling):
                return nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
            self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
            self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        elif INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def create_random_unit_vectors_of(length):
                weights = torch.randn(nNodes, length)  # Initialize weights randomly
                weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return nn.Parameter(weights)
            self.X = create_random_unit_vectors_of(length=nIn)
            self.Y = create_random_unit_vectors_of(length=nOut)

    # assuming batch_size >> treeSize
    def forward(self, x: torch.Tensor):
        nBatch, nOut = x.shape[0], self.Y.shape[-1]

        y = torch.zeros((nBatch, nOut), dtype=torch.float)

        def process_node(curr_node, depth, indices):
            if depth == self.depth or len(indices) == 0:
                return

            # Project x[indices] onto the current node's INPUT basis vector
            #   λ = x[indices] DOT currNode.X
            λ = torch.einsum("ni, i -> n", x[indices], self.X[curr_node])  # (nInd, nIn), nIn -> nInd

            # Update y-values for our indices
            #   y[indices] += λ * self.Y[indices]
            y[indices] += torch.einsum("n, j -> nj", λ, self.Y[curr_node])  # nInd, nOut -> nInd, nOut

            # Split indices based on λ >= 0
            indices_left = torch.where(λ < 0.0)[0]
            indices_right = torch.where(λ >= 0.0)[0]

            process_node(2*curr_node + 1, depth+1, indices_left)
            process_node(2*curr_node + 2, depth+1, indices_right)

        # our root node will process ALL batch-items
        all_indices = torch.arange(nBatch, dtype=torch.int)

        process_node(curr_node=0, depth=0, indices=all_indices)
        return y


In [7]:
def test_forward_pass(NBATCH, NIN, NOUT, DEPTH):
    set_random_seed()
    x1 = torch.rand((NBATCH, NIN))
    fff = FFF(NIN, NOUT, DEPTH)
    y1 = fff(x1)

    set_random_seed()
    x2 = torch.rand((NBATCH, NIN))
    ffff = FFFF(NIN, NOUT, DEPTH)
    y2 = ffff(x2)

    assert torch.allclose(x1, x2)

    # print('diff:', float(torch.norm(y1)), float(torch.norm(y2)), float(torch.norm(y1 - y2)))

    print(f'  {NBATCH:4d} {NIN:6d} {NOUT:6d} {DEPTH:3d}     ', '✅' if torch.allclose(y1, y2) else '❌', float(torch.norm(y1 - y2)))

print('nBatch    nIn   nOut   depth')   
test_forward_pass(99, 8, 8, 2)
test_forward_pass(100, 8, 8, 2)

test_forward_pass(3, 126, 64, 3)
test_forward_pass(3, 127, 64, 3)

test_forward_pass(3, 126, 1, 3)
test_forward_pass(3, 127, 1, 3)

test_forward_pass(1, 10000, 10000, 8)


nBatch    nIn   nOut   depth
    99      8      8   2      ✅ 5.882773166376865e-07
   100      8      8   2      ❌ 5.738617687711667e-07
     3    126     64   3      ✅ 0.0
     3    127     64   3      ❌ 0.9158244729042053
     3    126      1   3      ✅ 0.0
     3    127      1   3      ❌ 1.2935090065002441
     1  10000  10000   8      ✅ 0.0


In [8]:
# Assuming FFF and FFFF are defined and have the same parameters for a fair comparison

def set_random_seed(seed=42):
    torch.manual_seed(seed)

NBATCH = 10
NIN, NOUT = 512, 1024

# Set random seed and initialize models
set_random_seed()
fff = FFF(NIN, NOUT, depth=4)
set_random_seed()
ffff = FFFF(NIN, NOUT, depth=4)

# Forward pass
x = torch.rand((NBATCH, NIN))
y1 = fff(x)
y2 = ffff(x)

assert torch.allclose(y1, y2)


# Create a dummy target
target = torch.rand_like(y1)

# Compute loss
loss1 = F.mse_loss(y1, target)
loss2 = F.mse_loss(y2, target)

# Backward pass
loss1.backward()
loss2.backward()

# Check gradients
for (param1, param2) in zip(fff.parameters(), ffff.parameters()):
    assert torch.allclose(param1.grad, param2.grad, atol=1e-6), "Gradients do not match"

print("Backward pass test passed!")


AssertionError: 

# Load MNIST

In [9]:
import torchvision
import torchvision.transforms as transforms

# Transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)


# Test harness

In [10]:
from torch import nn, functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [11]:
# We'll use this later
def orthogonality_loss(basis_vectors):
    # Compute pairwise dot products
    dot_products = torch.matmul(basis_vectors, basis_vectors.T)
    
    # Zero out diagonal elements (self dot products)
    eye = torch.eye(dot_products.size(0)).to(dot_products.device)
    dot_products = dot_products * (1 - eye)
    
    # Sum of squares of off-diagonal elements (which should be close to zero)
    loss = (dot_products ** 2).sum()
    return loss

In [12]:
def train_and_test(net, ortho=False):
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

    # Training the network
    for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if ortho:  # hasattr(net, 'orthogonality_penalty'):
                loss += .001 * net.orthogonality_penalty()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % EVERY_N == EVERY_N - 1:  # print EVERY_N mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / EVERY_N:.3f}')
                running_loss = 0.0

    print('Finished Training')

    # Testing the network on the test data
    correct, total = 0, 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network over test images: {100 * correct / total:.3f} %')


In [13]:
# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFFF(nIn=28*28, nOut=500, depth=8)
        self.fc2 = FFFF(nIn=500, nOut=10, depth=8)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    def orthogonality_penalty(self):
        # Calculate orthogonality loss for each PiSlice layer
        loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
        loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
        return loss1 + loss2

train_and_test(Net())
# train_and_test(Net(), ortho=True)

  0%|          | 0/10 [00:00<?, ?it/s]

[1,   200] loss: 1.773
[1,   400] loss: 1.276
[1,   600] loss: 1.135
[1,   800] loss: 1.037


 10%|█         | 1/10 [00:47<07:06, 47.37s/it]

[2,   200] loss: 0.958
[2,   400] loss: 0.931
[2,   600] loss: 0.915
[2,   800] loss: 0.893


 20%|██        | 2/10 [01:38<06:36, 49.56s/it]

[3,   200] loss: 0.866
[3,   400] loss: 0.871
[3,   600] loss: 0.829
[3,   800] loss: 0.863


 30%|███       | 3/10 [02:30<05:54, 50.64s/it]

[4,   200] loss: 0.801
[4,   400] loss: 0.824
[4,   600] loss: 0.818
[4,   800] loss: 0.788


 40%|████      | 4/10 [03:23<05:09, 51.55s/it]

[5,   200] loss: 0.781
[5,   400] loss: 0.807
[5,   600] loss: 0.808
[5,   800] loss: 0.792


 50%|█████     | 5/10 [04:15<04:18, 51.61s/it]

[6,   200] loss: 0.778
[6,   400] loss: 0.791
[6,   600] loss: 0.787
[6,   800] loss: 0.790


 60%|██████    | 6/10 [05:07<03:26, 51.73s/it]

[7,   200] loss: 0.794
[7,   400] loss: 0.769
[7,   600] loss: 0.733
[7,   800] loss: 0.772


 70%|███████   | 7/10 [05:59<02:35, 51.83s/it]

[8,   200] loss: 0.741
[8,   400] loss: 0.744
[8,   600] loss: 0.749
[8,   800] loss: 0.752


 80%|████████  | 8/10 [06:52<01:44, 52.30s/it]

[9,   200] loss: 0.740
[9,   400] loss: 0.738
[9,   600] loss: 0.737
[9,   800] loss: 0.754


 90%|█████████ | 9/10 [07:45<00:52, 52.69s/it]

[10,   200] loss: 0.719
[10,   400] loss: 0.727
[10,   600] loss: 0.729
[10,   800] loss: 0.730


100%|██████████| 10/10 [08:40<00:00, 52.10s/it]


Finished Training
Accuracy of the network over test images: 75.460 %
