In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
NEPOCH = 10

In [3]:
from torch import nn, functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [4]:
cd ..

/Users/solar/Code/DL/FFF/FFF


In [5]:
from FFF.ffff import F4

In [6]:
# many ways how to run forest model 
# inside one tree we branch simply based on a sign of projection
# if we have two trees than we will have two projection
# we can:
#   branch independently
#   branch jointly based on largest projection ~ in terms of abs value!
#   but we can not use abs due to non-differentiability
#   thus we need to rely on some sign-insensitive normalization 

In [7]:
B, C_in, C_out, N_tr, D = 7, 20, 30, 3, 6
M = F4(in_features=C_in, out_features=C_out, num_trees=N_tr, depth=D)

In [8]:
M.keys.shape, M.values.shape

(torch.Size([63, 3, 20]), torch.Size([63, 3, 30]))

In [10]:
M.keys[[0,1,3],[0,1,2]].shape, M.keys[[0,0,0],[0,1,2]].shape

(torch.Size([3, 20]), torch.Size([3, 20]))

#### Example of a Forest model forward

In [15]:
x = torch.rand(B, C_in)
y = torch.zeros((B, C_out), dtype=torch.float)

current_nodes = torch.zeros(B, N_tr, dtype=torch.long) 
tree_selector = torch.arange(N_tr, dtype=torch.long) # [0,1,2,...]

# here we use "block-diagonal" selection on second dimension
lambda_ = torch.einsum(
        "bi, bki -> bk", x, M.keys[current_nodes, tree_selector]
    )

# optional normalization
# Note: softmax would always give positive lambdas / branch to the right
# and thus "kill" part of the tree, use L2

y += torch.einsum(
    "bk, bkj -> bj", lambda_, M.values[current_nodes, tree_selector]
)

plane_choice = (lambda_ > 0).long()

# figure out index of node in next layer to visit
current_nodes = (current_nodes * 2) + 1 + plane_choice

print(y.shape, lambda_.shape, current_nodes.shape)
print('Plane choices are independent for each tree:\n', plane_choice)

torch.Size([7, 30]) torch.Size([7, 3]) torch.Size([7, 3])
Plane choices are independent for each tree:
 tensor([[1, 0, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 0, 1],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 1]])


In [12]:
current_nodes

tensor([[2, 2, 1],
        [2, 1, 2],
        [2, 1, 2],
        [2, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [2, 2, 2]])

#### Benchmark on MNISN

In [25]:
from models.F4_net import F4_mnist
from models.F3_net import F3_net
from math import log2, floor

In [21]:
from bench_mnist import training_loop, run_test

In [18]:
in_features = 28*28
num_trees = 4
depth = int(floor(log2(in_features)))
depth_fair = depth - num_trees

In [19]:
depth, depth_fair

(9, 5)

##### F4 net test on MNIST

In [20]:
Net = F4_mnist(in_features=28*28, hidden_features=500, out_classes=10, num_trees=4)

In [22]:
training_loop(Net)

100%|██████████| 10/10 [02:58<00:00, 17.80s/it]

Finished Training





[413.70457230880857,
 207.66768068820238,
 161.6604362996295,
 141.80214098468423,
 125.46375386975706,
 113.23060313798487,
 105.03927972819656,
 96.17975116590969,
 86.58435740834102,
 80.72209519753233]

In [23]:
run_test(Net)

Accuracy of the network on the 10000 test images: 95 %


(9578, 10000)

##### FFF net test on MNIST

In [26]:
Net = F3_net()

In [27]:
training_loop(Net)

100%|██████████| 10/10 [02:31<00:00, 15.11s/it]

Finished Training





[716.391123495996,
 345.879173733294,
 277.3338041976094,
 245.78006969578564,
 220.5651059858501,
 204.77407950535417,
 189.74800446256995,
 176.06425451952964,
 170.36389563046396,
 162.86117041297257]

In [28]:
run_test(Net)

Accuracy of the network on the 10000 test images: 93 %


(9394, 10000)