In [1]:
import torch
from sqrll.sqrllm import SqrLLM, StatefulWrapper
from tqdm import tqdm
import math

model = SqrLLM(
    n_embed = 256,
    n_mem = 512,
    n_ffn = 256,
    ffn_rate = 4,
    n_layer = 6,
)
smodel = StatefulWrapper(model)

params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

try:
    model.load_state_dict(torch.load(f'./model{params}.pt'))
    print('loaded')
except:
    pass

model.eval();


params=4,468,736
loaded


In [2]:
from sqrll import sqrll, sqrllm

class CustomTracer(torch.fx.Tracer):
    def is_leaf_module(self, mod, name):
        # if isinstance(mod, sqrllm.RmsNorm):
        #     return True
        # print('isleaf', mod.__class__, name)
        return False

tracer = CustomTracer(autowrap_functions=[
    sqrll.sqrll_kernel,
    sqrllm.rms_norm,
])

graph = tracer.trace(model)

In [51]:
from torch.fx import map_arg
import operator

class Interpreter(torch.fx.Interpreter):
    inputs = {}
    outputs = {}
    weights = {}

    varmap = {}
    tempvars = []

    fwds = []

    # def map_nodes_to_values(self, args, n):
    #     def load_arg(n_arg):
    #         if n_arg in self.varmap:
    #             self.tempvars[self.varmap[n_arg]][1] -= 1
    #         else:
    #             print('load', n_arg, {k: self.tempvars[v] for k,v in self.varmap.items()})
    #         return self.env[n_arg]
    #     return map_arg(args, load_arg)

    def run_node(self, n):
        with self._set_current_node(n):

            args, kwargs = self.fetch_args_kwargs_from_env(n)
            val = getattr(self, n.op)(n.target, args, kwargs)

            # print(n.name, ':', n.op, n.target, n.args, n.kwargs)
            # print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users))

            def flatten(x):
                if isinstance(x, (list, tuple)):
                    return [a for xi in x for a in flatten(xi)]
                return [x]

            if 'call' in n.op:
                shape = tuple(val.squeeze().shape)
                n_users = len(n.users)

                for i, v in enumerate(self.tempvars):
                    if v[0] == shape and v[1] == 0:
                        self.varmap[n] = i
                        v[1] = n_users
                        break
                else:
                    self.varmap[n] = len(self.tempvars)
                    self.tempvars += [[shape, n_users]]

            # deref after alloc to be safe against aliasing

            def deref_arg(n_arg):
                if n_arg in self.varmap:
                    self.tempvars[self.varmap[n_arg]][1] -= 1
            map_arg(n.args, deref_arg)
            map_arg(n.kwargs, deref_arg)



            def argstr(n):
                return ', '.join([getattr(a,'name',str(a)) for a in n.args if a is not None])

            if n.op == 'get_attr':
                self.weights[n.name] = val.squeeze()
            elif n.op == 'placeholder':
                # TODO arbitrary nesting
                if isinstance(val, list):
                    for i, v in enumerate(val):
                        shape = tuple(v.squeeze().shape)
                        self.inputs[f'{n.name}{i}'] = shape
                else:
                    shape = tuple(val.squeeze().shape)
                    self.inputs[n.name] = shape
            elif n.op == 'output':
                out_vals = flatten(args)
                out_vars = flatten(n.args)
                for i, x in enumerate(zip(out_vals, out_vars)):
                    dst, src = x
                    shape = tuple(dst.squeeze().shape)
                    name = f'{n.name}{i}' if len(out_vals)>1 else n.name
                    self.outputs[name] = shape
                    self.fwds += [f'ml::copy({name}, {src})']
            elif n.op == 'call_method':
                if n.target == 'clone':
                    self.fwds += [f'ml::copy({n.name}, {argstr(n)})']
                elif n.target == 'detach':
                    self.fwds += [f'ml::copy({n.name}, {argstr(n)})']
                elif n.target == 'sigmoid':
                    self.fwds += [f'ml::sigmoid({n.name}, {argstr(n)})']
                else:
                    print(n.name, ':', n.op, n.target, n.args, n.kwargs)
                    print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users))
            elif n.op == 'call_function':
                if n.target == operator.add:
                    self.fwds += [f'ml::add({n.name}, {argstr(n)})']
                elif n.target == operator.mul:
                    self.fwds += [f'ml::mul({n.name}, {argstr(n)})']
                elif n.target == operator.getitem:
                    if isinstance(args[0], list):
                        src = f'{n.args[0]}{args[1]}'
                        self.fwds += [f'ml::copy({n.name}, {src})']
                    elif isinstance(args[0], torch.Tensor):
                        # idx = []
                        # for i in args[1]:
                        #     idx += [str(i)]
                        # idx = ', '.join(idx)
                        # self.fwds += [f'ml::slice({n.name}, {n.args[0]}, {idx})']
                        # XXX in sqrll 1-step mode this is a no-op slice
                        self.fwds += [f'ml::copy({n.name}, {n.args[0]})']
                    else:
                        print(n.name, ':', n.op, n.target, n.args, n.kwargs)
                        print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users))
                elif n.target == torch.nn.functional.linear:
                    self.fwds += [f'ml::linear({n.name}, {argstr(n)})']
                elif n.target == torch.nn.functional.softsign:
                    self.fwds += [f'ml::softsign({n.name}, {argstr(n)})']
                elif n.target == torch.nn.functional.embedding:
                    self.fwds += [f'ml::embedding({n.name}, {argstr(n)})']
                elif n.target == torch.nn.functional.dropout:
                    self.fwds += [f'ml::copy({n.name}, {argstr(n)})']
                elif n.target == sqrll.sqrll_kernel:
                    self.fwds += [f'ml::sqrll({n.name}, {argstr(n)})']
                elif n.target == sqrllm.rms_norm:
                    self.fwds += [f'ml::rmsnorm({n.name}, {argstr(n)})']
                else:
                    print(n.target, n.target.__name__, n.target.__qualname__, n.target.__module__)
                    print(n.name, ':', n.op, n.target, n.args, n.kwargs)
                    print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users))
            elif n.op == 'call_module':
                if n.target.__class__ == sqrllm.RmsNorm:
                    self.fwds += [f'ml::rmsnorm({n.name}, {argstr(n)})']
                else:
                    print(n.target, n.target.__class__)
                    print(n.name, ':', n.op, n.target, n.args, n.kwargs)
                    print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users))
            else:
                print(n.name, ':', n.op, n.target, n.args, n.kwargs)
                print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users))
            # print("")
            return val


inputs = torch.tensor([[0]])
_, mem = model(inputs)
mem = [torch.zeros_like(m) for m in mem]

interp = Interpreter(model, graph=graph)

out = interp.run(inputs, mem)

In [52]:
blob = []
ivars = []
fwds = []

def shapetype(shape, const=False):
    qual = ' const' if const else ''
    if len(shape) == 0:
        return f'Vec<1>{qual}'
    if len(shape) == 1:
        return f'Vec<{shape[0]}>{qual}'
    elif len(shape) == 2:
        return f'Mat<{shape[0]}, {shape[1]}>{qual}'
    else:
        assert 0, shape
    
def store(name, val, const=True):
    global blob
    global ivars
    offset = len(blob)
    blob += val.bfloat16().view(torch.uint16).flatten().tolist()
    shape = list(val.shape)
    stype = shapetype(shape, const)
    ivars += [f'{stype} {name} {{ blob, {offset} }}']
    return name, *shape

def addvar(name, shape, const=False):
    global blob
    global ivars
    stype = shapetype(shape, const)
    ivars += [f'{stype} {name} {{ Zeros{{}} }}']


for name, value in interp.weights.items():
    store(name, value, const=True)
    
for i, v in enumerate(interp.tempvars):
    name = f'tmp{i}'
    addvar(name, shape=v[0], const=False)

for k, i in interp.varmap.items():
    shape = interp.tempvars[i][0]
    stype = shapetype(shape)
    ivars += [f'{stype} & {k} = tmp{i}']

for name, shape in interp.inputs.items():
    stype = shapetype(shape)
    ivars += [f'{stype} {name} {{ Zeros{{}} }}']

for name, shape in interp.outputs.items():
    stype = shapetype(shape, const=True)
    ivars += [f'{stype} {name} {{ Zeros{{}} }}']

fwds += interp.fwds

# ivars += [f'Sampler<256> sampler {{ out.out }}']
fwds += ['int out = prev']
fwds += ['return out']


class Writer:
    def __init__(self, f):
        self.f = f
        self.indent = 0

    def __call__(self, s):
        self.f.write(' '*self.indent)
        self.f.write(s)
        return self

    def __enter__(self):
        self.f.write('\n')
        self.indent += 4
        return self
    
    def __exit__(self, *_):
        self.indent -= 4


with open('model.cpp', 'w') as f:
    w = Writer(f)
    w('#include "cgen_head.h"\n\n')

    with w(f'BFloat16Blob<{len(blob)}> blob = {{'):
        w(','.join([hex(x) for x in blob]) + '\n')
    w('};\n\n')

    with w('struct Model {\n'):
        for i in ivars:
            w(i + ';\n')

        w('\n')
        with w('int step(int prev) {'):
            for f in fwds:
                w(f + ';\n')
        w('}\n')

    w('};\n\n')
    w('Model model;\n\n')
    w('#include "cgen_main.h"\n')

In [None]:
blob = []
ivars = []
fwds = []

def store(name, val):
    global blob
    global ivars
    offset = len(blob)
    blob += val.bfloat16().view(torch.uint16).flatten().tolist()
    shape = list(val.shape)
    if len(shape) == 1:
        ivars += [f'Vec<{shape[0]}> {name} {{ blob, {offset} }}']
    elif len(shape) == 2:
        ivars += [f'Mat<{shape[0]}, {shape[1]}> {name} {{ blob, {offset} }}']
    else:
        assert 0
    return name, *shape


name, h, w = store('embed_w', model.w_in.weight);
ivars += [f'Embed<{h},{w}> embed {{ {name} }}']
fwds += ['embed(prev)']

last_out = 'embed.out'
for i, lay in enumerate(model.sqrll.blocks):

    s, w = store(f'l{i}_norm_s', lay.norm.weight)
    b, _ = store(f'l{i}_norm_b', lay.norm.bias)
    ivars += [f'VecMulAdd<{w}> l{i}_norm {{ {last_out}, {s}, {b} }}']
    fwds += [f'l{i}_norm()']

    ivars += [f'VecAdd<{w}> l{i}_res {{ {last_out}, l{i}_norm.out }}']
    fwds += [f'l{i}_res()']

    if hasattr(lay, 'ffn'):
        f = lay.ffn

        s, w = store(f'l{i}_norm2_s', f.norm.weight)
        b, _ = store(f'l{i}_norm2_b', f.norm.bias)
        ivars += [f'VecMulAdd<{w}> l{i}_norm2 {{ {last_out}, {s}, {b} }}']
        fwds += [f'l{i}_norm2()']




    last_out = f'l{i}_res.out'
    

    
m, h, w = store('out_w', model.w_out.weight)
b, _ = store('out_b', model.w_out.bias)
ivars += [f'VecMatBias<{h},{w}> out {{ {last_out}, {m}, {b} }}']
fwds += ['out()']

ivars += [f'Sampler<{w}> sampler {{ out.out }}']
fwds += ['int out = sampler()']
fwds += ['return out']


class Writer:
    def __init__(self, f):
        self.f = f
        self.indent = 0

    def __call__(self, s):
        self.f.write(' '*self.indent)
        self.f.write(s)
        return self

    def __enter__(self):
        self.f.write('\n')
        self.indent += 4
        return self
    
    def __exit__(self, *_):
        self.indent -= 4


with open('model.cpp', 'w') as f:
    w = Writer(f)
    w('#include "cgen_head.h"\n\n')

    with w(f'BFloat16Blob<{len(blob)}> blob = {{'):
        w(','.join([hex(x) for x in blob]) + '\n')
    w('};\n\n')

    with w('struct Model {\n'):
        for i in ivars:
            w(i + ';\n')

        w('\n')
        with w('int step(int prev) {'):
            for f in fwds:
                w(f + ';\n')
        w('}\n')

    w('};\n\n')
    w('Model model;\n\n')
    w('#include "cgen_main.h"\n')