Skip to content

Commit

Permalink
[te] Create TargetMachine only once with correct options to fix perf (#…
Browse files Browse the repository at this point in the history
…50406)

Summary:
Pull Request resolved: #50406

We were creating different TMs in PytorchLLVMJIT and LLVMCodeGen; the
one in LLVMCodeGen had the right target-specific options to generate fast AVX2
code (with FMAs, vbroadcastss, etc.), and that's what was showing up in the
debug output, but the LLVMJIT TM was the one that actually generated runtime
code, and it was slow.
ghstack-source-id: 119700110

Test Plan:
```
buck run mode/opt //caffe2/benchmarks/fb/tensorexpr:tensorexpr_bench
```

With this diff NNC is getting at least somewhat (5%) close to Pytorch with MKL,
for at least this one small-ish test case"

```
Run on (24 X 2394.67 MHz CPU s)
2021-01-11 15:57:27
----------------------------------------------------------------------------------------------------
Benchmark                                             Time           CPU Iterations UserCounters...
----------------------------------------------------------------------------------------------------
Gemm/Torch/128/128/128                            65302 ns      65289 ns      10734 GFLOPS=64.2423G/s
Gemm/TensorExprTile4x16VecUnroll/128/128/128      68602 ns      68599 ns      10256 GFLOPS=61.1421G/s
```

Reviewed By: bwasti

Differential Revision: D25877605

fbshipit-source-id: cd293bac94d025511f348eab5c9b8b16bf6505ec
  • Loading branch information
bertmaher authored and facebook-github-bot committed Jan 12, 2021
1 parent 7d28f1c commit cb37709
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 42 deletions.
47 changes: 7 additions & 40 deletions torch/csrc/jit/tensorexpr/llvm_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ class LLVMCodeGenImpl : public IRVisitor {
private:
std::unique_ptr<llvm::LLVMContext> context_;
llvm::IRBuilder<> irb_;
std::unique_ptr<llvm::TargetMachine> TM_;
std::unique_ptr<llvm::orc::PytorchLLVMJIT> jit_;
std::unique_ptr<llvm::Module> module_;
llvm::Function* fn_;
Expand Down Expand Up @@ -195,34 +194,6 @@ class LLVMCodeGenImpl : public IRVisitor {
} // namespace jit
} // namespace torch

static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() {
#if 0
// FIXME: Switch to using detectHost() rather than setting up the JTMB manually
// once LLVM 10 is available.
return assertSuccess(llvm::orc::JITTargetMachineBuilder::detectHost());
#else
llvm::orc::JITTargetMachineBuilder JTMB(
(llvm::Triple(llvm::sys::getProcessTriple())));

// Retrieve host CPU name and sub-target features and add them to builder.
// Relocation model, code model and codegen opt level are kept to default
// values.
llvm::SubtargetFeatures SubtargetFeatures;
llvm::StringMap<bool> FeatureMap;
llvm::sys::getHostCPUFeatures(FeatureMap);
for (auto& Feature : FeatureMap) {
SubtargetFeatures.AddFeature(Feature.first(), Feature.second);
}

JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default);
JTMB.setCPU(llvm::sys::getHostCPUName().str());
JTMB.addFeatures(SubtargetFeatures.getFeatures());
JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast;

return JTMB;
#endif
}

LLVMCodeGen::~LLVMCodeGen() = default;

LLVMCodeGen::LLVMCodeGen(Stmt* stmt)
Expand Down Expand Up @@ -307,13 +278,10 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();

auto JTMB = makeTargetMachineBuilder();
TM_ = assertSuccess(JTMB.createTargetMachine());

jit_ = std::make_unique<llvm::orc::PytorchLLVMJIT>();
module_ = std::make_unique<llvm::Module>("pytorch", getContext());
module_->setDataLayout(assertSuccess(JTMB.getDefaultDataLayoutForTarget()));
module_->setTargetTriple(JTMB.getTargetTriple().str());
module_->setDataLayout(jit_->getDataLayout());
module_->setTargetTriple(jit_->getTargetMachine().getTargetTriple().str());

// We support float16 ops by casting expr inputs to float32
// and then casting the result back to float16
Expand Down Expand Up @@ -536,7 +504,7 @@ void LLVMCodeGenImpl::emitKernel(
if (GRAPH_DEBUG_ENABLED) {
module_->print(asmStream, nullptr);
llvm::legacy::PassManager PM;
TM_->addPassesToEmitFile(
jit_->getTargetMachine().addPassesToEmitFile(
PM,
asmStream,
nullptr,
Expand Down Expand Up @@ -1863,16 +1831,15 @@ void LLVMCodeGenImpl::optimize(llvm::Module& M) {
llvm::legacy::PassManager PM;

// Add internal analysis passes from the target machine.
PM.add(
llvm::createTargetTransformInfoWrapperPass(TM_->getTargetIRAnalysis()));
FPM.add(
llvm::createTargetTransformInfoWrapperPass(TM_->getTargetIRAnalysis()));
auto& TM = jit_->getTargetMachine();
PM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis()));
FPM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis()));

llvm::PassManagerBuilder PMB;
PMB.OptLevel = 3;
PMB.LoopVectorize = true;
PMB.SLPVectorize = true;
TM_->adjustPassManager(PMB);
TM.adjustPassManager(PMB);

PMB.populateFunctionPassManager(FPM);
PMB.populateModulePassManager(PM);
Expand Down
50 changes: 48 additions & 2 deletions torch/csrc/jit/tensorexpr/llvm_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@ static llvm::JITTargetAddress toAddress(T* Ptr) {
return static_cast<llvm::JITTargetAddress>(reinterpret_cast<uintptr_t>(Ptr));
}

static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() {
#if 0
// FIXME: Switch to using detectHost() rather than setting up the JTMB manually
// once LLVM 10 is available.
return assertSuccess(llvm::orc::JITTargetMachineBuilder::detectHost());
#else
llvm::orc::JITTargetMachineBuilder JTMB(
(llvm::Triple(llvm::sys::getProcessTriple())));

// Retrieve host CPU name and sub-target features and add them to builder.
// Relocation model, code model and codegen opt level are kept to default
// values.
llvm::SubtargetFeatures SubtargetFeatures;
llvm::StringMap<bool> FeatureMap;
llvm::sys::getHostCPUFeatures(FeatureMap);
for (auto& Feature : FeatureMap) {
SubtargetFeatures.AddFeature(Feature.first(), Feature.second);
}

JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default);
JTMB.setCPU(llvm::sys::getHostCPUName().str());
JTMB.addFeatures(SubtargetFeatures.getFeatures());
JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast;

return JTMB;
#endif
}

static void registerIntrinsics(
llvm::orc::JITDylib& JD,
llvm::orc::MangleAndInterner& Mangle) {
Expand Down Expand Up @@ -189,10 +217,16 @@ namespace orc {
#if LLVM_VERSION_MAJOR >= 9 && LLVM_VERSION_MAJOR <= 12
class TORCH_API PytorchLLVMJITImpl {
private:
std::unique_ptr<TargetMachine> TM;
std::unique_ptr<LLJIT> LLJ;

public:
PytorchLLVMJITImpl() : LLJ(assertSuccess(LLJITBuilder().create())) {
PytorchLLVMJITImpl()
: TM(assertSuccess(makeTargetMachineBuilder().createTargetMachine())),
LLJ(assertSuccess(
LLJITBuilder()
.setJITTargetMachineBuilder(makeTargetMachineBuilder())
.create())) {
auto ProcSymbolsGenerator =
assertSuccess(DynamicLibrarySearchGenerator::GetForCurrentProcess(
LLJ->getDataLayout().getGlobalPrefix()));
Expand Down Expand Up @@ -222,6 +256,10 @@ class TORCH_API PytorchLLVMJITImpl {
return assertSuccess(LLJ->lookup(Name));
}

TargetMachine& getTargetMachine() {
return *TM;
}

const DataLayout& getDataLayout() {
return LLJ->getDataLayout();
}
Expand All @@ -242,6 +280,10 @@ JITSymbol PytorchLLVMJIT::findSymbol(const std::string Name) {
return impl_->findSymbol(std::move(Name));
}

TargetMachine& PytorchLLVMJIT::getTargetMachine() {
return impl_->getTargetMachine();
}

const DataLayout& PytorchLLVMJIT::getDataLayout() {
return impl_->getDataLayout();
}
Expand Down Expand Up @@ -278,7 +320,7 @@ class TORCH_API PytorchLLVMJITImpl {
[](Error Err) {
assertSuccess(std::move(Err), "lookupFlags failed");
})),
TM(EngineBuilder().selectTarget()),
TM(assertSuccess(makeTargetMachineBuilder().createTargetMachine())),
DL(TM->createDataLayout()),
ObjectLayer(
ES,
Expand Down Expand Up @@ -342,6 +384,10 @@ JITSymbol PytorchLLVMJIT::findSymbol(const std::string Name) {
return impl_->findSymbol(std::move(Name));
}

TargetMachine& PytorchLLVMJIT::getTargetMachine() {
return impl_->getTargetMachine();
}

const DataLayout& PytorchLLVMJIT::getDataLayout() {
return impl_->getDataLayout();
}
Expand Down

0 comments on commit cb37709

Please sign in to comment.