Skip to content
Permalink
Browse files

Add an optimization pass for fast-math propagation.

  • Loading branch information...
sheredom committed Jun 11, 2019
1 parent da94646 commit 514fd8e77b2f848fe08057385c6e26a49c09b6f8
Showing with 454 additions and 4 deletions.
  1. +1 −0 .travis.yml
  2. +3 −3 README.md
  3. +4 −1 driver/fpscev_opt.cpp
  4. +159 −0 fpscev.cpp
  5. +287 −0 test/fast_math_propagation.ll
@@ -24,4 +24,5 @@ script:
- cmake -DCMAKE_BUILD_TYPE=Release -DLLVM_DIR=$llvm_dir/lib/cmake/llvm ..
- cmake --build .
- ./driver/fpscev_opt ../test/fpscev.ll -o - -S -fpscev-analysis
- ./driver/fpscev_opt ../test/fast_math_propagation.ll -o - -S

@@ -21,8 +21,8 @@ if (i < 4) {
}
```

And lets say that via scalar evolution we know that `i` could _never_ be greater
than 3. That allows the compiler to remove the entire if branch as it will never
And lets say that via scalar evolution we know that `i` could _never_ be less
than 4. That allows the compiler to remove the entire if branch as it will never
be hit!

## Why Floating-Point Scalar Evolution?
@@ -38,7 +38,7 @@ A great use of this would be something like the following:
float f = ...; // definitely not NaN or Infinity
f = sin(f); // because f wasn't NaN or Infinity f is now in the range [-1..1]
if (isfinite(f)) {
if (!isfinite(f)) {
// Do a million lines of awful code that will bloat you executable!
}
@@ -35,7 +35,9 @@ using namespace llvm;

namespace llvm {
extern void initializeFPScalarEvolutionPassPass(PassRegistry&);
extern void initializeFastMathPropagationPassPass(PassRegistry &);
extern Pass* createFPScalarEvolutionPass();
extern Pass *createFastMathPropagationPass();
}

static cl::opt<std::string> inputFilename(cl::Positional,
@@ -94,9 +96,10 @@ int main(const int argc, const char *const argv[]) {
PassRegistry* const passRegistry = PassRegistry::getPassRegistry();

initializeFPScalarEvolutionPassPass(*passRegistry);
initializeFastMathPropagationPassPass(*passRegistry);

legacy::PassManager passManager;
passManager.add(createFPScalarEvolutionPass());
passManager.add(createFastMathPropagationPass());
passManager.run(*module);

std::error_code error;
@@ -273,6 +273,14 @@ struct FPSCEV final {
errs() << "isInteger: ";
errs() << isInteger << "\n";
}

bool isNaN() const { return min.isNaN() || max.isNaN(); }

bool isFinite() const { return min.isFinite() && max.isFinite(); }

bool isAllNegative() const { return max.isNegative(); }

bool isAllNonNegative() const { return !min.isNegative(); }
};

struct FPScalarEvolution final {
@@ -292,6 +300,16 @@ struct FPScalarEvolution final {

return &map[value];
}

const FPSCEV *getFPSCEV(Value *const value) const {
auto iterator = map.find(value);

if (iterator == map.end()) {
return nullptr;
}

return &iterator->second;
}
};

// Define our LLVM pass as inheriting from a FunctionPass.
@@ -660,6 +678,10 @@ struct FPScalarEvolutionPass final : FunctionPass,
// other than apply fast math flags.
FPSCEV fpscev(inst.getType());

// Record the arguments to the frem even if we don't actually need them yet.
fpse.getFPSCEV(inst.getOperand(0));
fpse.getFPSCEV(inst.getOperand(1));

const FastMathFlags flags = inst.getFastMathFlags();
fpscev.min = applyFastMathFlags(fpscev.min, flags);
fpscev.max = applyFastMathFlags(fpscev.max, flags);
@@ -1348,6 +1370,8 @@ struct FPScalarEvolutionPass final : FunctionPass,
// identify the pass.
static char ID;

const FPScalarEvolution &getFPSCEV() const { return fpse; }

private:
FPScalarEvolution fpse;
ScalarEvolution *scalarEvolution;
@@ -1356,14 +1380,149 @@ struct FPScalarEvolutionPass final : FunctionPass,

char FPScalarEvolutionPass::ID;

namespace {
struct FastMathPropagationPass final : public FunctionPass,
InstVisitor<FastMathPropagationPass> {
FastMathPropagationPass() : FunctionPass(ID) {}

bool runOnFunction(Function &function) override {
fpse = &getAnalysis<FPScalarEvolutionPass>().getFPSCEV();
modified = false;
visit(function);
return modified;
}

void visitFCmpInst(FCmpInst &inst) {
const FPSCEV *const xFpscev = fpse->getFPSCEV(inst.getOperand(0));
const FPSCEV *const yFpscev = fpse->getFPSCEV(inst.getOperand(1));

if (xFpscev->isFinite() && yFpscev->isFinite()) {
inst.setHasNoInfs(true);
modified = true;
}

if (!xFpscev->isNaN() && !yFpscev->isNaN()) {
inst.setHasNoNaNs(true);
modified = true;
}
}

void visitUnaryOperator(UnaryOperator &inst) {
switch (inst.getOpcode()) {
default:
return;
case Instruction::FNeg:
break;
}

const FPSCEV *const fpscev = fpse->getFPSCEV(inst.getOperand(0));

if (fpscev->isFinite()) {
inst.setHasNoInfs(true);
modified = true;
}

if (!fpscev->isNaN()) {
inst.setHasNoNaNs(true);
modified = true;
}
}

void visitBinaryOperator(BinaryOperator &inst) {
switch (inst.getOpcode()) {
default:
return;
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul:
case Instruction::FDiv:
case Instruction::FRem:
break;
}

const FPSCEV *const xFpscev = fpse->getFPSCEV(inst.getOperand(0));
const FPSCEV *const yFpscev = fpse->getFPSCEV(inst.getOperand(1));
const FPSCEV *const fpscev = fpse->getFPSCEV(&inst);

if (fpscev->isFinite() && xFpscev->isFinite() && yFpscev->isFinite()) {
inst.setHasNoInfs(true);
modified = true;
}

if (!fpscev->isNaN() && !xFpscev->isNaN() && !yFpscev->isNaN()) {
inst.setHasNoNaNs(true);
modified = true;
}
}

void visitIntrinsicInst(IntrinsicInst &inst) {
bool atLeastOneFP = false;
bool allFinite = true;
bool allNotNaN = true;

for (Value *const arg : inst.args()) {
const FPSCEV *const fpscev = fpse->getFPSCEV(arg);

if (fpscev) {
atLeastOneFP = true;
allFinite = allFinite && fpscev->isFinite();
allNotNaN = allNotNaN && !fpscev->isNaN();
}
}

const FPSCEV *const fpscev = fpse->getFPSCEV(&inst);

if (fpscev) {
atLeastOneFP = true;
allFinite = allFinite && fpscev->isFinite();
allNotNaN = allNotNaN && !fpscev->isNaN();
}

if (atLeastOneFP) {
if (allFinite) {
inst.setHasNoInfs(true);
modified = true;
}

if (allNotNaN) {
inst.setHasNoNaNs(true);
modified = true;
}
}
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<FPScalarEvolutionPass>();
}

// The ID of this pass - the address of which is used by LLVM to uniquely
// identify the pass.
static char ID;

private:
const FPScalarEvolution *fpse;
bool modified;
};
} // namespace

char FastMathPropagationPass::ID;

namespace llvm {
void initializeFPScalarEvolutionPassPass(PassRegistry &);
void initializeFastMathPropagationPassPass(PassRegistry &);

Pass *createFPScalarEvolutionPass() { return new FPScalarEvolutionPass(); }
Pass *createFastMathPropagationPass() { return new FastMathPropagationPass(); }
} // namespace llvm

INITIALIZE_PASS_BEGIN(FPScalarEvolutionPass, "fp-scalar-evolution",
"Floating Point Scalar Evolution Analysis", false, true);
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
INITIALIZE_PASS_END(FPScalarEvolutionPass, "fp-scalar-evolution",
"Floating Point Scalar Evolution Analysis", false, true);

INITIALIZE_PASS_BEGIN(FastMathPropagationPass, "fast-math-propagation",
"Fast Math Propagation", false, false);
INITIALIZE_PASS_DEPENDENCY(FPScalarEvolutionPass);
INITIALIZE_PASS_END(FastMathPropagationPass, "fast-math-propagation",
"Fast Math Propagation", false, false);

0 comments on commit 514fd8e

Please sign in to comment.
You can’t perform that action at this time.