In [39]:
import sys
import requests
from bs4 import BeautifulSoup
import unicodedata
import itertools
import itertools
import pprint
import re

URL = "https://arm-software.github.io/acle/mve_intrinsics/mve.html"
page = requests.get(URL)

soup = BeautifulSoup(page.content, "html.parser")
intrinsics = [unicodedata.normalize('NFKD', code.text) for code in soup.select("tbody > tr > td:nth-child(1) > code")]
archs =  [unicodedata.normalize('NFKD', code.text) for code in soup.select("tbody > tr > td:nth-child(5) > code")]

decls = [' '.join(entry.split()).replace('( ', '(') for entry in intrinsics]

decl_archs = list(zip(decls, archs))


In [40]:
print(decl_archs)

[('float16x8_t [__arm_]vcreateq_f16(uint64_t a, uint64_t b)', 'MVE'), ('float32x4_t [__arm_]vcreateq_f32(uint64_t a, uint64_t b)', 'MVE'), ('int8x16_t [__arm_]vcreateq_s8(uint64_t a, uint64_t b)', 'MVE'), ('int16x8_t [__arm_]vcreateq_s16(uint64_t a, uint64_t b)', 'MVE'), ('int32x4_t [__arm_]vcreateq_s32(uint64_t a, uint64_t b)', 'MVE'), ('int64x2_t [__arm_]vcreateq_s64(uint64_t a, uint64_t b)', 'MVE'), ('uint8x16_t [__arm_]vcreateq_u8(uint64_t a, uint64_t b)', 'MVE'), ('uint16x8_t [__arm_]vcreateq_u16(uint64_t a, uint64_t b)', 'MVE'), ('uint32x4_t [__arm_]vcreateq_u32(uint64_t a, uint64_t b)', 'MVE'), ('uint64x2_t [__arm_]vcreateq_u64(uint64_t a, uint64_t b)', 'MVE'), ('uint8x16_t [__arm_]vddupq[_n]_u8(uint32_t a, const int imm)', 'MVE'), ('uint16x8_t [__arm_]vddupq[_n]_u16(uint32_t a, const int imm)', 'MVE'), ('uint32x4_t [__arm_]vddupq[_n]_u32(uint32_t a, const int imm)', 'MVE'), ('uint8x16_t [__arm_]vddupq[_wb]_u8(uint32_t *a, const int imm)', 'MVE'), ('uint16x8_t [__arm_]vddupq[_wb

In [41]:
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
    args = [iter(iterable)] * n
    if incomplete == 'fill':
        return itertools.zip_longest(*args, fillvalue=fillvalue)
    if incomplete == 'strict':
        return zip(*args, strict=True)
    if incomplete == 'ignore':
        return zip(*args)
    else:
        raise ValueError('Expected fill, strict, or ignore')

In [79]:
lengths = ["8", "16", "32", "64", "128", "256"]

def prefix_all(prefix, list):
  return [prefix + elem for elem in list]

types = {
  *prefix_all('u', lengths), # unsigned integers
  *prefix_all('s', lengths), # signed integers
  *prefix_all('f', lengths), # floating point
  *prefix_all('p', lengths), # polynomial
}


class MVEIdent:
  suffixes = ['q', # Q-register
              'w', # Widening, i.e one of the inputs is longer than the other and the output will be placed in a longer type
              'l'] # Long, i.e. the output type is twice as long as the inputs

  def __init__(self, func_name):
    self.scalar = False
    self.merging = False
    self.dontcare = False
    self.zero = False
    self.lane = False
    self.predicated = False
    self.write_back = False
    func_name = re.sub(r'\[.*?\]', '', func_name)
    parts = func_name.lstrip('_').split('_')

    if parts[0][0] == 'v':
      parts[0] = parts[0][1:] # remove first char

    if parts[0][-1] in MVEIdent.suffixes:
      self.suffix = parts[0][-1]
      if self.suffix == 'q':
        parts[0] = parts[0][:-1]
    else:
      self.suffix = None


    name_types = [type for type in parts if type in types]
    parts = [part for part in parts if part not in types]

    if len(name_types) == 1:
      [out_type, in_type] =  name_types * 2
    elif len(name_types) == 2:
      [out_type, in_type] = name_types
    else:
      [out_type, in_type] = [None, None]

    self.types = {'out' : out_type, 'in': in_type}

    if 'n' in parts:
      parts.remove('n')
      self.scalar = True
    if 'm' in parts:
      parts.remove('m')
      self.merging = True
    if 'x' in parts:
      parts.remove('x')
      self.dontcare = True
    if 'z' in parts:
      parts.remove('z')
      self.zero = True
    if 'p' in parts:
      parts.remove('p')
      self.predicated = True
    if 'wb' in parts:
      parts.remove('wb')
      self.write_back = True
    if 'lane' in parts:
      parts.remove('lane')
      self.lane = True

    self.name = '_'.join(parts)


  def __repr__(self):
    return f'NEON:{{ name: "{self.name}", suffix: {self.suffix}, types: {self.types}}}'

In [None]:
name_map = {
    # NEON
    "add": "add",
    "addl": "add_long",
    "addw": "add",
    "hadd": "add_halve",
    "rhadd": "add_halve_round",
    "addhn": "add_narrow_high",
    "raddhn": "add_round_narrow_high",
    "qadd": "add_saturate",
    "mul": "multiply",
    "mla": "multiply_add",
    "mls": "multiply_subtract",
    "mlal": "multiply_add_long",
    "mlsl": "multiply_subtract_long",
    "fma": "multiply_add_fused",
    "fms": "multiply_subtract_fused",
    "qdmulh": "multiply_double_saturate_high",
    "qrdmulh": "multiply_double_round_saturate_high",
    "qdmull": "multiply_double_saturate_long",
    "qdmlal": "multiply_double_add_saturate_long",
    "qdmlsl": "multiply_double_subtract_saturate_long",
    "mull": "multiply_long",
    "sub": "subtract",
    "subl": "subtract_long",
    "subw": "subtract",
    "hsub": "subtract_high",
    "subhn": "subtract_narrow_high",
    "rsubhn": "subtract_round_narrow_high",
    "qsub": "subtract_saturate",
    "abd": "subtract_abs",
    "abdl": "subtract_abs_long",
    "aba": "subtract_abs_add",
    "abal": "subtract_abs_add",
    "abs": "abs",
    "qabs": "abs_saturate",
    "max": "max",
    "min": "min",
    "recpe": "reciprocal_estimate",
    "recps": "reciprocal_step",
    "rsqrte": "reciprocal_sqrt_estimate",
    "rsqrts": "reciprocal_sqrt_step",
    "padd": "pairwise_add",
    "paddl": "pairwise_add_long",
    "padal": "pairwise_add_accumulate_long",
    "pmax": "pairwise_max",
    "pmin": "pairwise_min",
    "ce": "equal",
    "ceq": "equal",
    "cge": "greater_than_or_equal",
    "cle": "less_than_or_equal",
    "cgt": "greater_than",
    "clt": "less_than",
    "cage": "absolute_greater_than_or_equal",
    "cale": "absolute_less_than_or_equal",
    "cagt": "absolute_greater_than",
    "calt": "absolute_less_than",
    "tst": "compare_test_nonzero",
    "shl": "shift_left",
    "qshl": "shift_left_saturate",
    "qshlu": "shift_left_unsigned_saturate",
    "rshl": "shift_left_round",
    "qrshl": "shift_left_round_saturate",
    "shll": "shift_left_long",
    "sli": "shift_left_insert",
    "shr": "shift_right",
    "rshr": "shift_right_round",
    "sra": "shift_right_accumulate",
    "rsra": "shift_right_accumulate_round",
    "shrn": "shift_right_narrow",
    "qshrun": "shift_right_saturate_narrow_unsigned",
    "qshrn": "shift_right_saturate_narrow",
    "qrshrun": "shift_right_unsigned_saturate_narrow",
    "qrshrn": "shift_right_saturate_narrow",
    "rshrn": "shift_right_round_saturate_narrow",
    "sri": "shift_right_insert",
    "cvt": "convert",
    "reinterpret": "reinterpret",
    "movn": "move_narrow",
    "movn_high": "move_high_narrow",
    "movl": "move_long",
    "qmovn": "move_saturate_narrow",
    "qmovun": "move_unsigned_saturate_narrow",
    "neg": "negate",
    "qneg": "negate_saturate",
    "mvn": "bitwise_not",
    "and": "bitwise_and",
    "orr": "bitwise_or",
    "eor": "bitwise_xor",
    "orn": "bitwise_or_not",
    "cls": "count_leading_sign_bits",
    "clz": "count_leading_zero_bits",
    "cnt": "count_active_bits",
    "bic": "bitwise_clear",
    "bsl": "bitwise_select",
    "create": "create",
    "dup": "duplicate",
    "mov": "move",
    "combine": "combine",
    "get_high": "get_high",
    "get_low": "get_low",
    "get": "get",
    "ext": "extract",
    "rev64": "reverse_64bit",
    "rev32": "reverse_32bit",
    "rev16": "reverse_16bit",
    "zip": "zip",
    "uzp": "unzip",
    "trn": "transpose",
    "set": "set",
    "ld1": "load1",
    "ld1_dup": "load1_duplicate",
    "ld2": "load2",
    "ld3": "load3",
    "ld4": "load4",
    "ld2_dup": "load2_duplicate",
    "ld3_dup": "load3_duplicate",
    "ld4_dup": "load4_duplicate",
    "ld1_x2": "load1_x2",
    "ld1_x3": "load1_x3",
    "ld1_x4": "load1_x4",
    "st1": "store1",
    "st2": "store2",
    "st3": "store3",
    "st4": "store4",
    "st1_x2": "store1_x2",
    "st1_x3": "store1_x3",
    "st1_x4": "store1_x4",
    "tbl1": "table_lookup1",
    "tbl2": "table_lookup2",
    "tbl3": "table_lookup3",
    "tbl4": "table_lookup4",
    "tbx1": "table_extend1",
    "tbx2": "table_extend2",
    "tbx3": "table_extend3",
    "tbx4": "table_extend4",
    # A64
    "cadd_rot270": "complex_add_rotate_270",
    "cadd_rot90": "complex_add_rotate_90",
    "cmla": "complex_multiply_add",
    "cmla_rot270": "complex_multiply_add_rotate_270",
    "cmla_rot180": "complex_multiply_add_rotate_180",
    "cmla_rot90": "complex_multiply_add_rotate_90",
    "cvta": "convert_round_to_nearest_with_ties_away_from_zero",
    "cvtm": "convert_round_toward_negative_infinity",
    "cvtn": "convert_round_to_nearest_with_ties_to_even",
    "cvtp": "convert_round_toward_positive_infinity",
    "maxnm": "max",
    "minnm": "min",
    "rnd": "round",
    "rnda": "round_to_nearest_with_ties_away_from_zero",
    "rndm": "round_toward_negative_infinity",
    "rndn": "round_to_nearest_with_ties_to_even",
    "rndp": "round_toward_positive_infinity",
    "rndx": "round_inexact",
    # Helium only
    "abav": "absolute_subtract_add",
    "adc": "add_carry",
    "adci": "add_carry_initialized",
    "addlv": "reduce_add_long",
    "addlva": "reduce_add_long",
    "addv": "reduce_add",
    "addva": "reduce_add",
    "asrl": "shift_right_long_arithmetic",
    "brsr": "bit_reverse_shift_right",
    "cmpcs": "compare ",
    "cmpeq": "",
    "cmpge": "",
    "cmpgt": "",
    "cmphi": "",
    "cmple": "",
    "cmplt": "",
    "cmpne": "",
    "ctp16": "",
    "ctp32": "",
    "ctp64": "",
    "ctp8": "",
    "ddup": "",
    "dwdup": "",
    "hcadd_rot270": "",
    "hcadd_rot90": "",
    "idup": "",
    "iwdup": "",
    "ldrb": "",
    "ldrb_gather_offset": "",
    "ldrd_gather_base": "",
    "ldrd_gather_offset": "",
    "ldrd_gather_shifted_offset": "",
    "ldrh": "",
    "ldrh_gather_offset": "",
    "ldrh_gather_shifted_offset": "",
    "ldrw": "",
    "ldrw_gather_base": "",
    "ldrw_gather_offset": "",
    "ldrw_gather_shifted_offset": "",
    "lsll": "",
    "maxa": "",
    "maxav": "",
    "maxv": "",
    "mina": "",
    "minav": "",
    "minv": "",
    "mladav": "",
    "mladava": "",
    "mladavax": "",
    "mladavx": "",
    "mlaldav": "",
    "mlaldava": "",
    "mlaldavax": "",
    "mlaldavx": "",
    "mlas": "",
    "mlsdav": "",
    "mlsdava": "",
    "mlsdavax": "",
    "mlsdavx": "",
    "mlsldav": "",
    "mlsldava": "",
    "mlsldavax": "",
    "mlsldavx": "",
    "movlb": "",
    "movlt": "",
    "movnb": "",
    "movnt": "",
    "mulh": "",
    "mullb_int": "",
    "mullb_poly": "",
    "mullt_int": "",
    "mullt_poly": "",
    "pnot": "",
    "psel": "",
    "qdmladh": "",
    "qdmladhx": "",
    "qdmlah": "",
    "qdmlash": "",
    "qdmlsdh": "",
    "qdmlsdhx": "",
    "qdmullb": "",
    "qdmullt": "",
    "qmovnb": "",
    "qmovnt": "",
    "qmovunb": "",
    "qmovunt": "",
    "qrdmladh": "",
    "qrdmladhx": "",
    "qrdmlah": "",
    "qrdmlash": "",
    "qrdmlsdh": "",
    "qrdmlsdhx": "",
    "qrshrnb": "",
    "qrshrnt": "",
    "qrshrunb": "",
    "qrshrunt": "",
    "qshl_r": "",
    "qshrnb": "",
    "qshrnt": "",
    "qshrunb": "",
    "qshrunt": "",
    "rmlaldavh": "",
    "rmlaldavha": "",
    "rmlaldavhax": "",
    "rmlaldavhx": "",
    "rmlsldavh": "",
    "rmlsldavha": "",
    "rmlsldavhax": "",
    "rmlsldavhx": "",
    "rmulh": "",
    "rshrnb": "",
    "rshrnt": "",
    "sbc": "",
    "sbci": "",
    "shl_r": "",
    "shlc": "",
    "shllb": "",
    "shllt": "",
    "shrnb": "",
    "shrnt": "",
    "sqrshr": "",
    "sqrshrl": "",
    "sqrshrl_sat48": "",
    "sqshl": "",
    "sqshll": "",
    "srshr": "",
    "srshrl": "",
    "strb": "",
    "strb_scatter_offset": "",
    "strd_scatter_base": "",
    "strd_scatter_offset": "",
    "strd_scatter_shifted_offset": "",
    "strh": "",
    "strh_scatter_offset": "",
    "strh_scatter_shifted_offset": "",
    "strw": "",
    "strw_scatter_base": "",
    "strw_scatter_offset": "",
    "strw_scatter_shifted_offset": "",
    "uninitialized": "",
    "uqrshl": "",
    "uqrshll": "",
    "uqrshll_sat48": "",
    "uqshl": "",
    "uqshll": "",
    "urshr": "",
    "urshrl": "",
}

In [81]:
import functools


type_order = [
  "uint8x8_t",
  "uint8x16_t",
  "int8x8_t",
  "int8x16_t",
  "uint16x4_t",
  "uint16x8_t",
  "int16x4_t",
  "int16x8_t",
  "uint32x2_t",
  "uint32x4_t",
  "int32x2_t",
  "int32x4_t",
  "uint64x1_t",
  "uint64x2_t",
  "uint32x2_t",
  "uint32x4_t",
  "float16x4_t",
  "float16x8_t",
  "float32x2_t",
  "float32x4_t",
  "poly8x8_t",
  "poly16x4_t",
  'mve_pred16_t',
]

class Var:
  type_map = { key:value for (value,key) in enumerate(type_order) }

  def __init__(self, string):
    components = string.split()
    self.ident = components.pop(-1)
    self.type = ' '.join(components)

  def __str__(self):
    return f"{self.type} {self.ident}"

  def __repr__(self):
    return f'Var:{{ type: "{self.type}", ident: "{self.ident}"}}'

  def __eq__(self, other):
    if self.type in Var.type_map and other.type in Var.type_map:
      return Var.type_map[self.type] == Var.type_map[other.type]
    else:
      return False

  def __lt__(self, other):
    if self.type in Var.type_map and other.type in Var.type_map:
      return Var.type_map[self.type] < Var.type_map[other.type]
    elif self.type in Var.type_map and other.type not in Var.type_map:
      return True
    else:
      return False

class Function:
  def __init__(self, decl_arch):
    (decl, arch) = decl_arch
    self.decl = decl
    in_parens = r"\((.)+\)"
    args = re.search(in_parens, decl).group().removeprefix('(').removesuffix(')')
    decl = re.sub(in_parens, '', decl).split()
    self.return_type = decl.pop(0)
    self.intrinsic = decl.pop(0)
    self.intrinsic = self.intrinsic.replace('[__arm_]', '')
    self.args = [Var(arg) for arg in args.split(',')]
    self.decoded = MVEIdent(self.intrinsic)
    self.intrinsic = re.sub(r'[\[\]]', '', self.intrinsic)
    self.archs = arch.split('/')
    if self.args[-1].type == "const int":
      self.const = self.args[-1].ident
      self.args = self.args[:-1]
    else:
      self.const = None
    self.name = name_map[self.decoded.name] if self.decoded.name in name_map else self.decoded.name
    if self.decoded.lane:
      self.name += '_lane'


  def __repr__(self):
    return f'Function:{{ intrinsic: "{self.intrinsic}", decoded: {self.decoded}, return_type: "{self.return_type}", args: "{self.args}" }}'

  def __equal__(self, other):
    var_match = [arg1 == arg2 for (arg1, arg2) in zip(self.args, other.args)]
    return functools.reduce(lambda a,b: a and b, var_match)

  def __lt__(self, other):
    for (arg1, arg2) in zip(self.args, other.args):
      if arg1 == arg2:
        continue
      else:
        return arg1 < arg2
    return False


In [82]:
funcs = [Function(decl_arch) for decl_arch in decl_archs]
import pprint
pprint.pp(funcs)

[Function:{ intrinsic: "vcreateq_f16", decoded: NEON:{ name: "create", suffix: q, types: {'out': 'f16', 'in': 'f16'}}, return_type: "float16x8_t", args: "[Var:{ type: "uint64_t", ident: "a"}, Var:{ type: "uint64_t", ident: "b"}]" },
 Function:{ intrinsic: "vcreateq_f32", decoded: NEON:{ name: "create", suffix: q, types: {'out': 'f32', 'in': 'f32'}}, return_type: "float32x4_t", args: "[Var:{ type: "uint64_t", ident: "a"}, Var:{ type: "uint64_t", ident: "b"}]" },
 Function:{ intrinsic: "vcreateq_s8", decoded: NEON:{ name: "create", suffix: q, types: {'out': 's8', 'in': 's8'}}, return_type: "int8x16_t", args: "[Var:{ type: "uint64_t", ident: "a"}, Var:{ type: "uint64_t", ident: "b"}]" },
 Function:{ intrinsic: "vcreateq_s16", decoded: NEON:{ name: "create", suffix: q, types: {'out': 's16', 'in': 's16'}}, return_type: "int16x8_t", args: "[Var:{ type: "uint64_t", ident: "a"}, Var:{ type: "uint64_t", ident: "b"}]" },
 Function:{ intrinsic: "vcreateq_s32", decoded: NEON:{ name: "create", suff

In [83]:
neon_funcs = [f for f in funcs if "NEON" in f.archs]  # Filter architecture
neon_int_funcs = [
    f
    for f in neon_funcs
    if "float" not in f.return_type and not any(["float" in arg.type for arg in f.args])
]
neon_float_funcs = [
    f
    for f in neon_funcs
    if "float" in f.return_type or any(["float" in arg.type for arg in f.args])
]

helium_funcs = [f for f in funcs if not "NEON" in f.archs]  # Filter architecture
helium_int_funcs = [
    f
    for f in helium_funcs
    if "float" not in f.return_type and not any(["float" in arg.type for arg in f.args])
]
helium_float_funcs = [
    f
    for f in helium_funcs
    if "float" in f.return_type or any(["float" in arg.type for arg in f.args])
]

In [84]:
import pprint
names = [f.name for f in helium_int_funcs]
pprint.pp(list(dict.fromkeys(names)))

['create',
 'ddup',
 'dwdup',
 'idup',
 'iwdup',
 'duplicate',
 'reverse_16bit',
 'reverse_32bit',
 'reverse_64bit',
 'uninitialized',
 'cmpeq',
 'cmpne',
 'cmpge',
 'cmpcs',
 'cmpgt',
 'cmphi',
 'cmple',
 'cmplt',
 'min',
 'mina',
 'minv',
 'minav',
 'max',
 'maxa',
 'maxv',
 'maxav',
 'abav',
 'subtract_abs',
 'abs',
 'abs_saturate',
 'adci',
 'adc',
 'add',
 'addlva',
 'addlv',
 'addva',
 'addv',
 'add_halve',
 'add_halve_round',
 'add_saturate',
 'mulh',
 'mullb_poly',
 'mullb_int',
 'mullt_poly',
 'mullt_int',
 'multiply',
 'rmulh',
 'qdmladh',
 'qdmladhx',
 'qrdmladh',
 'qrdmladhx',
 'qdmlah',
 'qrdmlah',
 'qdmlash',
 'qrdmlash',
 'qdmlsdh',
 'qdmlsdhx',
 'qrdmlsdh',
 'qrdmlsdhx',
 'multiply_double_saturate_high',
 'multiply_double_round_saturate_high',
 'qdmullb',
 'qdmullt',
 'mladava',
 'mladav',
 'mladavax',
 'mladavx',
 'mlaldava',
 'mlaldav',
 'mlaldavax',
 'mlaldavx',
 'multiply_add',
 'mlas',
 'mlsdava',
 'mlsdav',
 'mlsdavax',
 'mlsdavx',
 'mlsldava',
 'mlsldav',
 'mlsld

In [85]:
uniq_funcs = { f.decl:f for f in helium_int_funcs}.values() # uniquify (vshll_n duped for some reason)

In [86]:
funcs = uniq_funcs

In [87]:
from pprint import pprint

sorted_names = [
  f"{name_map[func.decoded.name]} : {func.decl}" for func in sorted(funcs) if "v7" in func.archs
]

name_args = {}
missing = {}
for func in funcs:
  args = ', '.join([str(arg) for arg in func.args])
  try:
    key = f"{name_map[func.decoded.name]}({args})"
  except KeyError:
    missing[func.decoded.name] = ''
    continue
  if key in name_args.keys():
    name_args[key] += [func]
  else:
    name_args[key] = [func]

pprint(missing)

needs_template = {k:v for (k,v) in name_args.items() if len(v) > 1}
needs_template_sig = needs_template.keys()
needs_template_funcs = [f for v in needs_template.values() for f in v]

{'abav': '',
 'adc': '',
 'adci': '',
 'addlv': '',
 'addlva': '',
 'addv': '',
 'addva': '',
 'asrl': '',
 'brsr': '',
 'cmpcs': '',
 'cmpeq': '',
 'cmpge': '',
 'cmpgt': '',
 'cmphi': '',
 'cmple': '',
 'cmplt': '',
 'cmpne': '',
 'ctp16': '',
 'ctp32': '',
 'ctp64': '',
 'ctp8': '',
 'ddup': '',
 'dwdup': '',
 'hcadd_rot270': '',
 'hcadd_rot90': '',
 'idup': '',
 'iwdup': '',
 'ldrb': '',
 'ldrb_gather_offset': '',
 'ldrd_gather_base': '',
 'ldrd_gather_offset': '',
 'ldrd_gather_shifted_offset': '',
 'ldrh': '',
 'ldrh_gather_offset': '',
 'ldrh_gather_shifted_offset': '',
 'ldrw': '',
 'ldrw_gather_base': '',
 'ldrw_gather_offset': '',
 'ldrw_gather_shifted_offset': '',
 'lsll': '',
 'maxa': '',
 'maxav': '',
 'maxv': '',
 'mina': '',
 'minav': '',
 'minv': '',
 'mladav': '',
 'mladava': '',
 'mladavax': '',
 'mladavx': '',
 'mlaldav': '',
 'mlaldava': '',
 'mlaldavax': '',
 'mlaldavx': '',
 'mlas': '',
 'mlsdav': '',
 'mlsdava': '',
 'mlsdavax': '',
 'mlsdavx': '',
 'mlsldav': ''

In [88]:

blacklist = [
  "move_high_narrow",
  "multiply_add_fused",
  "multiply_subtract_fused"
]

no_constexpr = [
  "load1",
  "load2",
  "load3",
  "load4",
  "load1_duplicate",
  "load2_duplicate",
  "load3_duplicate",
  "load4_duplicate",
  "load1_x2",
  "load1_x3",
  "load1_x4",
  "store1",
  "store2",
  "store3",
  "store4",
  "store1_x2",
  "store1_x3",
  "store1_x4",
]

always_inline = "[[gnu::always_inline]] "

def simplify_type(type):
  parts = type.split('x')
  return parts[0] + "_v"

def generate_function(func):
  args = ', '.join([str(arg) for arg in func.args])
  arg_idents = ', '.join([arg.ident.replace('*','') for arg in func.args])
  definition = "template <> " if func in needs_template_funcs else ""
  definition += always_inline
  definition += "nce " if func.name not in no_constexpr else "inline "
  definition += f"{func.return_type} {func.name}({args}) {{ return {func.intrinsic}({arg_idents}); }}"
  return definition

def generate_templated_function(func):
  args = ', '.join([str(arg) for arg in func.args])
  arg_idents = ', '.join([arg.ident.replace('*','') for arg in func.args])
  const_name = func.const
  args = args.replace(f', const int {const_name}', '')
  definition = f"template <int {const_name}>"
  definition += always_inline
  definition += f"nce {func.return_type} {func.name}({args}) {{ return {func.intrinsic}({arg_idents}, {const_name}); }}"
  return definition


In [89]:
#with open('neon.hpp', 'w') as sys.stdout:
print('#pragma once')
print('#include <arm_mve.h>')
print('#ifdef __cplusplus')
print('''#ifdef __clang__
#define nce constexpr
#else
#define nce inline
#endif
''')
print('namespace helium {')
print('// clang-format off')
for (name_arg, func_list) in needs_template.items():
  print(f"template <typename T> nce T {name_arg};")
for func in sorted(funcs):
  if func.name not in blacklist:
    if func.const != None:
      out = generate_templated_function(func)
    else:
      out = generate_function(func)
    print(out)
print('// clang-format on')
print('}  // namespace helium')
print('#endif')

#pragma once
#include <arm_mve.h>
#ifdef __cplusplus
#ifdef __clang__
#define nce constexpr
#else
#define nce inline
#endif

namespace helium {
// clang-format off
template <typename T> nce T create(uint64_t a, uint64_t b);
[[gnu::always_inline]] nce uint8x16_t reverse_16bit(uint8x16_t inactive, uint8x16_t a, mve_pred16_t p) { return vrev16q_m_u8(inactive, a, p); }
[[gnu::always_inline]] nce uint8x16_t reverse_32bit(uint8x16_t inactive, uint8x16_t a, mve_pred16_t p) { return vrev32q_m_u8(inactive, a, p); }
[[gnu::always_inline]] nce uint8x16_t reverse_64bit(uint8x16_t inactive, uint8x16_t a, mve_pred16_t p) { return vrev64q_m_u8(inactive, a, p); }
[[gnu::always_inline]] nce mve_pred16_t cmpeq(uint8x16_t a, uint8x16_t b) { return vcmpeqq_u8(a, b); }
[[gnu::always_inline]] nce mve_pred16_t cmpeq(uint8x16_t a, uint8x16_t b, mve_pred16_t p) { return vcmpeqq_m_u8(a, b, p); }
[[gnu::always_inline]] nce mve_pred16_t cmpne(uint8x16_t a, uint8x16_t b) { return vcmpneq_u8(a, b); }
[[gnu::always_