In [None]:
import torch, torch.nn.functional as F, math, wandb, os
from torch import nn, tensor
from fastcore.all import *
from math import sqrt
import matplotlib.pyplot as plt
from fastprogress import progress_bar, master_bar
from torch.utils import benchmark

os.environ["WANDB_SILENT"]='true'

In [None]:
def benchmark_model(model, x, cuda=False):
    if cuda: model,x = model.cuda(), x.cuda()
    return benchmark.Timer(
            stmt='model(x)',
            globals=locals(),
            label=model._get_name(),
            description='time',
        ).blocked_autorange(min_run_time=0.5)

def benchmark_with_params(func, x, params: list, cuda=False):
    '''func must accept single parameter from params and return model for `benchmark_model`'''
    results= L((benchmark_model(func(p), x, cuda)) for p in progress_bar(params, parent=globals().get('mb',None)))
    # take mean and convert to ms
    return results.map(lambda x: x.mean*1e3)

def keys_to_str(dict, keys): 
    return ' '.join(f'{k}={v}' for k,v in dict.items() if k in keys)

# Models

### MoE with sequential expert computations

In [None]:
class Experts(nn.ModuleList):
    def forward(self, x, routing_ws, selected_exps):
        mask = F.one_hot(selected_exps, num_classes=len(self)).permute(2, 1, 0)
        for i in range(len(self)):
            idx, top_x = torch.where(mask[i])
            if top_x.shape[0] == 0: continue
            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            res = self[i](x[top_x_list]) * routing_ws[top_x_list, idx.tolist(), None]
            if 'out' not in locals(): out = torch.zeros((x.shape[0],*res.shape[1:]), device=x.device)
            out.index_add_(0, top_x, res)
        return out

class MoE(nn.Module):
    def __init__(self, in_dim, out_dim, h_dim, n_exp=4, topk=2, act=nn.ReLU):
        super().__init__()
        self.topk=topk
        self.gate = nn.Linear(in_dim, n_exp, bias=False)
        self.experts = Experts(
            nn.Sequential(
            nn.Linear(in_dim, h_dim, bias=False), act(),
            nn.Linear(h_dim, out_dim, bias=False)) for _ in range(n_exp))
    
    def forward(self,x):
        logits = self.gate(x)
        probs = F.softmax(logits, dim=1)
        probs, selected_exps = torch.topk(probs, self.topk, dim=-1)
        probs /= probs.sum(dim=-1, keepdim=True)
        return self.experts(x, probs, selected_exps)

### MoE with parrallel expert computation by materializing matrix of size b\*k\*x\*y

In [None]:
class MoeEinops(nn.Module):
    def __init__(self, in_dim, out_dim, h_dim, n_exp=4, topk=2, act=nn.ReLU):
        super().__init__()
        self.topk = topk
        def init_uniform(shape, scale):
            return nn.Parameter(torch.empty(shape).uniform_(-scale, scale))
        self.gate = init_uniform((n_exp, in_dim), 1/sqrt(in_dim))
        self.w1 = init_uniform((n_exp, h_dim, in_dim), 1/sqrt(in_dim))
        self.w2 = init_uniform((n_exp, out_dim, h_dim), 1/sqrt(h_dim))
        self.act = act()
        
    def forward(self, x):
        probs, indices = torch.matmul(x, self.gate.T).topk(self.topk)
        probs = torch.softmax(probs, dim=-1)
        x = torch.einsum('bx,bkyx -> bky', x, self.w1[indices])
        x = torch.einsum('bkx,bkyx -> bky', self.act(x), self.w2[indices])
        return torch.einsum('bky,bk -> by', x, probs)

In [None]:
def get_vram():
    free = torch.cuda.mem_get_info()[0] / 1024 ** 3
    total = torch.cuda.mem_get_info()[1] / 1024 ** 3
    total_cubes = 24
    free_cubes = int(total_cubes * free / total)
    return f'VRAM: {total - free:.2f}/{total:.2f}GB\t VRAM:[' + (
            total_cubes - free_cubes) * '▮' + free_cubes * '▯' + ']'

## topk=1, similar speed to inference

In [None]:
bs=32
h_dims = [1,2,4,8,16,32,64,128,256,512,1024]
n_exps = [1,2,3,4,8,16,32,64,128]
params = dict(
    in_dim=28*28,
    out_dim=10,
    n_exp=128,
    topk=128)
x = torch.randn(bs, 28*28)

In [None]:
with torch.profiler.profile() as p:
    _, indices = torch.matmul(x, self.gate.T).topk(params['topk'])
    
p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)


In [None]:
import time
def profile_model(model, x, out_dir = './tracing'):
    out_dir = Path(out_dir); out_dir.mkdir(exist_ok=True, parents=True)
    with torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) as p:
        for _ in range(5): 
            with torch.no_grad(): model(x)
    prefix = f"{int(time.time())}"
    p.export_chrome_trace(str(out_dir/f'{model._get_name()}_{prefix}_trace.json.gz'))
    p.export_memory_timeline(str(out_dir/f'{model._get_name()}_{prefix}_memory.html'))

In [None]:
x = torch.randn(32, 28*28).cuda()

moe1 = MoeEinops(h_dim=1024, **params).cuda()
moe2 = MoE(h_dim=1024, **params).cuda()

In [None]:
moe1(x)

VRAM: 1.87/6.00GB	 VRAM:[▮▮▮▮▮▮▮▮▯▯▯▯▯▯▯▯▯▯▯▯▯▯▯▯]
VRAM: 1.87/6.00GB	 VRAM:[▮▮▮▮▮▮▮▮▯▯▯▯▯▯▯▯▯▯▯▯▯▯▯▯]
VRAM: 6.00/6.00GB	 VRAM:[▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮▮]


tensor([[ 0.0687,  0.0018,  0.0147,  0.0180,  0.0042, -0.0061,  0.0039, -0.0313,
         -0.0343,  0.0185],
        [ 0.0147,  0.0136,  0.0317, -0.0256,  0.0145, -0.0191, -0.0766, -0.0188,
         -0.0352,  0.0211],
        [ 0.0219,  0.0048,  0.0178,  0.0030,  0.0388, -0.0516, -0.0235, -0.0017,
          0.0270,  0.0173],
        [-0.0002, -0.0012, -0.0017,  0.0201, -0.0248, -0.0227, -0.0117, -0.0080,
         -0.0079, -0.0459],
        [-0.0037,  0.0371,  0.0374, -0.0031, -0.0044, -0.0186,  0.0012,  0.0122,
          0.0044, -0.0092],
        [-0.0230,  0.0049, -0.0164, -0.0175,  0.0309, -0.0090, -0.0230, -0.0600,
          0.0093,  0.0154],
        [-0.0348,  0.0199,  0.0182, -0.0442,  0.0100, -0.0137, -0.0299, -0.0215,
         -0.0003, -0.0371],
        [ 0.0299,  0.0113,  0.0158, -0.0520,  0.0088,  0.0127,  0.0525, -0.0184,
         -0.0172,  0.0123],
        [-0.0037, -0.0178,  0.0240,  0.0033, -0.0014, -0.0332,  0.0384, -0.0072,
         -0.0680,  0.0155],
        [ 0.0119, -

In [None]:
benchmark_model(moe1, x, True).mean*1e3

1571.7456319998746

In [None]:
benchmark_model(moe2, x, True).mean*1e3

204.85662466671783

In [None]:
profile_model(moe2, x)

In [None]:
profile_model(moe1, x)

In [None]:
def do_benchmark(topk=1, cuda=False):
    global mb
    mb = master_bar(n_exps, total=len(n_exps))
    for n_exp in mb:
        params['n_exp'] = n_exp
        params['topk'] = min(topk, n_exp)
        models = {"MoE einops": lambda x: MoeEinops(h_dim=x, **params),
                "MoE sequential":  lambda x: MoE(h_dim=x, **params)}
        results = {k:[] for k in models}
        for n,m in models.items():
            dev = 'cuda' if cuda else 'cpu'
            wandb.init(project='FFF', config = params | {'model':n}, name=f"{dev} {keys_to_str(params, ['n_exp','topk'])} {n}")
            results[n] =  benchmark_with_params(m, x, h_dims, cuda)
            for h,t in zip(h_dims, results[n]): wandb.log({"time": t, "h_dim": h})
            wandb.finish()

CPU, then CUDA

In [None]:
do_benchmark()

In [None]:
do_benchmark(cuda=True)

## topk=n, similar to training

In [None]:
do_benchmark(topk=1024)

In [None]:
do_benchmark(topk=1024, cuda=True)