-
Notifications
You must be signed in to change notification settings - Fork 7
/
if_conversion_handler.cc
97 lines (79 loc) · 4.16 KB
/
if_conversion_handler.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include "if_conversion_handler.h"
#include <iostream>
#include "third_party/assert_exception.h"
#include "pkt_func_transform.h"
#include "clang_utility_functions.h"
using namespace clang;
using std::placeholders::_1;
using std::placeholders::_2;
std::string IfConversionHandler::transform(const TranslationUnitDecl * tu_decl) {
unique_identifiers_ = UniqueIdentifiers(identifier_census(tu_decl));
return pkt_func_transform(tu_decl, std::bind(& IfConversionHandler::if_convert_body, this, _1, _2));
}
std::pair<std::string, std::vector<std::string>> IfConversionHandler::if_convert_body(const Stmt * function_body, const std::string & pkt_name) const {
assert_exception(function_body);
std::string output_ = "";
std::vector<std::string> new_decls_ = {};
// 1 is the C representation for true
if_convert(output_, new_decls_, "1", function_body, pkt_name);
return make_pair("{" + output_ + "}", new_decls_);
}
void IfConversionHandler::if_convert(std::string & current_stream,
std::vector<std::string> & current_decls,
const std::string & predicate,
const Stmt * stmt,
const std::string & pkt_name) const {
if (isa<CompoundStmt>(stmt)) {
for (const auto & child : stmt->children()) {
if_convert(current_stream, current_decls, predicate, child, pkt_name);
}
} else if (isa<IfStmt>(stmt)) {
const auto * if_stmt = dyn_cast<IfStmt>(stmt);
if (if_stmt->getConditionVariableDeclStmt()) {
throw std::logic_error("We don't yet handle declarations within the test portion of an if\n");
}
// Create temporary variable to hold the if condition
const auto condition_type_name = if_stmt->getCond()->getType().getAsString();
const auto cond_variable = unique_identifiers_.get_unique_identifier();
const auto cond_var_decl = condition_type_name + " " + cond_variable + ";";
// Add cond var decl to the packet structure, so that all decls accumulate there
current_decls.emplace_back(cond_var_decl);
// Add assignment to new packet temporary here,
// predicating it with the current predicate
const auto pkt_cond_variable = pkt_name + "." + cond_variable;
current_stream += pkt_cond_variable + " = (" + predicate + " ? (" + clang_stmt_printer(if_stmt->getCond()) + ") : 0);";
// Create predicates for if and else block
auto pred_within_if_block = "(" + predicate + " && " + pkt_cond_variable + ")";
auto pred_within_else_block = "(" + predicate + " && !" + pkt_cond_variable + ")";
// If convert statements within getThen block to ternary operators.
if_convert(current_stream, current_decls, pred_within_if_block, if_stmt->getThen(), pkt_name);
// If there is a getElse block, handle it recursively again
if (if_stmt->getElse() != nullptr) {
if_convert(current_stream, current_decls, pred_within_else_block, if_stmt->getElse(), pkt_name);
}
} else if (isa<BinaryOperator>(stmt)) {
current_stream += if_convert_atomic_stmt(dyn_cast<BinaryOperator>(stmt), predicate);
} else if (isa<DeclStmt>(stmt)) {
// Just append statement as is, but check that this only happens at the
// top level i.e. when predicate = "1" or true
assert_exception(predicate == "1");
current_stream += clang_stmt_printer(stmt);
return;
} else if (isa<NullStmt>(stmt)) {
// Do nothing
return;
} else {
throw std::logic_error("Cannot handle stmt " + clang_stmt_printer(stmt) + " of type " + std::string(stmt->getStmtClassName()));
assert_exception(false);
}
}
std::string IfConversionHandler::if_convert_atomic_stmt(const BinaryOperator * stmt,
const std::string & predicate) const {
assert_exception(stmt);
assert_exception(stmt->isAssignmentOp());
assert_exception(not stmt->isCompoundAssignmentOp());
// Create predicated version of BinaryOperator
const std::string lhs = clang_stmt_printer(dyn_cast<BinaryOperator>(stmt)->getLHS());
const std::string rhs = "(" + predicate + " ? (" + clang_stmt_printer(stmt->getRHS()) + ") : " + lhs + ")";
return (lhs + " = " + rhs + ";");
}