# HyperCube

🔴 Fail
Trains well on MNIST (95%) but refuses to train above random chance on CIFAR10 (10%)

The idea here is to have nHidden hidden nodes.

Each has a .x .y1 and .y2

We'll project the input x onto hidden.x getting λ

If λ > 0 we'll use .y1, so: y += λ^2 hidden.y1, else use .y2

So suppose we have nHidden = 8. Then we're making 8 decisions,
which is 2^8 = 256 possible {something from Y1} {something from Y2} pairs.

I'm using λ^2, as:
- It's always positive, and we want a POSITIVE coeff whether we project to .y1 or .y2
- λ^2 * {something bounded} will have a unique gradient at 0, whereas 
just bolting 2 vectors together at 0 will NOT.

For CIFAR10, I've tried 2 configurations:
```
        self.fc1 = Hypercube(nIn=16 * 5 * 5, nOut=120, nHidden=12)
        self.fc2 = Hypercube(nIn=120, nOut=84, nHidden=9)
        self.fc3 = Hypercube(nIn=84, nOut=10, nHidden=6)
```
Also nHidden = (32, 16, 8) which I think is too high.  
2^32 possibilities -- how could that ever train?

In [13]:
import lovely_tensors as lt
lt.monkey_patch()

# 🔹 Dataloaders

In [1]:
NBATCH, NBATCH_TEST = 128, -1


In [2]:
# Load MNIST
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def load_MNIST():
    # Transformations
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    # MNIST dataset
    trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Data loaders
    nTest = NBATCH_TEST if NBATCH_TEST > 0 else len(testset)
    trainloader = DataLoader(trainset, batch_size=NBATCH, shuffle=True)
    testloader = DataLoader(testset, batch_size=nTest, shuffle=False)

    return trainloader, testloader

# trainloader, testloader = load_MNIST()

In [3]:
def load_CIFAR10():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # https://github.com/microsoft/debugpy/issues/1166
    # ^ set nWorkers=1 to avoid this
    def create_dataloader(train, nBatch, shuffle, nWorkers=1):
        dataset = datasets.CIFAR10(
            root='./data', train=train, download=True, transform=transform
        )
        return DataLoader(
            dataset, batch_size=nBatch, shuffle=shuffle, num_workers=nWorkers
        )

    nTest = NBATCH_TEST if NBATCH_TEST > 0 else 1024
    trainloader = create_dataloader(train=True, nBatch=NBATCH, shuffle=True)
    testloader = create_dataloader(train=False, nBatch=1024, shuffle=False)

    # Class labels
    # classes = 'plane car bird cat deer dog frog horse ship truck'.split()
    
    return trainloader, testloader  #, classes

# trainloader, testloader, classes = load_CIFAR10()


In [4]:
trainloader, testloader = load_CIFAR10()

batch = next(iter(trainloader))
images, labels = batch
labels[:10]

Files already downloaded and verified
Files already downloaded and verified


tensor([5, 4, 5, 9, 6, 9, 4, 6, 9, 3])

# 🔹 Test Harness

In [7]:
NEPOCH = 10
EVERY_N = -1

import torch
from tqdm import tqdm

# Test harness
# We'll use this later
# def orthogonality_loss(basis_vectors):
#     # Compute pairwise dot products
#     dot_products = torch.matmul(basis_vectors, basis_vectors.T)
    
#     # Zero out diagonal elements (self dot products)
#     eye = torch.eye(dot_products.size(0)).to(dot_products.device)
#     dot_products = dot_products * (1 - eye)
    
#     # Sum of squares of off-diagonal elements (which should be close to zero)
#     loss = (dot_products ** 2).sum()
#     return loss

def train_and_test(net, dataloader, nEpoch=NEPOCH):
    trainloader, testloader = dataloader()

    # Loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

    # Training the network
    for epoch in tqdm(range(nEpoch)):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if hasattr(net, 'loss'):
                loss += net.loss()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % EVERY_N == EVERY_N - 1:  # print EVERY_N mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / EVERY_N:.3f}')
                running_loss = 0.0

        # Testing the network on the test data
        correct, total = 0, 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'[{epoch}] Accuracy of the network over test images: {100 * correct / total:.3f} %')


# 🔹 use MNIST / FF to check it works

In [8]:
# Neural network architecture
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(in_features=28*28, out_features=500)
        self.fc2 = torch.nn.Linear(in_features=500, out_features=10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    # def loss(self):
    #     # Calculate orthogonality loss for each PiSlice layer
    #     loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
    #     loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
    #     return loss1 + loss2

train_and_test(Net(), dataloader=load_MNIST, nEpoch=1)

100%|██████████| 1/1 [00:03<00:00,  3.15s/it]

[0] Accuracy of the network over test images: 93.380 %





# 🔸 FFF (with full-tree training)

In [9]:
INIT_STRAT = 'hyperspherical-shell'
# DEPTH = 8

In [42]:
import torch
from torch.nn.functional import normalize

from typing import Optional
from math import floor, log2, sqrt

class Hypercube(torch.nn.Module):
    def __init__(self, nIn:int, nOut:int, nHidden:int):
        super().__init__()
        # self.nDim = nDim or int(floor(log2(nIn)))  # depth is the number of decision boundaries

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        # if INIT_STRAT == 'gaussian':
        #     # This from orig authors; scaling looks off for self.Y
        #     def create_basis_vectors_of(length, scaling):
        #         return torch.nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
        #     self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
        #     self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        if INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def unit_vectors(shape):
                weights = torch.randn(shape)  # Initialize weights randomly
                weights = normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return torch.nn.Parameter(weights)
            self.X = unit_vectors((nHidden, nIn))
            self.Y = unit_vectors((2, nHidden, nOut))

            # self.loss = None

    def forward(self, x:torch.Tensor):
        nBatch, nOut, nHidden = len(x), self.Y.shape[-1], len(self.X)
        # is_training = x.requires_grad
    
        # Walk the tree, assembling y piecemeal
        y = torch.zeros((nBatch, nOut), dtype=torch.float, device=x.device, requires_grad=True)
        
        # loss = torch.tensor(0.0, requires_grad=True)  # scalar tensor
        # this would be wrong (1D tensor not same as scalar)
        # loss = torch.zeros(1, requires_grad=True)

        # Project x onto the current node's INPUT basis vector
        #   λ = x DOT currNode.X
        # (nBatch, nIn) (nHidden, nIn) -> nBatch, nHidden
        λ = torch.einsum('bi, hi -> bh', x, self.X)
        λλ = λ**2

        splitter_indices = (λ >= 0).long()  # (nBatch, nHidden)

        # (2, nHidden, nOut)
        splitter_indices_expanded = splitter_indices.unsqueeze(-1).expand(-1, -1, nOut)

        # self.Y: (2, nHidden, nOut)
        selected_Y = torch.gather(self.Y, 0, splitter_indices_expanded)

        scaled_chosen_Ys = λλ.unsqueeze(-1) * selected_Y

        y = torch.sum(scaled_chosen_Ys, axis=1)
        return y

    def __repr__(self):
        return f"FFF({self.X.shape[-1]}, {self.Y.shape[-1]}, depth={self.depth})"


# sanity check
nBatch, nIn, nHidden, nOut = 3, 10, 8, 12 
fff = Hypercube(nIn, nOut, nHidden)

print('Testing training:')
x = torch.randn((nBatch, nIn), requires_grad=True)
y = fff(x)
# print('loss:', fff.loss)
cost = torch.norm(y)
cost.backward()
print('x.grad:', x.grad)

print('Testing inference:')
x = torch.randn((nBatch, nIn), requires_grad=False)
y = fff(x)
# print('loss:', fff.loss)
print('y:', y)


Testing training:
x.grad: tensor[3, 10] n=30 x∈[-1.903, 0.784] μ=-0.476 σ=0.641
Testing inference:
y: tensor[3, 12] n=36 x∈[-4.254, 3.159] μ=-0.221 σ=1.782 grad SumBackward1


In [43]:
import torch
import numpy as np

# Manual initialization of smaller sample tensor shapes for easier output inspection
nBatch = 2
nHidden = 3
nOut = 2

# Manually create sample data in PyTorch with specific values for easy verification
Y = torch.tensor([[[1, 2], [3, 4], [5, 6]],  # left
                  [[7, 8], [9, 10], [11, 12]]])  # right
lambd = torch.tensor([[1, 2, 3],  # batch item 0
                      [4, 5, 6]])  # batch item 1
splitter = torch.tensor([[True, False, True],  # batch item 0
                         [False, True, False]])  # batch item 1

# Convert boolean splitter to integer indices (0 or 1) for PyTorch
splitter_indices = splitter.long()

# Explicit looping
result_loop = torch.zeros((nBatch, nHidden, nOut))
for b in range(nBatch):
    for h in range(nHidden):
        selected_Y_slice = Y[splitter_indices[b, h], h, :]
        result_loop[b, h, :] = lambd[b, h] * selected_Y_slice

# Without explicit looping using torch.gather and element-wise multiplication
splitter_indices_expanded = splitter_indices.unsqueeze(-1).expand(-1, -1, Y.size(-1))
selected_Y = torch.gather(Y, 0, splitter_indices_expanded)
result_no_loop = lambd.unsqueeze(-1) * selected_Y

print(splitter_indices_expanded.numpy())
print(Y.numpy())
print(selected_Y.numpy())

torch.allclose(result_loop, result_no_loop.float())



[[[1 1]
  [0 0]
  [1 1]]

 [[0 0]
  [1 1]
  [0 0]]]
[[[ 1  2]
  [ 3  4]
  [ 5  6]]

 [[ 7  8]
  [ 9 10]
  [11 12]]]
[[[ 7  8]
  [ 3  4]
  [11 12]]

 [[ 1  2]
  [ 9 10]
  [ 5  6]]]


True

# 🧪 Test it out on MNIST

In [49]:
# Neural network architecture
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = Hypercube(nIn=28*28, nOut=500, nHidden=32)
        self.fc2 = Hypercube(nIn=500, nOut=10, nHidden=8)

    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat

    # def loss(self):
    #     # Calculate orthogonality loss for each PiSlice layer
    #     # loss1 = orthogonality_loss(self.fc1.X) + orthogonality_loss(self.fc1.Y)
    #     # loss2 = orthogonality_loss(self.fc2.X) + orthogonality_loss(self.fc2.Y)
    #     # return loss1 + loss2
    #     return 0.01 * (self.fc1.loss + self.fc2.loss)

train_and_test(Net(), load_MNIST, nEpoch=10)
# train_and_test(Net(), ortho=True)

  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:04<00:42,  4.75s/it]

[0] Accuracy of the network over test images: 93.780 %


 20%|██        | 2/10 [00:09<00:37,  4.69s/it]

[1] Accuracy of the network over test images: 94.590 %


 30%|███       | 3/10 [00:14<00:32,  4.66s/it]

[2] Accuracy of the network over test images: 94.870 %


 40%|████      | 4/10 [00:18<00:28,  4.68s/it]

[3] Accuracy of the network over test images: 95.280 %


 50%|█████     | 5/10 [00:23<00:23,  4.68s/it]

[4] Accuracy of the network over test images: 94.660 %


 60%|██████    | 6/10 [00:28<00:18,  4.70s/it]

[5] Accuracy of the network over test images: 96.160 %


 70%|███████   | 7/10 [00:32<00:14,  4.69s/it]

[6] Accuracy of the network over test images: 96.760 %


 80%|████████  | 8/10 [00:37<00:09,  4.71s/it]

[7] Accuracy of the network over test images: 95.290 %


 90%|█████████ | 9/10 [00:42<00:04,  4.71s/it]

[8] Accuracy of the network over test images: 96.530 %


100%|██████████| 10/10 [00:46<00:00,  4.70s/it]

[9] Accuracy of the network over test images: 96.680 %





# 🧪 Test it out on CIFAR10

In [61]:
import torch.nn as nn
from torch.nn.functional import relu

class Net_regular(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool(relu(self.conv1(x)))
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = relu(self.fc1(x))
        x = relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Net_Hypercube(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = Hypercube(nIn=16 * 5 * 5, nOut=120, nHidden=4)
        self.fc2 = Hypercube(nIn=120, nOut=84, nHidden=3)
        self.fc3 = Hypercube(nIn=84, nOut=10, nHidden=2)

    def forward(self, x):
        x = self.pool(relu(self.conv1(x)))
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = relu(self.fc1(x))
        x = relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    # def loss(self):
    #     return 0.01 * (self.fc1.loss + self.fc2.loss + self.fc3.loss)


In [59]:
# Disable this:
'''
Files already downloaded and verified
Files already downloaded and verified
  0%|          | 0/10 [00:00<?, ?it/s]0.00s - Debugger warning: It seems that frozen modules are being used, which may
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - Debugger warning: It seems that frozen modules are being used, which may
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
'''

import os
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'

In [54]:
# net = Net_regular()
train_and_test(Net_regular(), dataloader=load_CIFAR10, nEpoch=10)

Files already downloaded and verified
Files already downloaded and verified


 10%|█         | 1/10 [00:25<03:47, 25.31s/it]

[0] Accuracy of the network over test images: 45.200 %


 20%|██        | 2/10 [00:49<03:19, 24.91s/it]

[1] Accuracy of the network over test images: 49.520 %


 30%|███       | 3/10 [01:14<02:54, 24.91s/it]

[2] Accuracy of the network over test images: 50.890 %


 40%|████      | 4/10 [01:38<02:27, 24.55s/it]

[3] Accuracy of the network over test images: 55.410 %


 50%|█████     | 5/10 [02:03<02:02, 24.45s/it]

[4] Accuracy of the network over test images: 57.620 %


 60%|██████    | 6/10 [02:27<01:38, 24.56s/it]

[5] Accuracy of the network over test images: 59.460 %


 70%|███████   | 7/10 [02:52<01:13, 24.59s/it]

[6] Accuracy of the network over test images: 59.540 %


 80%|████████  | 8/10 [03:17<00:49, 24.72s/it]

[7] Accuracy of the network over test images: 60.220 %


 90%|█████████ | 9/10 [03:41<00:24, 24.60s/it]

[8] Accuracy of the network over test images: 61.560 %


100%|██████████| 10/10 [04:06<00:00, 24.63s/it]

[9] Accuracy of the network over test images: 62.020 %





In [62]:
train_and_test(Net_Hypercube(), dataloader=load_CIFAR10, nEpoch=10)


Files already downloaded and verified
Files already downloaded and verified


 10%|█         | 1/10 [00:24<03:40, 24.48s/it]

[0] Accuracy of the network over test images: 10.380 %


 20%|██        | 2/10 [00:48<03:14, 24.30s/it]

[1] Accuracy of the network over test images: 10.380 %


 30%|███       | 3/10 [01:12<02:50, 24.30s/it]

[2] Accuracy of the network over test images: 10.380 %


 40%|████      | 4/10 [01:37<02:26, 24.41s/it]

[3] Accuracy of the network over test images: 10.340 %


 50%|█████     | 5/10 [02:02<02:02, 24.45s/it]

[4] Accuracy of the network over test images: 10.350 %


 50%|█████     | 5/10 [02:11<02:11, 26.20s/it]


KeyboardInterrupt: 