In [None]:
import torch, torch.functional as F
from torch import nn, tensor
from fastcore.all import *
from math import sqrt

## FFF

In [None]:
class FFF(nn.Module):
    def __init__(self, in_dim, out_dim, h_dim, depth, tree_act=nn.LogSigmoid(), topk=None, save_probs=True):
        super().__init__()
        store_attr('save_probs,tree_act,depth')
        self.n = 2**depth
        self.topk = topk or self.n
        def uniform(shape, scale): 
            return nn.Parameter(torch.empty(shape).uniform_(-scale,scale))
        self.nodes = uniform((self.n-1, in_dim), scale=1/sqrt(in_dim))
        self.w1 = uniform((self.n, h_dim, in_dim), scale=1/sqrt(in_dim))
        self.w2 = uniform((self.n, out_dim, h_dim), scale=1/sqrt(h_dim))
        self.act = nn.ReLU()
        self.t = self.init_t_()
        self.s = self.init_s_()

    def init_t_(self):
        tree, res = torch.eye(self.n), []
        for _ in range(self.depth): 
            res.append(tree)
            tree = tree.view(self.n, -1, 2).sum(-1)
        return nn.Parameter(torch.cat(list(reversed(res)),dim=1), False)

    def init_s_(self):
        s = torch.eye(self.n-1)
        return nn.Parameter(torch.stack([s,-s], dim=2).view(self.n-1,2*(self.n-1)), False)

    def forward(self, x):
        bs = x.shape[0]
        if self.training:
            z = x.matmul(self.nodes.T).matmul(self.s)
            z = self.tree_act(z).matmul(self.t.T)
            if self.save_probs: self.probs = torch.softmax(z,-1)
            probs, indices = z.topk(self.topk)
            probs = torch.softmax(probs, dim=-1)
        else:
            indices = torch.zeros(bs, dtype=torch.long, device=x.device)
            for _ in range(self.depth):
                indices = indices*2 + 1 + (torch.einsum("b i, b i -> b", x, self.nodes[indices])<0).long()
            indices = indices[:,None] - self.n+1
            probs = torch.ones(bs,1)
        x = torch.einsum('bx, bkyx -> bky', x, self.w1[indices])
        x = torch.einsum('bkx, bkyx -> bky', self.act(x), self.w2[indices])
        return torch.einsum('bky, bk -> by', x, probs) if probs.shape[1]>1 else x[:,0]

In [None]:
class Experts(nn.ModuleList):
    def forward(self, x, routing_ws, selected_exps):
        mask = F.one_hot(selected_exps, num_classes=len(self)).permute(2, 1, 0)
        for i in range(len(self)):
            idx, top_x = torch.where(mask[i])
            if top_x.shape[0] == 0: continue
            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            res = self[i](x[top_x_list]) * routing_ws[top_x_list, idx.tolist(), None]
            if 'out' not in locals(): out = torch.zeros((x.shape[0],*res.shape[1:]), device=x.device)
            out.index_add_(0, top_x, res)
        return out
    
def binary(x, bits):
    'converts integer vector into binary with number of `bits`'
    mask = 2**torch.arange(bits, device=x.device, dtype=x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()

class FFFSeq(nn.Module):
    def __init__(self, in_dim, out_dim, h_dim, depth, tree_act=nn.LogSigmoid(), topk=None, save_probs=True):
        super().__init__()
        store_attr('save_probs,tree_act,depth')
        self.n = 2**depth
        self.topk = topk or self.n
        def uniform(shape, scale): 
            return nn.Parameter(torch.empty(shape).uniform_(-scale,scale))
        self.nodes = uniform((self.n-1, in_dim), scale=1/sqrt(in_dim))
        self.leaves = Experts(
            nn.Sequential(
            nn.Linear(in_dim, h_dim, bias=False), nn.ReLU(),
            nn.Linear(h_dim, out_dim, bias=False)) for _ in range(self.n))
        self.t = self.init_t_()
        self.s = self.init_s_()

    def init_t_(self):
        tree, res = torch.eye(self.n), []
        for _ in range(self.depth): 
            res.append(tree)
            tree = tree.view(self.n, -1, 2).sum(-1)
        return nn.Parameter(torch.cat(list(reversed(res)),dim=1), False)

    def init_s_(self):
        s = torch.eye(self.n-1)
        return nn.Parameter(torch.stack([s,-s], dim=2).view(self.n-1,2*(self.n-1)), False)

    def forward(self, x):
        bs = x.shape[0]
        if self.training:
            z = x.matmul(self.nodes.T).matmul(self.s)
            z = self.tree_act(z).matmul(self.t.T)
            if self.save_probs: self.probs = torch.softmax(z,-1)
            probs, indices = z.topk(self.topk)
            probs = torch.softmax(probs, dim=-1)
        else:
            indices = torch.zeros(bs, dtype=torch.long, device=x.device)
            for _ in range(self.depth):
                indices = indices*2 + 1 + (torch.einsum("b i, b i -> b", x, self.nodes[indices])<0).long()
            indices = indices[:,None] - self.n+1
            probs = torch.ones(bs,1)
        return self.leaves(x, probs, indices)

In [None]:
from fastai.callback.wandb import *
from FastFF.utils import *
from fastai.vision.all import *

bs = 512
params = dict(
    in_dim=28*28,
    out_dim=10,
    h_dim=16,
    depth=3,
    topk=2)

fff = FFF(**params)
cbs = [ShowGraphCallback(), ProbsDistrCB(), GetGradCB([fff.nodes, fff.w1])]
dls = get_mnist_dls(bs)
Learner(dls, fff, loss_func=F.cross_entropy, metrics=accuracy, cbs=cbs).fit_one_cycle(5, lr_max=7e-3)

In [None]:
bs = 512
params = dict(
    in_dim=28*28,
    out_dim=10,
    h_dim=16,
    depth=3,
    topk=2)

fff = FFFSeq(**params)
cbs = [ShowGraphCallback(), ProbsDistrCB()]
dls = get_mnist_dls(bs)
Learner(dls, fff, loss_func=F.cross_entropy, metrics=accuracy, cbs=cbs).fit_one_cycle(5, lr_max=7e-3)