Skip to content

Commit

Permalink
Enable complex python tests (#1667)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed May 21, 2022
1 parent 4ceeee5 commit f4d3630
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 9 deletions.
47 changes: 43 additions & 4 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,unroll_with_rng'
os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
os.environ['PYTORCH_NVFUSER_ENABLE'] = 'complex'

if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_texpr_fuser_enabled(False)
Expand Down Expand Up @@ -159,7 +160,9 @@ def setUp(self):
torch.float16,
torch.float32,
torch.float64,
torch.bool
torch.bool,
torch.complex64,
torch.complex128,
]
if TEST_BF16:
self.support_tensor_dtypes.append(torch.bfloat16)
Expand Down Expand Up @@ -609,7 +612,9 @@ def test_unary_ops(self):
*self.int_types,
torch.float16,
torch.float32,
torch.float64
torch.float64,
torch.cfloat,
torch.cdouble,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
Expand Down Expand Up @@ -654,7 +659,10 @@ def test_unary_ops(self):
torch.tan,
torch.tanh,
torch.nn.functional.silu]
skip_complex = {torch.rsqrt, torch.reciprocal}
for op, dtype in itertools.product(operations, data_types):
if dtype.is_complex and op in skip_complex:
continue
self._unary_test_helper(op, dtype, False) # test special numbers
self._unary_test_helper(op, dtype, True) # test random data

Expand Down Expand Up @@ -760,12 +768,20 @@ def t_doublex_tensory(x: float, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o

def t_cdoublex_tensory(x: complex, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o

# Omit both scalar cases and swap cases
assert category1 == "scalar" and category2 != "scalar"
if dtype_arg1.is_floating_point:
return t_doublex_tensory
if dtype_arg1 == torch.int64 or dtype_arg1 == torch.int32:
return t_intx_tensory
if dtype_arg1.is_complex or dtype_arg1 == torch.int32:
return t_cdoublex_tensory
raise NotImplementedError

def _binary_test_helper(self, operation, dtypes, random_data, categories="ndim"):
Expand Down Expand Up @@ -905,6 +921,7 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
self.assertEqual(o.dtype, jit_o.dtype)
if test_value:
self.assertEqual(o, jit_o)
print(t_jit.graph_for(x, y))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
except Exception as e:
print("failing test for op: ", operation.__name__)
Expand All @@ -917,13 +934,12 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops(self):
# disabled bf16 / fp16 data types because of accuracy tolerance
data_types = [
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64
torch.float64,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
Expand Down Expand Up @@ -958,6 +974,29 @@ def test_binary_ops(self):
for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_complex(self):
data_types = [torch.cfloat, torch.cdouble]
operations = [torch.mul, torch.div, torch.pow, torch.eq, torch.ne]

category_types = [
"scalar",
"0dim",
"0dimcpu",
"ndim"
]

binary_dtype_combinations = list(itertools.combinations(data_types, 2))
category_combinations = list(itertools.combinations(category_types, 2))

for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations):
self._binary_test_helper(op, dtypes, True, categories) # random data

for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ void KernelArgumentHolder::push(const at::Tensor& tensor) {
changed_ = true;
if (is_cpu_scalar(tensor)) {
switch (tensor.scalar_type()) {
case c10::ScalarType::ComplexDouble:
arguments_.push_back(std::make_unique<CpuScalarTensorArg<
CpuScalarTensorCodegen<c10::complex<double>>>>(
tensor.data_ptr<c10::complex<double>>()[0]));
break;
case c10::ScalarType::ComplexFloat:
arguments_.push_back(std::make_unique<CpuScalarTensorArg<
CpuScalarTensorCodegen<c10::complex<float>>>>(
tensor.data_ptr<c10::complex<float>>()[0]));
break;
case c10::ScalarType::Double:
arguments_.push_back(
std::make_unique<
Expand Down
13 changes: 12 additions & 1 deletion torch/csrc/jit/codegen/cuda/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3254,7 +3254,18 @@ class IrParser {
}

bool registerScalar(const JitValue* val) {
if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(FloatType::get()))) {
if (val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(ComplexType::get()))) {
CgValue cg_val = nullptr;
if (auto ival = constant_as<c10::complex<double>>(val)) {
cg_val = IrBuilder::create<ComplexDouble>(ival.value());
} else {
cg_val = IrBuilder::create<ComplexDouble>();
}
value_map_.emplace(val->unique(), cg_val);
return true;
} else if (val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(FloatType::get()))) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CgValue cg_val;
if (auto ival = constant_as<double>(val)) {
Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/jit/codegen/cuda/partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,13 @@ bool compatibleType(const torch::jit::Value* val) {
DataType::Null) {
return false;
}
// Complex is disabled until its support is completely added
// TODO: remove this logic
if (isComplexType(aten_to_data_type(tensor_type->scalarType().value()))) {
return false;
if (!isEnabled(EnableOption::Complex)) {
// Complex is disabled by default until its support is completely added
// TODO: remove this logic
if (isComplexType(
aten_to_data_type(tensor_type->scalarType().value()))) {
return false;
}
}
}
// magic number 8 here since our kernel argument only supports rank <= 8
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,34 +760,52 @@ static const char* supported_casts2string(
case supported_switch_pair(DataType::Double, DataType::Float):
case supported_switch_pair(DataType::Bool, DataType::Float):
return "(float)";
case supported_switch_pair(DataType::ComplexFloat, DataType::Float):
case supported_switch_pair(DataType::ComplexDouble, DataType::Float):
return "(float)std::real";
case supported_switch_pair(DataType::Index, DataType::Int):
case supported_switch_pair(DataType::Int32, DataType::Int):
case supported_switch_pair(DataType::Float, DataType::Int):
case supported_switch_pair(DataType::Double, DataType::Int):
case supported_switch_pair(DataType::Bool, DataType::Int):
return "(int64_t)";
case supported_switch_pair(DataType::ComplexFloat, DataType::Int):
case supported_switch_pair(DataType::ComplexDouble, DataType::Int):
return "(int64_t)std::real";
case supported_switch_pair(DataType::Index, DataType::Int32):
case supported_switch_pair(DataType::Int, DataType::Int32):
case supported_switch_pair(DataType::Float, DataType::Int32):
case supported_switch_pair(DataType::Double, DataType::Int32):
case supported_switch_pair(DataType::Bool, DataType::Int32):
return "(int32_t)";
case supported_switch_pair(DataType::ComplexFloat, DataType::Int32):
case supported_switch_pair(DataType::ComplexDouble, DataType::Int32):
return "(int32_t)std::real";
case supported_switch_pair(DataType::Int, DataType::Index):
case supported_switch_pair(DataType::Int32, DataType::Index):
case supported_switch_pair(DataType::Float, DataType::Index):
case supported_switch_pair(DataType::Double, DataType::Index):
return "(nvfuser_index_t)";
case supported_switch_pair(DataType::ComplexFloat, DataType::Index):
case supported_switch_pair(DataType::ComplexDouble, DataType::Index):
return "(nvfuser_index_t)std::real";
case supported_switch_pair(DataType::Index, DataType::Double):
case supported_switch_pair(DataType::Int, DataType::Double):
case supported_switch_pair(DataType::Int32, DataType::Double):
case supported_switch_pair(DataType::Float, DataType::Double):
case supported_switch_pair(DataType::Bool, DataType::Double):
return "(double)";
case supported_switch_pair(DataType::ComplexFloat, DataType::Double):
case supported_switch_pair(DataType::ComplexDouble, DataType::Double):
return "(double)std::real";
case supported_switch_pair(DataType::Float, DataType::Bool):
case supported_switch_pair(DataType::Double, DataType::Bool):
case supported_switch_pair(DataType::Int32, DataType::Bool):
case supported_switch_pair(DataType::Int, DataType::Bool):
return "(bool)";
case supported_switch_pair(DataType::ComplexFloat, DataType::Bool):
case supported_switch_pair(DataType::ComplexDouble, DataType::Bool):
return "(bool)std::real";
case supported_switch_pair(DataType::Index, DataType::ComplexDouble):
case supported_switch_pair(DataType::Int, DataType::ComplexDouble):
case supported_switch_pair(DataType::Int32, DataType::ComplexDouble):
Expand Down
33 changes: 33 additions & 0 deletions torch/csrc/jit/codegen/cuda/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,34 @@ auto parseDisableOptions() {
return options_map;
}

auto parseEnableOptions() {
std::unordered_map<EnableOption, bool> options_map = {
{EnableOption::Complex, false}};

if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_ENABLE")) {
c10::string_view options_view(dump_options);
while (!options_view.empty()) {
const auto end_pos = options_view.find_first_of(',');
const auto token = options_view.substr(0, end_pos);
if (token == "complex") {
options_map[EnableOption::Complex] = true;
} else {
TORCH_CHECK(
false,
"Invalid disable option: '",
token,
"'\nAvailable options:\n",
"\tcomplex");
}
options_view = (end_pos != c10::string_view::npos)
? options_view.substr(end_pos + 1)
: "";
}
}

return options_map;
}

} // namespace

#pragma clang diagnostic push
Expand Down Expand Up @@ -240,6 +268,11 @@ bool isDisabled(DisableOption option) {
return options.at(option);
}

bool isEnabled(EnableOption option) {
const static auto options = parseEnableOptions();
return options.at(option);
}

bool useFallback() {
// Keep this env var for compatibility
const char* disable_fb_env = getenv("PYTORCH_NVFUSER_DISABLE_FALLBACK");
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ enum class DisableOption {

TORCH_CUDA_CU_API bool isDisabled(DisableOption option);

//! Types of features to enable
//!
//! These can be set through the `PYTORCH_NVFUSER_ENABLE` environment variable
//!
enum class EnableOption {
Complex, //! Enable complex support on python
};

TORCH_CUDA_CU_API bool isEnabled(EnableOption option);

// Check if fallback path should be used which will dispatch to eagermode if any
// errors are encountered. Helpful for debugging.
bool useFallback();
Expand Down

0 comments on commit f4d3630

Please sign in to comment.