In [86]:
import os 
import re
import datetime
from random import randint, random
import readline
from time import sleep
import json
from textwrap import dedent

In [409]:
import numpy as np
from math import ceil

In [533]:
DOC_PATH = os.path.join('..','rvv-intrinsic-doc-master','intrinsic_funcs')
VLEN = 128

def random_line():
    md_files = os.listdir(DOC_PATH)
    filename = [f for f in md_files if f.startswith("05")][0]
    filename = os.path.join(DOC_PATH,filename)
    filesize = os.stat(filename).st_size
    with open(filename,'r') as f:
        ret = ""
        while not re.findall(';\s+$',ret):
            file_pointer = randint(0,filesize)
            f.seek(file_pointer)
            f.readline()
            ret =  f.readline()
        return ret

In [786]:
class RvvType():
    SIZE_DICT = {'mf8':1/8,'mf4':1/4,'mf2':1/2,'m1':1,'m2':2,'m4':4,'m8':8,}
    attr = ['basetype','sew','lmul','vl','bytesize','abbr']
    def __init__(self, string):
        if not string.startswith('v'):
            raise RuntimeError("%s not a rvv type" %string)        
        self.string = string
        self._parse()
        assert all([getattr(self,attr,None) for attr in self.attr]), "not all attr defined"
        
    def __eq__(self,other):
        return self.string == other.string
    def __hash__(self):
        return hash(self.string)
    
    def __repr__(self):
        dtype = 'rvv vector type %s:\n' %self.string
        attrs = '\n'.join([name+':'+ str(getattr(self,name,None)) for name in self.attr])
        return dtype+attrs
        
    def _abbr_rule(self,match):
        dtype = match.group(0)
        if dtype=='uint':
            return 'u'
        elif dtype=='int':
            return 'i'
        elif dtype=='float':
            return 'f'     
        
    def _parse(self):
        match = re.search('(?P<basetype>(u?int\d+)|(float\d+))(?P<lmul>mf?\d)',self.string)
        self.lmul = lmul = match.group('lmul')
        regsize = self.SIZE_DICT.get(lmul)*VLEN
        regsize = int(regsize)
        self.bytesize = regsize//8
        basetype = match.group('basetype')
        self.abbr = re.sub('(?P<match>(u?int)|(float))',self._abbr_rule,basetype)+lmul
        self.basetype = basetype+'_t'
        self.sew = basesize = re.search('(?P<size>\d+)',basetype).group('size')
        self.vl = int(regsize)//int(basesize)
    
    @property
    def declare(self):
        return dedent('''\
        {basetype} *{{op}}_base = ({basetype}*){{addr}};
        '''.format(basetype=self.basetype))   
    
    @property
    def context(self):
        return dedent('''\
        //context for {{op}}
        {} {{op}} = vle{}_v_{}({{op}}_base,{});
        '''.format(self.string, self.sew, self.abbr,self.vl))
     

In [787]:
class RvvBool(RvvType):
    LMUL_DICT = {v:k for k,v in RvvType.SIZE_DICT.items()}
    attr = ['basetype','n','vl','bytesize','abbr']
    def __init__(self,string):
        RvvType.__init__(self,string)
        
    def _parse(self):
        self.n = n = re.search('(?P<n>\d+)',self.string).group('n')
        self.abbr = 'b%s' %n
        self.vl = vlmax = VLEN//int(n)
        self.bytesize = vlmax//8 + int((vlmax%8)!=0)
        self.basetype = 'uint8_t'
    
    @property
    def context(self):
        lmul = self.vl*8/VLEN
        lmul = self.LMUL_DICT[lmul]
        return dedent('''\
        //context for {{op}}
        {vbool_t} {{op}} = vlm_v_b{n}({{op}}_base,{vl});   
        uint8{lmul}_t vec_{{op}} = vmv_v_x_u8{lmul}(0,{vl});
        vec_{{op}} = vmerge_vxm({{op}},vec_{{op}},1,{vl});
        _Bool {{op}}_bool[{vl}];
        vse8_v_u8{lmul}({{op}}_bool,vec_{{op}},{vl});
        '''.format(lmul=lmul,vl=self.vl,vbool_t = self.string,n=self.n))
    

In [795]:
class ScalarType():
    TYPEDEF = {'size_t':'uint32_t',}
    def __init__(self,string):
        if string.startswith('v'):
            raise RuntimeError("%s is not a scalar type" %string)
        self.string = string
        self.basetype = string
        self.alias = self.TYPEDEF.get(string,string)
        self.bytesize = self._sizeof(self.alias)
        self.attr = ['bytesize','alias']
        
    def __repr__(self):
        dtype = 'scalar type %s:\n' %self.string
        attrs = '\n'.join([name+':'+ str(getattr(self,name,None)) for name in self.attr])
        return dtype+attrs
    
    def _sizeof(self,typename):
        return np.dtype(typename.rstrip('_t')).itemsize
    
    @property
    def declare(self):
        return dedent('''\
        {dtype} *{{op}} = ({dtype}*){{addr}};
        '''.format(dtype=self.string))
    @property
    def context(self):
        return "//nocontext for {op}\n"

In [788]:
def parse_dtype(dtype:str):
    if dtype.startswith('vbool'):
        return RvvBool(dtype)
    elif dtype.startswith('v'):
        return RvvType(dtype)
    else:
        return ScalarType(dtype)
    
def parse_declaration(line):
    ret, func, rest_of_line = line.split(" ", 2)
    match = re.search('(?P<ops>(?<=\().+(?=\)))',rest_of_line)
    op_pairs = match.group('ops').split(',')
    op_pairs = [op.split() for op in op_pairs]
    operands = {arg:dtype for dtype,arg in op_pairs}
    #return {'ret':ret,'func':func,'ops':operands}
    return [ret,func,operands]

In [810]:
line = random_line()
line

'vbool8_t vmseq_vv_u32m4_b8 (vuint32m4_t op1, vuint32m4_t op2, size_t vl);\n'

In [811]:
ret,func,ops = parse_declaration(line)

In [812]:
class AddrDispensor():
    def __init__(self,addr_begin:int):
        self.begin = addr_begin
        self.addr = addr_begin
        
    def get_addr(self,len_btye):
        assert (len_btye>0), 'addr increment cannot <=0'
        addr = self.addr
        self.addr += len_btye
        self.addr = ceil(self.addr/8)*8
        return addr

In [813]:
ad = AddrDispensor(0x2000)

In [814]:
ret = parse_dtype(ret)
str_declare = ''
str_context = ''
str_declare += dedent('''\
    {dtype} *golden = ({dtype}*){addr_g};
    {dtype} *actual = ({dtype}*){addr_a};
    '''.format(dtype=ret.basetype,
               addr_g=addr.get_addr(ret.bytesize),
               addr_a=addr.get_addr(ret.bytesize)))
str_context += ret.context.format(op='vec_actual')

In [815]:
for name,dtype in ops.items():
    dtype = parse_dtype(dtype)
    str_declare += dtype.declare.format(op=name,addr=ad.get_addr(dtype.bytesize))
    str_context += dtype.context.format(op=name)


In [816]:
print(line)
print(str_declare)
print(str_context)

vbool8_t vmseq_vv_u32m4_b8 (vuint32m4_t op1, vuint32m4_t op2, size_t vl);

uint8_t *golden = (uint8_t*)8912;
uint8_t *actual = (uint8_t*)8920;
uint32_t *op1_base = (uint32_t*)8192;
uint32_t *op2_base = (uint32_t*)8256;
size_t *vl = (size_t*)8320;

//context for vec_actual
vbool8_t vec_actual = vlm_v_b8(vec_actual_base,16);   
uint8m1_t vec_vec_actual = vmv_v_x_u8m1(0,16);
vec_vec_actual = vmerge_vxm(vec_actual,vec_vec_actual,1,16);
_Bool vec_actual_bool[16];
vse8_v_u8m1(vec_actual_bool,vec_vec_actual,16);
//context for op1
vuint32m4_t op1 = vle32_v_u32m4(op1_base,16);
//context for op2
vuint32m4_t op2 = vle32_v_u32m4(op2_base,16);
//nocontext for vl



In [820]:
C_TEMPLATE=dedent('''
    #include "string.h"
    #include "riscv_vector.h"
    int main(){{
        {}{}{}
        return 0;
    }}
    ''')
print(C_TEMPLATE.format(str_declare,str_context,func))


#include "string.h"
#include "riscv_vector.h"
int main(){
    uint8_t *golden = (uint8_t*)8912;
uint8_t *actual = (uint8_t*)8920;
uint32_t *op1_base = (uint32_t*)8192;
uint32_t *op2_base = (uint32_t*)8256;
size_t *vl = (size_t*)8320;
//context for vec_actual
vbool8_t vec_actual = vlm_v_b8(vec_actual_base,16);   
uint8m1_t vec_vec_actual = vmv_v_x_u8m1(0,16);
vec_vec_actual = vmerge_vxm(vec_actual,vec_vec_actual,1,16);
_Bool vec_actual_bool[16];
vse8_v_u8m1(vec_actual_bool,vec_vec_actual,16);
//context for op1
vuint32m4_t op1 = vle32_v_u32m4(op1_base,16);
//context for op2
vuint32m4_t op2 = vle32_v_u32m4(op2_base,16);
//nocontext for vl
vmseq_vv_u32m4_b8
    return 0;
}

