Skip to content

Commit

Permalink
Merge pull request #8 from federicorossifr/feat_logtab_backend
Browse files Browse the repository at this point in the history
Feat logtab backend
  • Loading branch information
federicorossifr committed Aug 5, 2023
2 parents 689f1ce + 0b4e6ef commit 2a38536
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 25 deletions.
22 changes: 16 additions & 6 deletions examples/09_tabulated/09_tabulated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,30 @@ using namespace posit;
#include <iomanip>


template <class T,int nbits, int esbits>
template <class T,int nbits, int esbits, bool log>
struct PositTable;

#include "logptab4_0.h"
#include "ptab4_0.h"



int main() {
using PPEMU = Posit<int32_t, 4, 0, uint32_t, PositSpec::WithInfs>;
using PP = Posit<int8_t, 4, 0, TabulatedBackend<int8_t,PPEMU,PositTable<int8_t,4,0>>, PositSpec::WithInfs>;
using PTT = PositTableTrait<int8_t,4,0,false>;
using LogPTT = PositTableTrait<int8_t,4,0,true>;

using PTable = PositTable<int8_t,4,0,false>;
using LogPTable = PositTable<int8_t,4,0,true>;

using PP = Posit<int8_t, 4, 0, TabulatedBackend<PTT,PPEMU,PTable>, PositSpec::WithInfs>;
using LogPP = Posit<int8_t, 4, 0, TabulatedBackend<LogPTT,PPEMU,LogPTable>, PositSpec::WithInfs>;

PP a(0.5f), b(0.25f);
std::cout << a+b << std::endl;
std::cout << a*b << std::endl;
std::cout << a-b << std::endl;
std::cout << a/b << std::endl;
LogPP a2(0.5f), b2(0.25f);
std::cout << a+b << " " << a2+b2 << std::endl;
std::cout << a*b << " " << a2*b2 << std::endl;
std::cout << a-b << " " << a2-b2 << std::endl;
std::cout << a/b << " " << a2/b2 << std::endl;
return 0;
}
10 changes: 9 additions & 1 deletion examples/09_tabulated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@ add_custom_command(
DEPENDS ${PROJECT_SOURCE_DIR}/scripts/generatePositCppTables.py
)

add_executable(example_09_tabulated ${CMAKE_CURRENT_SOURCE_DIR}/09_tabulated.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ptab4_0.h)
add_custom_command(
OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/logptab4_0.h
COMMAND python ${PROJECT_SOURCE_DIR}/scripts/generatePositCppTables.py
-n 4 -e 0 --log
-o ${CMAKE_CURRENT_SOURCE_DIR}/logptab4_0.h
DEPENDS ${PROJECT_SOURCE_DIR}/scripts/generatePositCppTables.py
)

add_executable(example_09_tabulated ${CMAKE_CURRENT_SOURCE_DIR}/09_tabulated.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ptab4_0.h ${CMAKE_CURRENT_SOURCE_DIR}/logptab4_0.h)
45 changes: 37 additions & 8 deletions include/backends/tabback.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,28 @@
* --
*/


#include <type_traits>

namespace posit {

template <class T>
struct is_posit_backend;

template <class T, class PositEmu, class PositTable>
template <class T, int n, int e, bool log>
struct PositTableTrait {
using type = T;
using nbits = std::integral_constant<int,n>;
using esbits = std::integral_constant<int,e>;
using isLogtab = std::integral_constant<bool,log>;
};


template <class PTableTrait, class PositEmu, class PTable>
struct TabulatedBackend: public HwBaseBackend
{
struct single_tag{};
constexpr static T indexMask = (1<<PositEmu::vtotalbits)-1;
using T = PTableTrait::type;
constexpr static T indexMask = (1<<PTableTrait::nbits::value)-1;
TabulatedBackend() {}
TabulatedBackend(single_tag, T x): v(x) {}

Expand All @@ -34,20 +44,39 @@ namespace posit {
constexpr operator long () const {return (long)PositEmu::from_sraw(v);}

TabulatedBackend operator + (TabulatedBackend o) const {
return TabulatedBackend{{},PositTable::add[v & indexMask][o.v & indexMask]};
return TabulatedBackend{{},PTable::add[v & indexMask][o.v & indexMask]};
}
TabulatedBackend operator * (TabulatedBackend o) const {
return TabulatedBackend{{},PositTable::mul[v & indexMask][o.v & indexMask]};
if constexpr (PTableTrait::isLogtab::value) {
T idxa = v & indexMask, idxb = o.v & indexMask;
T sign = (idxa >> (PTableTrait::nbits::value-1)) ^ (idxb >> (PTableTrait::nbits::value-1));
T logA = PTable::log[idxa], logB = PTable::log[idxb];
T logAB = PTable::add[logA & indexMask][logB & indexMask];
T expAB = PTable::exp[logAB];
if(sign) expAB = - expAB;
return TabulatedBackend{{},expAB};
} else {
return TabulatedBackend{{},PTable::mul[v & indexMask][o.v & indexMask]};
}
}
TabulatedBackend operator / (TabulatedBackend o) const {
return TabulatedBackend{{},PositTable::div[v & indexMask][o.v & indexMask]};
if constexpr (PTableTrait::isLogtab::value) {
T idxa = v & indexMask, idxb = o.v & indexMask;
T sign = (idxa >> (PTableTrait::nbits::value-1)) ^ (idxb >> (PTableTrait::nbits::value-1));
T logA = PTable::log[idxa], logB = -PTable::log[idxb];
T logAB = PTable::add[logA & indexMask][logB & indexMask];
T expAB = PTable::exp[logAB];
if(sign) expAB = - expAB;
return TabulatedBackend{{},expAB};
} else
return TabulatedBackend{{},PTable::div[v & indexMask][o.v & indexMask]};
}
T v;

};

template <class T, class PositEmu, class PositTable>
struct is_posit_backend<TabulatedBackend<T,PositEmu,PositTable> >: public std::true_type
template <class T, class PositEmu, class PTable>
struct is_posit_backend<TabulatedBackend<T,PositEmu,PTable> >: public std::true_type
{
};

Expand Down
89 changes: 79 additions & 10 deletions scripts/generatePositCppTables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
from hardposit import from_bits, from_double
import argparse
import operator
import math

def demologexp(idxa,idxb,nbits,esbits):
logtab = generate1DLogTable(6,2)
exptab = generate1DExpTable(6,2)

a = from_bits(idxa,nbits,esbits)
b = from_bits(idxb,nbits,esbits)
sign = 1 if (a.eval() > 0 and b.eval() > 0) or (a.eval() < 0 and b.eval() < 0) else -1
print("a*b=",(a*b).eval())

alog = logtab[a.to_bits()]
blog = logtab[b.to_bits()]

print("log(a) = ",alog.eval()," log(b) = ",blog.eval())

ablog = alog + blog

print("log(a) + log(b) = ",ablog.eval())

expablog = exptab[ablog.to_bits()]

print("exp(log(a)+log(b))=",expablog.eval())

print("sign * exp(log(a)+log(b))=",sign*expablog.eval())

def generateOpTable(nbits, esbits,op):
posit_range = range(2**nbits)
Expand All @@ -14,8 +38,36 @@ def generateOpTable(nbits, esbits,op):

return table


def formatToCpp(table_2d, table_name, base_type):

def generate1DLogTable(nbits, esbits):
posit_range = range(2**nbits)
table = [0]*(2**nbits)

for i in posit_range:
posit_i = from_bits(i,nbits,esbits)
if posit_i.eval() == 0:
table[i] = from_bits(2**(nbits-1),nbits,esbits)
continue
sign = -1 if posit_i.eval() < 0 else 1
val = posit_i.eval() if sign > 0 else -posit_i.eval()
table[i] = from_double(math.log(val),nbits,esbits)
return table

def generate1DExpTable(nbits, esbits):
posit_range = range(2**nbits)
table = [0]*(2**nbits)
max_pos = from_bits(2**(nbits-1) - 1,nbits,esbits)
max_log = from_double(2*math.log(max_pos.eval()),nbits,esbits)
for i in posit_range:
posit_i = from_bits(i,nbits,esbits)
if math.fabs(posit_i.eval()) > max_log.eval():
table[i] = from_bits(2**(nbits-1),nbits,esbits)
continue
val = posit_i.eval()
table[i] = from_double(math.exp(val),nbits,esbits)
return table

def format2DToCpp(table_2d, table_name, base_type):
linear_size = len(table_2d)
declaration = f'static constexpr {base_type} {table_name}[{linear_size}][{linear_size}] = ';
content = "{";
Expand All @@ -33,6 +85,17 @@ def formatToCpp(table_2d, table_name, base_type):
content = content + "}"
return declaration+content

def format1DToCpp(table_1d,table_name, base_type):
linear_size = len(table_1d)
declaration = f'static constexpr {base_type} {table_name}[{linear_size}] = ';
content = "{";
for i in range(linear_size):
p_res = int(table_1d[i].to_hex(),base=16)
content = content + str(p_res)
if i < linear_size - 1:
content = content+","
content = content+"}"
return declaration+content

def getType(nbits):
if nbits <= 8:
Expand All @@ -41,13 +104,17 @@ def getType(nbits):
return "int16_t"


def formatStructOps(nbits,esbits):
declaration = f'template<> struct PositTable<{getType(nbits)},{nbits},{esbits}>'
def formatStructOps(nbits,esbits,logtab):
declaration = f'template<> struct PositTable<{getType(nbits)},{nbits},{esbits},{"true" if logtab else "false"}>'
declaration = declaration+"{"
declaration = declaration+formatToCpp(generateOpTable(nbits,esbits,operator.mul),"mul",getType(nbits))+";\n"
declaration = declaration+formatToCpp(generateOpTable(nbits,esbits,operator.add),"add",getType(nbits))+";\n"
declaration = declaration+formatToCpp(generateOpTable(nbits,esbits,operator.sub),"sub",getType(nbits))+";\n"
declaration = declaration+formatToCpp(generateOpTable(nbits,esbits,operator.truediv),"div",getType(nbits))+";\n"
declaration = declaration+format2DToCpp(generateOpTable(nbits,esbits,operator.add),"add",getType(nbits))+";\n"
declaration = declaration+format2DToCpp(generateOpTable(nbits,esbits,operator.sub),"sub",getType(nbits))+";\n"
if not logtab:
declaration = declaration+format2DToCpp(generateOpTable(nbits,esbits,operator.truediv),"div",getType(nbits))+";\n"
declaration = declaration+format2DToCpp(generateOpTable(nbits,esbits,operator.mul),"mul",getType(nbits))+";\n"
else:
declaration = declaration+format1DToCpp(generate1DLogTable(nbits,esbits),"log",getType(nbits))+";\n"
declaration = declaration+format1DToCpp(generate1DExpTable(nbits,esbits),"exp",getType(nbits))+";\n"

declaration = declaration+"};"
return declaration
Expand All @@ -58,15 +125,17 @@ def formatStructOps(nbits,esbits):

parser.add_argument('-n',help='Number of Posit bits',metavar='nbits',type=int,required=True)
parser.add_argument('-e',help="Number of exponent bits",metavar='esbits',type=int,required=True)
parser.add_argument('--log',help="Number of exponent bits",action='store_true',required=False,default=False)
parser.add_argument('-o',help="Output file",metavar="out.cpp",required=True)

args = parser.parse_args()

nbits = args.n
esbits = args.e
outfile = args.o
logtab = args.log

print(f'Generating source header for table: n={nbits}, esbits={esbits}')
out=formatStructOps(nbits,esbits)
out=formatStructOps(nbits,esbits,logtab)

with open(outfile,"w") as f:
f.write(out)
Expand Down

0 comments on commit 2a38536

Please sign in to comment.