## Flat structure for X projections

I calculated all the lambdas in parallel. This lambdas now represent value function which are only approximations of the true lambda values.
The main advantage is that the values can be calculated in parallel. This is equivalent as calculating lambdas for the whole tree in parallel.

The approximation between Value functions and lambdas is now done with Y projections.
The Y projections take Value functions, and approximate lambda values that are then used for Y projections. Since this is linear operations they can be combined into single matrix Y projection.

### Results
- There is no degradation in quality for CIFAR10
- Network uses half memory bandwidth
- Network is 2x-3x faster

## Future work
- Check if having tree like structure for Y projections gives any advantage over flat structure

In [1]:
from typing import Optional
from math import floor, log2, sqrt

import torch as torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

In [7]:
INIT_STRAT = 'hyperspherical-shell'


class FFF(nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        # each node "holds" a basis-vector in INPUT space (.X) and in OUTPUT space (.Y)

        if INIT_STRAT == 'gaussian':
            # This from orig authors; scaling looks off for self.Y
            def create_basis_vectors_of(length, scaling):
                return nn.Parameter(torch.empty(nNodes, length).uniform_(-scaling, scaling))
            self.X = create_basis_vectors_of(length=nIn, scaling=1/sqrt(nIn))
            self.Y = create_basis_vectors_of(length=nOut, scaling=1/sqrt(self.depth + 1))

        elif INIT_STRAT == 'hyperspherical-shell':
            # Initialize vectors on INPUT/OUTPUT space unit hypersphere
            #   (idea: basis vectors should be of unit length).
            def create_random_unit_vectors_of(length):
                weights = torch.randn(nNodes, length)  # Initialize weights randomly
                weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
                return nn.Parameter(weights)
            self.X = create_random_unit_vectors_of(length=nIn)
            self.Y = create_random_unit_vectors_of(length=nOut)

    def push_away_loss(self, x: torch.Tensor):
        return torch.pow((1 - torch.abs(x)), 2).sum()

    def forward(self, x: torch.Tensor):
        nBatch, nIn, nOut = x.shape[0], self.X.shape[-1], self.Y.shape[-1]

        current_node = torch.zeros(nBatch, dtype=torch.long, device=x.device)

        # Walk the tree, assembling y piecemeal
        y = torch.zeros((nBatch, nOut), dtype=torch.float, device=x.device)
        for depth in range(self.depth):
            # Project x onto the current node's INPUT basis vector
            #   λ = x DOT currNode.X
            # (nBatch, nIn,) (nBatch, nIn) -> (nBatch,)
            λ = torch.einsum("b i, b i -> b", x, self.X[current_node])

            y += torch.einsum("b, b j -> b j", λ, self.Y[current_node])

            branch_choice = (λ > 0.0).long()
            current_node = (current_node * 2) + 1 + branch_choice
        return y

class FFF_v2(nn.Module):
    def __init__(self, nIn: int, nOut: int, depth: Optional[int] = None):
        super().__init__()
        self.depth = depth or int(floor(log2(nIn)))  # depth is the number of decision boundaries
        nNodes = 2 ** self.depth - 1

        def create_random_unit_vectors_of(length):
            weights = torch.randn(nNodes, length)  # Initialize weights randomly
            weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
            return nn.Parameter(weights)

        self.X = nn.Linear(nIn, self.depth, bias=False)
        self.Y = create_random_unit_vectors_of(length=nOut)

    def forward_old_code(self, x: torch.Tensor):
        nBatch, nOut = x.shape[0], self.Y.shape[-1]
        current_node = torch.zeros(nBatch, dtype=torch.long, device=x.device)
        y = torch.zeros((nBatch, nOut), dtype=torch.float, device=x.device)

        λ = self.X(x)
        for i in range(self.depth):
            y += torch.einsum("b, b j -> b j", λ[:, i], self.Y[current_node])
            branch_choice = (λ[:,i] > 0).long()
            current_node = (current_node * 2) + 1 + branch_choice
        return y

    def forward(self, x: torch.Tensor):
        nBatch, nOut = x.shape[0], self.Y.shape[-1]

        λ = self.X(x)
        branch_choice = (λ > 0).long()

        indenes = torch.empty((nBatch, self.depth), dtype=torch.long, device=x.device)
        current_node = torch.zeros(nBatch, dtype=torch.long, device=x.device)
        for i in range(self.depth):
            indenes[:, i] = current_node
            current_node = (current_node * 2) + 1 + branch_choice[:, i]

        y = torch.einsum("b i, b i j -> b j", λ, self.Y[indenes])
        return y

In [8]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [9]:
class Net_FFF(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = FFF(16 * 5 * 5, 120, depth=None)
        self.fc2 = FFF(120, 84, depth=None)
        self.fc3 = FFF(84, 10, depth=None)

    def forward(self, x):
        train_key = False
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x= F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Net_FFF_v2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = FFF_v2(16 * 5 * 5, 120, depth=None)
        self.fc2 = FFF_v2(120, 84, depth=None)
        self.fc3 = FFF_v2(84, 10, depth=None)

    def forward(self, x):
        train_key = False
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x= F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [10]:
def evaluate(net: nn.Module, data_loader: torch.utils.data.DataLoader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
    return 100 * correct / total

def train(net: nn.Module,
          trainloader: torch.utils.data.DataLoader,
          testloader: torch.utils.data.DataLoader,
          epochs: int):

    optimizer = optim.AdamW(net.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch} loss: {running_loss / len(trainloader)}")

In [14]:
train_for = 40

net_FFF = Net_FFF().to(device)
net_FFF_v2 = Net_FFF_v2().to(device)

In [16]:
train(net_FFF, trainloader, testloader, train_for)

Epoch 0 loss: 2.0214164122901
Epoch 1 loss: 1.6917507355780248
Epoch 2 loss: 1.5582671805720805
Epoch 3 loss: 1.466966018347484
Epoch 4 loss: 1.397243906469906
Epoch 5 loss: 1.342637888001054
Epoch 6 loss: 1.2993970338036032
Epoch 7 loss: 1.2574183681736821
Epoch 8 loss: 1.224845184237146
Epoch 9 loss: 1.1952471472418216
Epoch 10 loss: 1.167301650699752
Epoch 11 loss: 1.1404021249707703
Epoch 12 loss: 1.122293352928308
Epoch 13 loss: 1.0975981932466903
Epoch 14 loss: 1.081290222647245
Epoch 15 loss: 1.0580935062045027
Epoch 16 loss: 1.0382005909214849
Epoch 17 loss: 1.0193680747390708
Epoch 18 loss: 1.003220223862192
Epoch 19 loss: 0.989182462777628
Epoch 20 loss: 0.9743737374120356
Epoch 21 loss: 0.9651489597757149
Epoch 22 loss: 0.9503091465481712
Epoch 23 loss: 0.9410525194519316
Epoch 24 loss: 0.9349830432621109
Epoch 25 loss: 0.9194280529571006
Epoch 26 loss: 0.9148033989969727
Epoch 27 loss: 0.9084660581615575
Epoch 28 loss: 0.8975272242675352
Epoch 29 loss: 0.8923117654097964
Ep

In [15]:
train(net_FFF_v2, trainloader, testloader, train_for)

Epoch 0 loss: 2.049093236398819
Epoch 1 loss: 1.7930712974284921
Epoch 2 loss: 1.6349740842419207
Epoch 3 loss: 1.5265829590580346
Epoch 4 loss: 1.448452637018755
Epoch 5 loss: 1.3829183194338512
Epoch 6 loss: 1.3330699976752787
Epoch 7 loss: 1.290956637743489
Epoch 8 loss: 1.2561685651769419
Epoch 9 loss: 1.2253639742236613
Epoch 10 loss: 1.20175103534518
Epoch 11 loss: 1.180387809453413
Epoch 12 loss: 1.161542066680196
Epoch 13 loss: 1.144530805023125
Epoch 14 loss: 1.131169364885296
Epoch 15 loss: 1.11852028668689
Epoch 16 loss: 1.105437661375841
Epoch 17 loss: 1.0944929507077503
Epoch 18 loss: 1.0825332180618326
Epoch 19 loss: 1.073117415465967
Epoch 20 loss: 1.0629768208469577
Epoch 21 loss: 1.0559150125364514
Epoch 22 loss: 1.0503833335074013
Epoch 23 loss: 1.0432353771251182
Epoch 24 loss: 1.0371619969072854
Epoch 25 loss: 1.033444506737887
Epoch 26 loss: 1.0287578622703357
Epoch 27 loss: 1.0248553832168774
Epoch 28 loss: 1.0205961662485166
Epoch 29 loss: 1.016515857423358
Epoch

In [17]:
evaluate(net_FFF, testloader), evaluate(net_FFF_v2, testloader)

(tensor(54.2500, device='cuda:0'), tensor(60.2300, device='cuda:0'))