Skip to content

Commit

Permalink
Emit more simplified LLVM IR for a number of vector expansions. (pyto…
Browse files Browse the repository at this point in the history
  • Loading branch information
resistor committed Mar 13, 2020
1 parent e8c1167 commit 341a4be
Showing 1 changed file with 54 additions and 18 deletions.
72 changes: 54 additions & 18 deletions torch/csrc/jit/tensorexpr/llvm_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,18 @@ void LLVMCodeGenImpl::visit(const Ramp* v) {
auto stride = this->value_;
int lanes = v->lanes();

if (llvm::ConstantInt* const_stride = llvm::dyn_cast<llvm::ConstantInt>(stride)) {
std::vector<llvm::Constant*> vals = { llvm::ConstantInt::get(base->getType(), 0) };
for (int i = 1; i < lanes; ++i) {
vals.push_back(llvm::ConstantExpr::getAdd(vals.back(), const_stride));
}

llvm::Value* offsets = llvm::ConstantVector::get(vals);
llvm::Value* splat = irb_.CreateVectorSplat(lanes, base);
value_ = irb_.CreateAdd(splat, offsets);
return;
}

llvm::Type* vecType = nullptr;
switch (v->dtype().scalar_type()) {
#define TYPE_CASE(_1, Name) \
Expand Down Expand Up @@ -798,18 +810,18 @@ llvm::Value* LLVMCodeGenImpl::emitMaskedLoad(
}

void LLVMCodeGenImpl::visit(const Load* v) {
v->base_handle()->accept(this);
auto base = this->value_;
v->index()->accept(this);
auto idx = this->value_;
v->mask()->accept(this);
auto mask = this->value_;

if (v->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
v->index()->accept(this);
auto idx = this->value_;

auto* maskimm = dynamic_cast<const IntImm*>(v->mask());
if (maskimm && maskimm->value() == 1) {
value_ = emitUnmaskedLoad(base, idx);
} else {
v->mask()->accept(this);
auto mask = this->value_;
value_ = emitMaskedLoad(base, idx, mask);
}
return;
Expand Down Expand Up @@ -843,7 +855,11 @@ void LLVMCodeGenImpl::visit(const Load* v) {
if (unmasked_load && idx_ramp) {
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
auto first_idx = irb_.CreateExtractElement(idx, uint64_t{0ULL});
v->base_handle()->accept(this);
auto base = this->value_;
idx_ramp->base()->accept(this);
auto first_idx = this->value_;

auto addr = irb_.CreateGEP(base, first_idx);
auto vaddr = irb_.CreateBitOrPointerCast(
addr, llvm::PointerType::get(loadType, 0));
Expand All @@ -853,6 +869,13 @@ void LLVMCodeGenImpl::visit(const Load* v) {
}

// Fallback to a scalar implementation
v->base_handle()->accept(this);
auto base = this->value_;
v->index()->accept(this);
auto idx = this->value_;
v->mask()->accept(this);
auto mask = this->value_;

llvm::Value* load = llvm::UndefValue::get(loadType);
for (int i = 0; i < v->dtype().lanes(); ++i) {
auto sub_idx = irb_.CreateExtractElement(idx, i);
Expand Down Expand Up @@ -952,22 +975,23 @@ void LLVMCodeGenImpl::emitMaskedStore(
}

void LLVMCodeGenImpl::visit(const Store* v) {
v->base_handle()->accept(this);
auto base = this->value_;
v->index()->accept(this);
auto idx = this->value_;
v->mask()->accept(this);
auto mask = this->value_;
v->value()->accept(this);
auto val = this->value_;

value_ = llvm::ConstantInt::get(IntTy_, 0);

if (v->value()->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
v->index()->accept(this);
auto idx = this->value_;
v->value()->accept(this);
auto val = this->value_;

auto* maskimm = dynamic_cast<const IntImm*>(v->mask());
if (maskimm && maskimm->value() == 1) {
emitUnmaskedStore(base, idx, val);
} else {
v->mask()->accept(this);
auto mask = this->value_;

emitMaskedStore(base, idx, mask, val);
}
return;
Expand All @@ -983,12 +1007,19 @@ void LLVMCodeGenImpl::visit(const Store* v) {
}
}

v->base_handle()->accept(this);
auto base = this->value_;
v->value()->accept(this);
auto val = this->value_;

// Handle the case where the store is contiguous and unmasked efficiently
auto* idx_ramp = dynamic_cast<const Ramp*>(v->index());
if (unmasked_store && idx_ramp) {
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
auto first_idx = irb_.CreateExtractElement(idx, uint64_t{0});
idx_ramp->base()->accept(this);
auto first_idx = value_;

auto addr = irb_.CreateGEP(base, first_idx);
auto vaddr = irb_.CreateBitOrPointerCast(
addr, llvm::PointerType::get(val->getType(), 0));
Expand All @@ -997,6 +1028,11 @@ void LLVMCodeGenImpl::visit(const Store* v) {
}
}

v->index()->accept(this);
auto idx = this->value_;
v->mask()->accept(this);
auto mask = this->value_;

// Fallback to a scalar implementation
for (int i = 0; i < v->value()->dtype().lanes(); ++i) {
auto sub_idx = irb_.CreateExtractElement(idx, i);
Expand Down

0 comments on commit 341a4be

Please sign in to comment.