In [119]:
import torch
from sqrll.sqrllm import SqrLLM, StatefulWrapper
from tqdm import tqdm
import math

model = SqrLLM(
    n_embed = 256,
    n_mem = 512,
    n_ffn = 256,
    ffn_rate = 4,
    n_layer = 6,
)
# model = SqrLLM(
#     n_embed = 4,
#     n_mem = 4,
#     n_ffn = 4,
#     ffn_rate = 99,
#     n_layer = 1,
# )
smodel = StatefulWrapper(model)

params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

try:
    model.load_state_dict(torch.load(f'./model{params}.pt'))
    print('loaded')
except:
    pass

model.eval();


params=4,468,736
loaded


In [120]:
from sqrll import sqrll, sqrllm

class CustomTracer(torch.fx.Tracer):
    def is_leaf_module(self, mod, name):
        # if isinstance(mod, sqrllm.RmsNorm):
        #     return True
        # print('isleaf', mod.__class__, name)
        return False

tracer = CustomTracer(autowrap_functions=[
    sqrll.sqrll_kernel,
    sqrllm.rms_norm,
])

graph = tracer.trace(model)

In [121]:
from torch.fx import map_arg
import operator

def shape_dtype(shape, ref=False, const=False):
    shape = ','.join(str(d) for d in shape)
    dtype = f'ml::tensor<{shape}>'
    if ref:
        dtype += '&'
    if const:
        dtype = 'const '+dtype
    return dtype

def val_dtype(val, ref=False, const=False):
    if isinstance(val, (list, tuple)):
        subtypes = [val_dtype(a, ref, const) for a in val]
        subtypes = ','.join(subtypes)
        return f'std::tuple<{subtypes}>'
    return shape_dtype(tuple(val.shape), ref, const)

def flatten(val):
    if isinstance(val, (list, tuple)):
        return [a for vx in val for a in flatten(vx)]
    return [val]

        
class Interpreter(torch.fx.Interpreter):
    inputs = {}
    output_type = None
    output_ref = None
    weights = {}

    tmp_vars = {}
    node_vars = {}

    fwds = []

    def get_tmp(self, node, shape):
        refcount = len(node.users)
        for name, info in self.tmp_vars.items():
            tshape, tref = info
            if shape == tshape and tref == 0:
                self.node_vars[node] = name
                info[1] = refcount
                # print('alloc', node.name, name, refcount)
                return name
        name = f'tmp{len(self.tmp_vars)}'
        self.tmp_vars[name] = [shape, refcount]
        self.node_vars[node] = name
        # print('alloc', node.name, name, refcount)
        return name

    def deref(self, node):
        if node is None:
            return 'nullptr'
        if isinstance(node, slice):
            if node == slice(None, None, None):
                return 'slice<>()'
            else:
                raise ValueError('unsupported '+str(node))
        if not isinstance(node, torch.fx.Node):
            return str(node)
        if node not in self.node_vars:
            # print('deref', node, 'UNDEFINED!!!')
            return node.name
        name = self.node_vars[node]
        if name in self.tmp_vars:
            self.tmp_vars[name][1] -= 1
            # print('deref', node.name, name, self.tmp_vars[name][1])
            assert self.tmp_vars[name][1] >= 0
        return name

    def nested_refstr(self, arg):
        if isinstance(arg, (list, tuple)):
            nest = [self.nested_refstr(a) for a in arg]
            return '{'+', '.join(nest)+'}'
        return self.node_vars[arg]

    def alias(self, node, src):
        src = self.node_vars[src]
        self.node_vars[node] = src
        self.tmp_vars[src][1] += len(node.users) - 1


    def run_node(self, n):
        with self._set_current_node(n):

            args, kwargs = self.fetch_args_kwargs_from_env(n)
            val = getattr(self, n.op)(n.target, args, kwargs)

            if n.op == 'placeholder':
                self.inputs[n.name] = val_dtype(val)
                self.node_vars[n] = n.name
            elif n.op == 'get_attr':
                self.weights[n.name] = (val_dtype(val, const=True), val)
                self.node_vars[n] = n.name
            elif n.op == 'call_function' or n.op == 'call_method':

                fname = n.target
                if 'fun' in n.op:
                    fname = fname.__name__

                if n.target == operator.getitem:
                    if isinstance(args[0], (list, tuple)):
                        src = self.node_vars[n.args[0]]
                        self.node_vars[n] = f'get<{args[1]}>({src})'
                        return val
                
                no_ops = [
                    'detach',
                    'clone',
                    torch.nn.functional.dropout,
                ]
                if n.target in no_ops:
                    self.alias(n, n.args[0])
                    return val

                out_var = self.get_tmp(n, tuple(val.shape))
                flat_arg_nodes = flatten(n.args)
                flat_arg_vars = [self.deref(n) for n in flat_arg_nodes]
                fargs = ', '.join(flat_arg_vars)

                self.fwds += [f'{out_var} = {fname}({fargs})']
            elif n.op == 'call_module':
                raise ValueError('call_module unsupported')
            elif n.op == 'output':
                self.output_type = val_dtype(val, ref=True)
                self.output_ref = self.nested_refstr(n.args[0])
            else:
                print(n.name, ':', n.op, n.target, n.args, n.kwargs)
                print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users))

            print(n.name, ':', n.op, n.target, n.args, n.kwargs)
            print('->', val)
            return val


inputs = torch.tensor([[ord('a')]])
_, mem = model(inputs)
mem = [torch.zeros_like(m) for m in mem]

interp = Interpreter(model, graph=graph)

out = interp.run(inputs, mem)


x : placeholder x () {}
-> tensor([[97]])
mem : placeholder mem (None,) {}
-> [tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [122]:
blob = []
ivars = []
fwds = []

for name, info in interp.weights.items():
    dtype, val = info
    offset = len(blob)
    blob += val.bfloat16().view(torch.uint16).flatten().tolist()
    ivars += [f'static {dtype} {name} {{ blob+{offset} }}']

fwds += interp.fwds

class Writer:
    def __init__(self, f):
        self.f = f
        self.indent = 0

    def __call__(self, s):
        self.f.write(' '*self.indent)
        self.f.write(s)
        return self

    def __enter__(self):
        self.f.write('\n')
        self.__call__('{\n')
        self.indent += 4
        return self
    
    def __exit__(self, *_):
        self.indent -= 4
        self.__call__('}')

class_name = 'Model'

with open('model.cpp', 'w') as f:
    w = Writer(f)
    w('#include "cgen_head.h"\n\n')

    w('// weight initializers"\n')
    with w(f'static const ml::bfloat16 blob[{len(blob)}] = '):
        w(','.join([hex(x) for x in blob]) + '\n')
    w(';\n\n')
    
    w('// weight tensors\n')
    for i in ivars:
        w(i + ';\n')
    
    w('\n\n')

    with w(f'struct {class_name}'):

        w('// inputs\n')
        for name, dtype in interp.inputs.items():
            w(f'{dtype} {name};\n')
            
        w('// tmp vars\n')
        for name, info in interp.tmp_vars.items():
            dtype = shape_dtype(info[0])
            w(f'{dtype} {name};\n')

        w('\n')
        w(f'{interp.output_type}\n')
        with w('operator()()'):
            
            w('using std::get;\n')
            w('using namespace ml;\n')
                
            for f in fwds:
                w(f + ';\n')

            # w(f'{interp.output_type} output {interp.output_ref};\n')
            # w(f'return output;\n')
            w(f'return {interp.output_ref};\n')

        w('\n')

    w(';\n\n')
    w('#include "cgen_main.h"\n')

In [None]:
blob = []
ivars = []
fwds = []

def store(name, val):
    global blob
    global ivars
    offset = len(blob)
    blob += val.bfloat16().view(torch.uint16).flatten().tolist()
    shape = list(val.shape)
    if len(shape) == 1:
        ivars += [f'Vec<{shape[0]}> {name} {{ blob, {offset} }}']
    elif len(shape) == 2:
        ivars += [f'Mat<{shape[0]}, {shape[1]}> {name} {{ blob, {offset} }}']
    else:
        assert 0
    return name, *shape


name, h, w = store('embed_w', model.w_in.weight);
ivars += [f'Embed<{h},{w}> embed {{ {name} }}']
fwds += ['embed(prev)']

last_out = 'embed.out'
for i, lay in enumerate(model.sqrll.blocks):

    s, w = store(f'l{i}_norm_s', lay.norm.weight)
    b, _ = store(f'l{i}_norm_b', lay.norm.bias)
    ivars += [f'VecMulAdd<{w}> l{i}_norm {{ {last_out}, {s}, {b} }}']
    fwds += [f'l{i}_norm()']

    ivars += [f'VecAdd<{w}> l{i}_res {{ {last_out}, l{i}_norm.out }}']
    fwds += [f'l{i}_res()']

    if hasattr(lay, 'ffn'):
        f = lay.ffn

        s, w = store(f'l{i}_norm2_s', f.norm.weight)
        b, _ = store(f'l{i}_norm2_b', f.norm.bias)
        ivars += [f'VecMulAdd<{w}> l{i}_norm2 {{ {last_out}, {s}, {b} }}']
        fwds += [f'l{i}_norm2()']




    last_out = f'l{i}_res.out'
    

    
m, h, w = store('out_w', model.w_out.weight)
b, _ = store('out_b', model.w_out.bias)
ivars += [f'VecMatBias<{h},{w}> out {{ {last_out}, {m}, {b} }}']
fwds += ['out()']

ivars += [f'Sampler<{w}> sampler {{ out.out }}']
fwds += ['int out = sampler()']
fwds += ['return out']


class Writer:
    def __init__(self, f):
        self.f = f
        self.indent = 0

    def __call__(self, s):
        self.f.write(' '*self.indent)
        self.f.write(s)
        return self

    def __enter__(self):
        self.f.write('\n')
        self.indent += 4
        return self
    
    def __exit__(self, *_):
        self.indent -= 4


with open('model.cpp', 'w') as f:
    w = Writer(f)
    w('#include "cgen_head.h"\n\n')

    with w(f'BFloat16Blob<{len(blob)}> blob = {{'):
        w(','.join([hex(x) for x in blob]) + '\n')
    w('};\n\n')

    with w('struct Model {\n'):
        for i in ivars:
            w(i + ';\n')

        w('\n')
        with w('int step(int prev) {'):
            for f in fwds:
                w(f + ';\n')
        w('}\n')

    w('};\n\n')
    w('Model model;\n\n')
    w('#include "cgen_main.h"\n')