Skip to content

Commit

Permalink
Add fp to smt printer.
Browse files Browse the repository at this point in the history
  • Loading branch information
sjudson authored and ccadar committed Feb 15, 2022
1 parent e091e23 commit aa90765
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 65 deletions.
14 changes: 7 additions & 7 deletions include/klee/util/ExprSMTLIBPrinter.h
Expand Up @@ -116,8 +116,8 @@ class ExprSMTLIBPrinter {
ABBR_NAMED ///< Abbreviate with :named annotations.
};

/// Different supported SMTLIBv2 sorts (a.k.a type) in QF_AUFBV
enum SMTLIB_SORT { SORT_BITVECTOR, SORT_BOOL };
/// Different supported SMTLIBv2 sorts (a.k.a type) in QF_FPABV
enum SMTLIB_SORT { SORT_BITVECTOR, SORT_BOOL, SORT_FP };

/// Allows the way Constant bitvectors are printed to be changed.
/// This setting is persistent across queries.
Expand Down Expand Up @@ -238,7 +238,7 @@ class ExprSMTLIBPrinter {
SMTLIB_SORT getSort(const ref<Expr> &e);

/// Print an expression but cast it to a particular SMTLIBv2 sort first.
void printCastToSort(const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT sort);
void printCastToSort(const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT sort, ExprSMTLIBPrinter::SMTLIB_SORT constSort = SORT_BITVECTOR);

// Resets various internal objects for a new query
void reset();
Expand Down Expand Up @@ -278,14 +278,14 @@ class ExprSMTLIBPrinter {

/// Print a Constant in the format specified by the current "Constant Display
/// Mode"
void printConstant(const ref<ConstantExpr> &e);
void printConstant(const ref<ConstantExpr> &e, ExprSMTLIBPrinter::SMTLIB_SORT constSort = SORT_BITVECTOR);

/// Recursively print expression
/// \param e is the expression to print
/// \param expectedSort is the sort we want. If "e" is not of the right type a
/// cast will be performed.
/// \param abbrMode the abbreviation mode to use for this expression
void printExpression(const ref<Expr> &e, SMTLIB_SORT expectedSort);
void printExpression(const ref<Expr> &e, SMTLIB_SORT expectedSort, ExprSMTLIBPrinter::SMTLIB_SORT constSort = SORT_BITVECTOR);

/// Scan Expression recursively for Arrays in expressions. Found arrays are
/// added to
Expand Down Expand Up @@ -323,7 +323,7 @@ class ExprSMTLIBPrinter {

// For the set of operators that take sort "s" arguments
void printSortArgsExpr(const ref<Expr> &e,
ExprSMTLIBPrinter::SMTLIB_SORT s);
ExprSMTLIBPrinter::SMTLIB_SORT s, ExprSMTLIBPrinter::SMTLIB_SORT c = SORT_BITVECTOR);

/// For the set of operators that come in two sorts (e.g. (and () ()) (bvand
/// () ()) )
Expand Down Expand Up @@ -369,7 +369,7 @@ class ExprSMTLIBPrinter {
getSMTLIBOptionString(ExprSMTLIBPrinter::SMTLIBboolOptions option);

/// Print expression without top-level abbreviations
void printFullExpression(const ref<Expr> &e, SMTLIB_SORT expectedSort);
void printFullExpression(const ref<Expr> &e, SMTLIB_SORT expectedSort, ExprSMTLIBPrinter::SMTLIB_SORT constSort = SORT_BITVECTOR);

/// Print an assert statement for the given expr.
void printAssert(const ref<Expr> &e);
Expand Down
225 changes: 167 additions & 58 deletions lib/Expr/ExprSMTLIBPrinter.cpp
Expand Up @@ -10,6 +10,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "klee/util/ExprSMTLIBPrinter.h"
#include "klee/util/APFloatEval.h"

#include <stack>

Expand Down Expand Up @@ -106,7 +107,7 @@ bool ExprSMTLIBPrinter::setConstantDisplayMode(ConstantDisplayMode cdm) {
return true;
}

void ExprSMTLIBPrinter::printConstant(const ref<ConstantExpr> &e) {
void ExprSMTLIBPrinter::printConstant(const ref<ConstantExpr> &e, ExprSMTLIBPrinter::SMTLIB_SORT c) {
/* Handle simple boolean constants */

if (e->isTrue()) {
Expand All @@ -130,6 +131,27 @@ void ExprSMTLIBPrinter::printConstant(const ref<ConstantExpr> &e) {
*/
unsigned int zeroPad = 0;

/* Special case to print FP constants */
if (c == SORT_FP) {
e->toString(value, 10);

std::string delimit = "E+";
if (value.find(delimit) == std::string::npos) {
unsigned long long i = std::strtoull(value.c_str(), NULL, 10);
double d;

memcpy(&d, &i, sizeof(&d));

llvm::APFloat asF(d);
llvm::SmallVector<char, 16> result;
asF.toString(result, /*FormatPrecision=*/0, /*FormatMaxPadding=*/0);
value = std::string(result.begin(), result.end());
}

*p << "((_ to_fp 11 53) RNE " << value.substr(0, value.find(delimit)) << " " << value.substr(value.find(delimit) + 2) << ")";
return;
}

switch (cdm) {
case BINARY:
e->toString(value, 2);
Expand All @@ -156,6 +178,7 @@ void ExprSMTLIBPrinter::printConstant(const ref<ConstantExpr> &e) {

case DECIMAL:
e->toString(value, 10);

*p << "(_ bv" << value << " " << e->getWidth() << ")";
break;

Expand All @@ -164,11 +187,10 @@ void ExprSMTLIBPrinter::printConstant(const ref<ConstantExpr> &e) {
}
}

void ExprSMTLIBPrinter::printExpression(
const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT expectedSort) {
void ExprSMTLIBPrinter::printExpression(const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT expectedSort, ExprSMTLIBPrinter::SMTLIB_SORT constSort) {
// check if casting might be necessary
if (getSort(e) != expectedSort) {
printCastToSort(e, expectedSort);
printCastToSort(e, expectedSort, constSort);
return;
}

Expand All @@ -190,7 +212,7 @@ void ExprSMTLIBPrinter::printExpression(
if (i != bindings.end()) {
if (i->second > 0) {
*p << "(! ";
printFullExpression(e, expectedSort);
printFullExpression(e, expectedSort, constSort);
*p << " :named ?B" << i->second << ")";
i->second = -i->second;
} else {
Expand All @@ -202,14 +224,13 @@ void ExprSMTLIBPrinter::printExpression(
}
}

printFullExpression(e, expectedSort);
printFullExpression(e, expectedSort, constSort);
}

void ExprSMTLIBPrinter::printFullExpression(
const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT expectedSort) {
void ExprSMTLIBPrinter::printFullExpression(const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT expectedSort, ExprSMTLIBPrinter::SMTLIB_SORT constSort) {
switch (e->getKind()) {
case Expr::Constant:
printConstant(cast<ConstantExpr>(e));
printConstant(cast<ConstantExpr>(e), constSort);
return; // base case

case Expr::NotOptimized:
Expand Down Expand Up @@ -261,6 +282,24 @@ void ExprSMTLIBPrinter::printFullExpression(
printAShrExpr(cast<AShrExpr>(e));
return;

case Expr::FOEq:
case Expr::FOLt:
case Expr::FOLe:
case Expr::FOGt:
case Expr::FOGe:
case Expr::FAdd:
case Expr::FSub:
case Expr::FMul:
case Expr::FDiv:
case Expr::IsNaN:
case Expr::IsInfinite:
case Expr::IsNormal:
case Expr::IsSubnormal:
case Expr::FSqrt:
case Expr::FAbs:
printSortArgsExpr(e, SORT_FP, SORT_FP); // if one of the operands to an FP op is a constant, we need to print it as an FP
return;

default:
/* The remaining operators (Add,Sub...,Ult,Ule,..)
* Expect SORT_BITVECTOR arguments
Expand Down Expand Up @@ -468,6 +507,39 @@ const char *ExprSMTLIBPrinter::getSMTLIBKeyword(const ref<Expr> &e) {
case Expr::Sge:
return "bvsge";

case Expr::FOEq:
return "fp.eq";
case Expr::FOLt:
return "fp.lt";
case Expr::FOLe:
return "fp.leq";
case Expr::FOGt:
return "fp.gt";
case Expr::FOGe:
return "fp.geq";

case Expr::FAdd:
return "fp.add RNE";
case Expr::FSub:
return "fp.sub RNE";
case Expr::FMul:
return "fp.mul RNE";
case Expr::FDiv:
return "fp.div RNE";

case Expr::IsNaN:
return "fp.isNaN";
case Expr::IsInfinite:
return "fp.isInfinite";
case Expr::IsNormal:
return "fp.isNormal";
case Expr::IsSubnormal:
return "fp.isSubnormal";
case Expr::FSqrt:
return "fp.sqrt RNE";
case Expr::FAbs:
return "fp.abs RNE";

default:
llvm_unreachable("Conversion from Expr to SMTLIB keyword failed");
}
Expand Down Expand Up @@ -541,10 +613,10 @@ void ExprSMTLIBPrinter::printSetLogic() {
*o << "(set-logic ";
switch (logicToUse) {
case QF_ABV:
*o << "QF_ABV";
*o << "ALL";
break;
case QF_AUFBV:
*o << "QF_AUFBV";
*o << "ALL";
break;
}
*o << " )\n";
Expand Down Expand Up @@ -946,8 +1018,26 @@ ExprSMTLIBPrinter::SMTLIB_SORT ExprSMTLIBPrinter::getSort(const ref<Expr> &e) {
case Expr::Ule:
case Expr::Ugt:
case Expr::Uge:
case Expr::FOEq:
case Expr::FOLt:
case Expr::FOLe:
case Expr::FOGt:
case Expr::FOGe:
case Expr::IsNaN:
case Expr::IsInfinite:
case Expr::IsNormal:
case Expr::IsSubnormal:
return SORT_BOOL;

// Float ops return floats.
case Expr::FAdd:
case Expr::FSub:
case Expr::FMul:
case Expr::FDiv:
case Expr::FSqrt:
case Expr::FAbs:
return SORT_FP;

// These may be bitvectors or bools depending on their width (see
// printConstant and printLogicalOrBitVectorExpr).
case Expr::Constant:
Expand All @@ -964,56 +1054,76 @@ ExprSMTLIBPrinter::SMTLIB_SORT ExprSMTLIBPrinter::getSort(const ref<Expr> &e) {
}

void ExprSMTLIBPrinter::printCastToSort(const ref<Expr> &e,
ExprSMTLIBPrinter::SMTLIB_SORT sort) {
ExprSMTLIBPrinter::SMTLIB_SORT sort, ExprSMTLIBPrinter::SMTLIB_SORT constSort) {

switch (sort) {
case SORT_BITVECTOR:
if (humanReadable) {
p->breakLineI();
*p << ";Performing implicit bool to bitvector cast";
p->breakLine();
case SORT_FP: {
if (e->getKind() == SORT_BITVECTOR) {
// if the internal expression is a constant we'll let the constSort take care of it so we can skip the cast
if ( e->getKind() != Expr::Constant ) *p << "((_ to_fp 11 53) RNE ";
printExpression(e, SORT_BITVECTOR, constSort);
if ( e->getKind() != Expr::Constant ) *p << ")";
} else { // SORT_BOOL
llvm_unreachable("Unsupported cast");
}
// We assume the e is a bool that we need to cast to a bitvector sort.
*p << "(ite";
p->pushIndent();
printSeperator();
printExpression(e, SORT_BOOL);
printSeperator();
*p << "(_ bv1 1)";
printSeperator(); // printing the "true" bitvector
*p << "(_ bv0 1)";
p->popIndent();
printSeperator(); // printing the "false" bitvector
*p << ")";
break;
case SORT_BOOL: {
/* We make the assumption (might be wrong) that any bitvector whose unsigned
* decimal value is is zero is interpreted as "false", otherwise it is
* true.
*
* This may not be the interpretation we actually want!
*/
Expr::Width bitWidth = e->getWidth();
if (humanReadable) {
p->breakLineI();
*p << ";Performing implicit bitvector to bool cast";
p->breakLine();
} break;
case SORT_BITVECTOR: {
if (e->getKind() == SORT_BOOL) {
if (humanReadable) {
p->breakLineI();
*p << ";Performing implicit bool to bitvector cast";
p->breakLine();
}
// We assume the e is a bool that we need to cast to a bitvector sort.
*p << "(ite";
p->pushIndent();
printSeperator();
printExpression(e, SORT_BOOL, constSort);
printSeperator();
*p << "(_ bv1 1)";
printSeperator(); // printing the "true" bitvector
*p << "(_ bv0 1)";
p->popIndent();
printSeperator(); // printing the "false" bitvector
*p << ")";
} else { // SORT_FP
*p << "((_ fp.to_ubv 64) RNE ";
printExpression(e, SORT_FP, constSort);
*p << ")";
}
*p << "(bvugt";
p->pushIndent();
printSeperator();
// We assume is e is a bitvector
printExpression(e, SORT_BITVECTOR);
printSeperator();
*p << "(_ bv0 " << bitWidth << ")";
p->popIndent();
printSeperator(); // Zero bitvector of required width
*p << ")";
} break;
case SORT_BOOL: {
if (e->getKind() == SORT_BITVECTOR) {
/* We make the assumption (might be wrong) that any bitvector whose unsigned
* decimal value is is zero is interpreted as "false", otherwise it is
* true.
*
* This may not be the interpretation we actually want!
*/
Expr::Width bitWidth = e->getWidth();
if (humanReadable) {
p->breakLineI();
*p << ";Performing implicit bitvector to bool cast";
p->breakLine();
}
*p << "(bvugt";
p->pushIndent();
printSeperator();
// We assume is e is a bitvector
printExpression(e, SORT_BITVECTOR, constSort);
printSeperator();
*p << "(_ bv0 " << bitWidth << ")";
p->popIndent();
printSeperator(); // Zero bitvector of required width
*p << ")";

if (bitWidth != Expr::Bool)
llvm::errs()
if (bitWidth != Expr::Bool)
llvm::errs()
<< "ExprSMTLIBPrinter : Warning. Casting a bitvector (length "
<< bitWidth << ") to bool!\n";

} else { // SORT_FP
llvm_unreachable("Unsupported cast");
}
} break;
default:
llvm_unreachable("Unsupported cast");
Expand Down Expand Up @@ -1048,15 +1158,14 @@ void ExprSMTLIBPrinter::printSelectExpr(const ref<SelectExpr> &e,
*p << ")";
}

void ExprSMTLIBPrinter::printSortArgsExpr(const ref<Expr> &e,
ExprSMTLIBPrinter::SMTLIB_SORT s) {
void ExprSMTLIBPrinter::printSortArgsExpr(const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT s, ExprSMTLIBPrinter::SMTLIB_SORT c) {
*p << "(" << getSMTLIBKeyword(e) << " ";
p->pushIndent(); // add indent for recursive call

// loop over children and recurse into each expecting they are of sort "s"
for (unsigned int i = 0; i < e->getNumKids(); i++) {
printSeperator();
printExpression(e->getKid(i), s);
printExpression(e->getKid(i), s, c);
}

p->popIndent(); // pop indent added for recursive call
Expand Down

0 comments on commit aa90765

Please sign in to comment.