Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[te] Fix casting of unsigned char, and abs(int) #44157

Closed
wants to merge 7 commits into from
49 changes: 40 additions & 9 deletions test/test_jit_fuser_te.py
Expand Up @@ -87,7 +87,7 @@ def get_nodes_and_parents_recursively(block, kind, acc):
acc[block].append(node)
elif node.kind() == 'prim::DifferentiableGraph':
get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc)
elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
node.inputs().__next__().node().kind() == 'prim::TypeCheck'):
get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc)
else:
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def rand(dtype, device="cuda"):
return torch.rand(shape, dtype=dtype, device=device)
else:
# dtype is an integer.
return torch.randint(0, 100, shape, dtype=dtype, device=device)
return torch.randint(1, 4, shape, dtype=dtype, device=device)
raise RuntimeError("Unhandled dtype")

dtypes = [
Expand All @@ -1160,12 +1160,12 @@ def rand(dtype, device="cuda"):
torch.int16,
torch.int32,
torch.int64,
torch.float16,
# torch.float16,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe add a comment whether we want these to be supported at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah, float16 should be coming in soon-ish. The commented lines are a transient state for this stack :-)

# torch.bfloat16,
torch.float32,
torch.float64,
torch.bfloat16,
torch.bool,
torch.complex32,
# torch.bool,
# torch.complex32,
# torch.complex64,
# torch.complex128,
# torch.qint8,
Expand All @@ -1175,6 +1175,32 @@ def rand(dtype, device="cuda"):
unary_ops = [
torch.sigmoid,
torch.reciprocal,
torch.neg,
torch.relu,
torch.log,
torch.log10,
torch.log2,
torch.exp,
torch.expm1,
torch.erf,
torch.erfc,
torch.cos,
torch.sin,
torch.tan,
torch.acos,
torch.asin,
torch.cosh,
torch.sinh,
torch.atan,
torch.tanh,
torch.sqrt,
torch.rsqrt,
torch.abs,
torch.ceil,
torch.floor,
torch.round,
torch.trunc,
torch.frac,
]
devices = [
"cuda",
Expand All @@ -1189,9 +1215,14 @@ def rand(dtype, device="cuda"):
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
t = torch.jit.trace(fn, (x,))
self.assertEqual(ref, t(x))
self.assertAllFused(t.graph_for(x))
try:
t = torch.jit.trace(fn, (x,))
torch.testing.assert_allclose(ref, t(x))
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(" ".join([
"Failed:", str(dtype), op.__name__, device
]))

if __name__ == '__main__':
run_tests()
7 changes: 5 additions & 2 deletions torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Expand Up @@ -110,7 +110,7 @@ std::string cudaDtypeCppString(const Dtype& dtype) {
case ScalarType::Short:
return "short";
case ScalarType::Long:
return "long";
return "long long";
default:
return dtype.ToCppString();
}
Expand Down Expand Up @@ -198,7 +198,7 @@ void CudaPrinter::visit(const Cast* v) {
return;
}

os() << cudaDtypeCppString(v->dtype());
os() << "(" << cudaDtypeCppString(v->dtype()) << ")";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch, how did it work? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know when it became a thing but apparently you can cast by saying float(x), like it's a constructor. But that sorta doesn't parse if you do unsigned char(x). Clearly we need whitespace in identifiers! (I hope this reddit post is a joke: https://www.reddit.com/r/ProgrammerHumor/comments/4on81i/til_c_allows_u200b_zero_width_space_in_identifiers/ )

os() << "(";
v->src_value()->accept(this);
os() << ")";
Expand All @@ -221,6 +221,9 @@ void CudaPrinter::visit(const Intrinsics* v) {
if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
func_name = func_name + "f";
}
if (v->op_type() == IntrinsicsOp::kFabs && is_integral(returnType)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does nvrtc not handle std::abs or ::abs? All this exp/expf is so last decade (and is not used in eager part of the codebase at all).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does, actually. I don't know why we went the + "f" route, except that maybe the old fuser uses it? I'll take a note to use std:: in a follow-up.

func_name = "abs";
}

os() << func_name << "(";
for (int i = 0; i < v->nparams(); i++) {
Expand Down