Skip to content

Commit

Permalink
Device constant (#1593)
Browse files Browse the repository at this point in the history
Fixes #1331

Added an entry for device constant
  • Loading branch information
jjsjann123 committed Apr 18, 2022
1 parent d0cc32f commit 583bb01
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
23 changes: 23 additions & 0 deletions test/test_jit_cuda_fuser.py
Expand Up @@ -4539,6 +4539,29 @@ def clamp_min(x):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_device_constant(self):
x = torch.randn(4, 2, device="cuda")

def t(x):
return torch.rand_like(x, device=torch.device(type='cuda'))

# cpu tensor shouldn't be fused
def t_cpu(x):
return torch.rand_like(x, device=torch.device(type='cpu'))

with nvfuser_singleton_fusion(True):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)

t_cpu_jit = torch.jit.script(t_cpu)
for i in range(5):
t_cpu_jit(x)

self.assertGraphContainsExactly(t_cpu_jit.graph_for(x), FUSION_GUARD, 0)


class TestPassManagerCudaFuser(JitTestCase):
def setUp(self):
Expand Down
37 changes: 30 additions & 7 deletions torch/csrc/jit/codegen/cuda/parser.cpp
Expand Up @@ -1043,6 +1043,27 @@ class IrParser {
auto operand = list_val.front();
list_val.pop_front();

if (!node->input(3)->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
auto device = constant_as<c10::Device>(node->input(3));
TORCH_INTERNAL_ASSERT(
device.has_value() && device->is_cuda(),
"rand_like in nvfuser is not on cuda device");
auto input_tensor_type =
node->input(0)->type()->cast<TensorType>();
// device->index() == -1 indicating that we don't change device
// index
if (device->index() != -1 && input_tensor_type) {
auto input_device = input_tensor_type->device();
// we expect device index to be consistent with input and it
// should have already been handled by partition
TORCH_INTERNAL_ASSERT(
!input_device.has_value() ||
input_device->index() == device->index(),
"rand_like in nvfuser is not on cuda device");
}
}

auto out = randlike(operand);
value_map.emplace(node->output()->unique(), out);
},
Expand Down Expand Up @@ -2623,14 +2644,14 @@ class IrParser {
"aten::amax/amin cannot be fused with dynamic keepdim");

TensorView* out = nullptr;
if (node->kind() ==
c10::Symbol::fromQualString("aten::amax")) {
if (node->kind() == c10::Symbol::fromQualString("aten::amax")) {
out = max(self->as<TensorView>(), dims, keepdim.value());
} else if (node->kind() ==
c10::Symbol::fromQualString("aten::amin")) {
} else if (
node->kind() == c10::Symbol::fromQualString("aten::amin")) {
out = min(self->as<TensorView>(), dims, keepdim.value());
} else {
TORCH_INTERNAL_ASSERT(false, "unrecognized operation in aten::amax/amin");
TORCH_INTERNAL_ASSERT(
false, "unrecognized operation in aten::amax/amin");
}
value_map.emplace(node->output()->unique(), out);
},
Expand Down Expand Up @@ -2903,10 +2924,12 @@ class IrParser {
} else if (
val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(StringType::get())) ||
val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(DeviceObjType::get())) ||
val->type()->isSubtypeOf(static_cast<c10::TypePtr>(NoneType::get()))) {
// TODO: should we consider adding support for NoneType;
// String scalars are only used in parsing rules;
// Do not register string with codegen IR.
// Note: String/Device scalars are only used in parsing rules, do not
// register string with codegen IR.
return true;
} else if (val->type()->cast<ListType>()) {
// TODO: we don't support list type in codegen yet;
Expand Down

0 comments on commit 583bb01

Please sign in to comment.