Skip to content

Commit

Permalink
[te] Fix casting of unsigned char, and abs(int)
Browse files Browse the repository at this point in the history
ghstack-source-id: 9f939c4a78b832da338d851725a3b3b77c409b9e
Pull Request resolved: #44157
  • Loading branch information
bertmaher committed Sep 3, 2020
1 parent b457879 commit a109d40
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
46 changes: 37 additions & 9 deletions test/test_jit_fuser_te.py
Expand Up @@ -1152,20 +1152,48 @@ def rand(dtype, device="cuda"):
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.bfloat16,
# torch.float16,
# torch.bfloat16,
torch.float32,
torch.float64,
torch.bool,
torch.complex32,
torch.complex64,
torch.complex128,
torch.qint8,
torch.quint8,
torch.qint32,
# torch.bool,
# torch.complex32,
# torch.complex64,
# torch.complex128,
# torch.qint8,
# torch.quint8,
# torch.qint32,
]
unary_ops = [
torch.sigmoid,
torch.reciprocal,
torch.neg,
torch.relu,
torch.log,
torch.log10,
torch.log2,
torch.exp,
torch.expm1,
torch.erf,
torch.erfc,
torch.cos,
torch.sin,
torch.tan,
torch.acos,
torch.asin,
torch.cosh,
torch.sinh,
torch.atan,
torch.tanh,
torch.sqrt,
torch.rsqrt,
torch.abs,
torch.ceil,
torch.floor,
torch.round,
torch.trunc,
torch.frac,
torch.lgamma,
]
devices = [
"cuda",
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Expand Up @@ -186,7 +186,7 @@ void CudaPrinter::visit(const For* v) {
}

void CudaPrinter::visit(const Cast* v) {
os() << cudaDtypeCppString(v->dtype());
os() << "(" << cudaDtypeCppString(v->dtype()) << ")";
os() << "(";
v->src_value()->accept(this);
os() << ")";
Expand All @@ -209,6 +209,9 @@ void CudaPrinter::visit(const Intrinsics* v) {
if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
func_name = func_name + "f";
}
if (v->op_type() == IntrinsicsOp::kFabs && is_integral(returnType)) {
func_name = "abs";
}

os() << func_name << "(";
for (int i = 0; i < v->nparams(); i++) {
Expand Down

0 comments on commit a109d40

Please sign in to comment.