In [1]:
import torch, torch.nn.functional as F, math, wandb, os
from torch import nn, tensor
from fastcore.all import *
from math import sqrt
import matplotlib.pyplot as plt
from fastprogress import progress_bar, master_bar
from torch.utils import benchmark

os.environ["WANDB_SILENT"]='true'

In [2]:
def benchmark_model(model, x, cuda=False):
    if cuda: model,x = model.cuda(), x.cuda()
    return benchmark.Timer(
            stmt='model(x)',
            globals=locals(),
            label=model._get_name(),
            description='time',
        ).blocked_autorange(min_run_time=0.2)

def benchmark_with_params(func, x, params: list, cuda=False):
    '''func must accept single parameter from params and return model for `benchmark_model`'''
    results= L((benchmark_model(func(p), x, cuda)) for p in progress_bar(params, parent=globals().get('mb',None)))
    # take mean and convert to ms
    return results.map(lambda x: x.mean*1e3)

def keys_to_str(dict, keys): 
    return ' '.join(f'{k}={v}' for k,v in dict.items() if k in keys)

# Models

### MoE with sequential expert computations

In [3]:
class Experts(nn.ModuleList):
    def forward(self, x, routing_ws, selected_exps):
        mask = F.one_hot(selected_exps, num_classes=len(self)).permute(2, 1, 0)
        for i in range(len(self)):
            idx, top_x = torch.where(mask[i])
            if top_x.shape[0] == 0: continue
            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            res = self[i](x[top_x_list]) * routing_ws[top_x_list, idx.tolist(), None]
            if 'out' not in locals(): out = torch.zeros((x.shape[0],*res.shape[1:]), device=x.device)
            out.index_add_(0, top_x, res)
        return out

class MoE(nn.Module):
    def __init__(self, in_dim, out_dim, h_dim, n_exp=4, topk=2, act=nn.ReLU):
        super().__init__()
        self.topk=topk
        self.gate = nn.Linear(in_dim, n_exp, bias=False)
        self.experts = Experts(
            nn.Sequential(
            nn.Linear(in_dim, h_dim, bias=False), act(),
            nn.Linear(h_dim, out_dim, bias=False)) for _ in range(n_exp))
    
    def forward(self,x):
        logits = self.gate(x)
        probs = F.softmax(logits, dim=1)
        probs, selected_exps = torch.topk(probs, self.topk, dim=-1)
        probs /= probs.sum(dim=-1, keepdim=True)
        return self.experts(x, probs, selected_exps)

### MoE with parrallel expert computation by materializing matrix of size b\*k\*x\*y

In [4]:
class MoeEinops(nn.Module):
    def __init__(self, in_dim, out_dim, h_dim, n_exp=4, topk=2, act=nn.ReLU):
        super().__init__()
        self.topk = topk
        def init_uniform(shape, scale):
            return nn.Parameter(torch.empty(shape).uniform_(-scale, scale))
        self.gate = init_uniform((n_exp, in_dim), 1/sqrt(in_dim))
        self.w1 = init_uniform((n_exp, h_dim, in_dim), 1/sqrt(in_dim))
        self.w2 = init_uniform((n_exp, out_dim, h_dim), 1/sqrt(h_dim))
        self.act = act()
        
    def forward(self, x):
        probs, indices = torch.matmul(x, self.gate.T).topk(self.topk)
        probs = torch.softmax(probs, dim=-1)
        x = torch.einsum('bx,bkyx -> bky', x, self.w1[indices])
        x = torch.einsum('bkx,bkyx -> bky', self.act(x), self.w2[indices])
        return torch.einsum('bky,bk -> by', x, probs)

## topk=1, similar speed to inference

In [5]:
bs=32
h_dims = [1,2,4,8,16,32,64,128,256,512,1024]
n_exps = [1,2,3,4,8,16,32,64,128]
params = dict(
    in_dim=28*28,
    out_dim=10,
    n_exp=1,
    topk=1)
x = torch.randn(bs, 28*28)

In [6]:
def do_benchmark(topk=1, cuda=False):
    global mb
    mb = master_bar(n_exps, total=len(n_exps))
    for n_exp in mb:
        params['n_exp'] = n_exp
        params['topk'] = min(topk, n_exp)
        models = {"MoE einops": lambda x: MoeEinops(h_dim=x, **params),
                "MoE sequential":  lambda x: MoE(h_dim=x, **params)}
        results = {k:[] for k in models}
        for n,m in models.items():
            dev = 'cuda' if cuda else 'cpu'
            wandb.init(project='FFF', config = params | {'model':n}, name=f"{dev} {keys_to_str(params, ['n_exp','topk'])} {n}")
            results[n] =  benchmark_with_params(m, x, h_dims, cuda)
            for h,t in zip(h_dims, results[n]): wandb.log({"time": t, "h_dim": h})
            wandb.finish()

CPU, then CUDA

In [None]:
do_benchmark()

In [None]:
do_benchmark(cuda=True)

## topk=n, similar to training

In [None]:
do_benchmark(topk=1024)

In [None]:
do_benchmark(topk=1024, cuda=True)