diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index d0db3bcba83bb1..4b4362c77fd97b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -68,6 +68,7 @@ cc_library( "passes/restore_function_name.cc", "passes/unfuse_mhlo_batch_norm.cc", "passes/unwrap_xla_call_module_op.cc", + "passes/xla_call_module_to_call.cc", ], hdrs = [ "passes/passes.h", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 63f6f822dbebdf..6e76fe15307f23 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -130,6 +130,13 @@ def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> { ]; } +def XlaCallModuleToCallPass : Pass<"stablehlo-xla-call-module-to-call", "ModuleOp"> { + let summary = "Convert XlaCallModuleOp to func.call op"; + let dependentDialects = [ + "TF::TensorFlowDialect", + ]; +} + def UnwrapXlaCallModuleOpPass : Pass<"stablehlo-unwrap-xla-call-module-op", "ModuleOp"> { let summary = "Unwrap XlaCallModuleOps into inline functions if not used for quantizing fused patterns."; let dependentDialects = ["TF::TensorFlowDialect"]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index f3cf92dde359d1..8d3290713f2cc7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -98,6 +98,11 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { pm.addPass(createQuantizePass(quantize_options)); pm.addNestedPass(createPostQuantizePass()); + // Convert XlaCallModuleOps lifted but not quantized to func.call op. + // The reasons these ops are not quantized may be: + // 1. Disabled due to selective quantization. + // 2. Not supported, e.g. add op for server. + pm.addPass(createXlaCallModuleToCallPass()); ModuleOp module_op = getOperation(); if (const absl::Status pm_run_status = RunPassesOnModuleOp(mlir_dump_file_name_, pm, module_op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc new file mode 100644 index 00000000000000..123244db3b7dbb --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_XLACALLMODULETOCALLPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleToCallPass + : public impl::XlaCallModuleToCallPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaCallModuleToCallPass) + + explicit XlaCallModuleToCallPass() = default; + + private: + void runOnOperation() override; +}; + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { + auto module_op = op->getParentOfType(); + SymbolTable symbol_table(module_op); + + auto entry_func_op = dyn_cast_or_null( + symbol_table.lookup(GetEntryFunctionName(op))); + if (!entry_func_op) return failure(); + + // Replace the XlaCallModuleOp with a new CallOp. + rewriter.replaceOpWithNewOp(op, entry_func_op, op.getArgs()); + return success(); + } +}; + +void XlaCallModuleToCallPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + RewritePatternSet patterns(&getContext()); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 1d6c4b3ea219b5..31c2a987209f1d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -1179,6 +1179,126 @@ def test_conv_weight_only_model( 0.35, ) + @parameterized.parameters( + testing.parameter_combinations([{ + 'shape_dynamic': ( + False, + True, + ), + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_add_ptq_model( + self, + shape_dynamic: bool, + ): + input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3) + self._create_add_model( + input_shape, + self._input_saved_model_path, + ) + + # Generate model input data. + rng = np.random.default_rng(seed=42) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ], + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + self.assertEqual( + self._get_num_xla_call_module_op(self._output_saved_model_path), 1 + ) + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + + # Check add is not quantized. + self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str)) + + @parameterized.parameters( + testing.parameter_combinations([{ + 'shape_dynamic': ( + False, + True, + ), + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_add_weight_only_model( + self, + shape_dynamic: bool, + ): + input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3) + self._create_add_model( + input_shape, + self._input_saved_model_path, + ) + + # Generate model input data. + rng = np.random.default_rng(seed=42) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + weight_only_ptq_preset=qc.WeightOnlyPtqPreset(), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + self.assertEqual( + self._get_num_xla_call_module_op(self._output_saved_model_path), 1 + ) + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + + # Check add is not quantized. + self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str), module_str) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py index 05d1676114765a..15c141681cb917 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -73,6 +73,20 @@ def _extract_first_xla_call_module_op( return str(stablehlo_module) raise ValueError('No XlaCallModule found in saved model.') + def _get_num_xla_call_module_op(self, output_saved_model_path: str) -> int: + """Gets the number of XlaCallModule ops in the output saved model.""" + root = load.load(output_saved_model_path) + tf_graph_def = root.signatures['serving_default'].graph.as_graph_def() + count = 0 + for node_def in tf_graph_def.node: + if node_def.op == 'XlaCallModule': + count += 1 + for function in tf_graph_def.library.function: + for node_def in function.node_def: + if node_def.op == 'XlaCallModule': + count += 1 + return count + def _create_matmul_model( self, input_shape: Sequence[int], @@ -339,6 +353,42 @@ def __call__( return GatherModel(use_variable) + def _create_add_model( + self, + shape: Sequence[int], + saved_model_path: str, + ) -> module.Module: + class AddModel(module.Module): + """A simple model with a single add.""" + + def __init__(self): + pass + + @def_function.function + def add(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs an add operation. + + Args: + input_tensor: Input tensor to perform add on. + + Returns: + A map of: output key -> output result. + """ + out = math_ops.add(input_tensor, input_tensor) + return {'output': out} + + model = AddModel() + saved_model_save.save( + model, + saved_model_path, + signatures=model.add.get_concrete_function( + tensor_spec.TensorSpec( + shape=shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + return model + # Prepares sample einsum input data shapes. # This function returns: # 1. Shape for input 1 diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir index 317da0b762e60d..a34ac1a1c65ffc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir @@ -69,11 +69,6 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: } // CHECK-LABEL: func.func @main // CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32> - // CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<{{.*}}> : tensor<1024x3xf32> -// CHECK: "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) - -// CHECK: func.func private @composite_dot_general_fn_1 -// CHECK-SAME: attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general -// CHECK-SAME: contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> +// CHECK: stablehlo.dot_general %[[ARG_0]], %[[CONST_0]] +// CHECK-NOT: tf.XlaCallModule diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir index eba63a1aacaddc..4fb8e6a9dc6c43 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir @@ -715,7 +715,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. +// Tests that XlaCallModule op is not quantized and converted to func.call without the quantfork.stats ops. module attributes {tf_saved_model.semantics} { func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { @@ -728,8 +728,8 @@ module attributes {tf_saved_model.semantics} { // CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} // CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] +// CHECK: %[[CALL:.+]] = call @composite_dot_general_fn(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/xla_call_module_to_call.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/xla_call_module_to_call.mlir new file mode 100644 index 00000000000000..f0330d0266d56d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/xla_call_module_to_call.mlir @@ -0,0 +1,23 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-xla-call-module-to-call | FileCheck %s + +// ----- + +// Tests composite tf.XlaCallModule is converted to func.call. + +module { + // CHECK-LABEL: func.func @main + func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { + // CHECK: call @composite_dot_general_fn_1 + // CHECK-SAME: (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + // CHECK-NOT: tf.XlaCallModule + %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_fn_1 + // CHECK-SAME: -> tensor<1x3xf32> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 5dada7297e8544..23c801fd912ecf 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -167,10 +167,10 @@ HloComputation::~HloComputation() { CHECK(async_start_->async_wrapped_computation() == this); async_start_->ClearCalledComputations(); } + Cleanup(); for (const auto& i : instructions_) { delete i.inst(); } - Cleanup(); } void HloComputation::SetInstruction(HloInstruction* instruction, @@ -472,10 +472,65 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, info->inst_ = nullptr; // Leave a hole: this is no longer part of "instructions()" instruction_indices_.erase(inst_it); - instruction->index_in_parent_ = ~0u; + DCHECK_EQ(instructions_.size() - to_be_deleted_.size(), + instruction_indices_.size()) + << "instructions_.size(): " << instructions_.size() + << ", to_be_deleted_.size(): " << to_be_deleted_.size(); return OkStatus(); } +void HloComputation::Cleanup() { + if (to_be_deleted_.empty()) return; + + // Given that there are instructions to be deleted, there must be at least one + // instruction not marked for deletion. Otherwise we have deleted *all* + // instructions, which is probably a bug. + DCHECK(!instruction_indices_.empty()); + + // Replacement, i.e. the rightmost "unmarked" (a.k.a. not marked for deletion) + // entry in the vector. + HloInstructionInfo* replacement = &instructions_.back(); + for (HloInstruction* marked_instruction : to_be_deleted_) { + int marked_index = marked_instruction->index_in_parent_; + HloInstructionInfo* marked = &instructions_[marked_index]; + DCHECK(marked->inst() == nullptr); + + delete marked_instruction; + + // Find the first unmarked entry to the left of 'replacement', if needed. + while (replacement >= instructions_.data() && + replacement->inst() == nullptr) { + --replacement; + } + DCHECK_GE(replacement, instructions_.data()); + + // Nothing to do if 'marked' is already to the right of 'replacement'. + if (marked > replacement) continue; + + // Replace the marked entry with the unmarked one. + HloInstruction* unmarked_instruction = replacement->inst(); + int unmarked_index = marked_index; + // Small optimization: instead of std::swap(), just overwrite *marked. This + // requires us to also decrement 'replacement' to avoid reusing the + // unmarked entry we just copied. + *marked = *replacement; + --replacement; + + // Update reverse mapping. + auto it = instruction_indices_.find(unmarked_instruction); + DCHECK(it != instruction_indices_.end()); + it->second = unmarked_index; + unmarked_instruction->index_in_parent_ = unmarked_index; + } + + DCHECK_EQ(instructions_.size() - to_be_deleted_.size(), + instruction_indices_.size()) + << "instructions_.size(): " << instructions_.size() + << ", to_be_deleted_.size(): " << to_be_deleted_.size(); + to_be_deleted_.clear(); + instructions_.resize(instruction_indices_.size()); +} + void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, bool accept_different_shape) { // The shape of the root (ignoring layout) is an invariant of the computation diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index ec7d4b31c4dccf..1a27a8c4378b12 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -841,16 +841,13 @@ class HloComputation { return execution_thread_ == HloInstruction::kMainExecutionThread; } - // Deallocate instructions that are marked by "RemoveInstruction". The two - // stage clean up process is designed such that HloPass can have stable - // internal pointers to HloInstructions while we create and remove + // Deallocates instructions that are marked by "RemoveInstruction" and + // compacts the instructions_ vector by removing the deleted instructions' + // entries (a.k.a. tombstones). + // This two-stage clean up process is designed such that HloPass can have + // stable internal pointers to HloInstructions while we create and remove // HloInstructions in a pass. - void Cleanup() { - for (HloInstruction* it : to_be_deleted_) { - delete it; - } - to_be_deleted_.clear(); - } + void Cleanup(); // Returns true if a given instruction is marked dead in this computation. bool IsMarkedAsDead(const HloInstruction* inst); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 5a808ce2f6d7c2..115c27989a38e5 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -49,7 +49,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 253 +_version = 254 # Version number for MLIR:Python components. mlir_api_version = 55 diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index af3cc23abd5725..4f69d8d53d2486 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -974,21 +974,65 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComplexLog(op, operand_value); } case HloOpcode::kLog1p: { - // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) - // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) - // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) + // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) + // + // that is accurate only when |a| is relatively small while + // large |a| and |b| lead to multiplication overflow in the real + // part. + // + // The following expression for the real part: + // + // log1p(a+bi).real = log(hypot(a+1, b)) + // = log(max(|a+1|, |b|) * sqrt(1 + (min(|a+1|, |b|) / + // max(|a+1|, b))^2)) [to fix overflow for maximal values + // of |a+1| and |b|] = log(max(|a+1|, |b|)) + log(sqrt(1 + // + (min(|a+1|, |b|) / max(|a+1|, b))^2)) = + // log(max(|a+1|, |b|)) + 0.5*log1p((min(|a+1|, |b|) / + // max(|a+1|, b))^2) [to fix inaccuracies for small a, + // we'll use log1p] = log1p((1 + a > |b| ? a : max(|a+1|, + // |b|) - 1) + 0.5*log1p((min(|a+1|, |b|) / max(|a+1|, + // b))^2) + // + // is accurate on the whole complex plane except when |b| is + // small and a is very close to -|b|^2/2 that leads to + // substraction errors when adding the two log1p values as in + // log1p(-|b|^2) + log1p(|b|^2) + // TODO: improve the accuracy for the case above. + auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto two = llvm::ConstantFP::get(llvm_ty, 2.0); - auto a_plus_one = FAdd(a, one); - auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b)); - TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq)); - TF_ASSIGN_OR_RETURN(auto angle, - EmitAtan2(component_type, b, a_plus_one, "")); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); + auto half = llvm::ConstantFP::get(llvm_ty, 0.5); + + auto a1 = FAdd(a, one); + auto abs_a1 = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a1}, + {llvm_ty}, b_); + auto abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b}, + {llvm_ty}, b_); + + auto max_abs_of_a1_and_b = EmitFloatMax(abs_a1, abs_b, ""); + auto min_abs_of_a1_and_b = EmitFloatMin(abs_a1, abs_b, ""); + + auto max_abs_of_a1_and_b_minus_one = + Select(FCmpOGT(a1, abs_b), a, FSub(max_abs_of_a1_and_b, one)); + auto min_max_ratio = FDiv(min_abs_of_a1_and_b, max_abs_of_a1_and_b); + + TF_ASSIGN_OR_RETURN( + auto log_of_max_abs_of_a1_and_b, + EmitLog1p(component_type, max_abs_of_a1_and_b_minus_one)); + TF_ASSIGN_OR_RETURN( + auto log_of_sqrt_part, + EmitLog1p(component_type, FMul(min_max_ratio, min_max_ratio))); + + auto r = FAdd(FMul(half, log_of_sqrt_part), log_of_max_abs_of_a1_and_b); + auto real_part = Select(FCmpUNO(r, r), min_abs_of_a1_and_b, + r); // handles nan and inf values correctly + + TF_ASSIGN_OR_RETURN(auto imag_part, EmitAtan2(component_type, b, a1, "")); + return EmitComposeComplex(op, real_part, imag_part); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index d7e22827e26ddf..a914b894a93e44 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -700,6 +700,29 @@ xla_test( ], ) +xla_test( + name = "complex_unary_op_test", + srcs = [ + "complex_unary_op_samples.h", + "complex_unary_op_test.cc", + ], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":client_library_test_base", + ":literal_test_util", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:xla_data_proto_cc", + "//xla/client:global_data", + "//xla/client:local_client", + "//xla/client:xla_builder", + "@local_tsl//tsl/platform:test", + ], +) + xla_test( name = "scalar_computations_test", srcs = ["scalar_computations_test.cc"], diff --git a/third_party/xla/xla/tests/complex_unary_op_samples.h b/third_party/xla/xla/tests/complex_unary_op_samples.h new file mode 100644 index 00000000000000..edceb7829c44c5 --- /dev/null +++ b/third_party/xla/xla/tests/complex_unary_op_samples.h @@ -0,0 +1,1448 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + This file is generated using xla/tests/generate_complex_unary_op_samples.py. + Do not edit! + */ + +#include +#include +#include +#include +#include + +#ifndef XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ +#define XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ + +namespace complex_unary_op_samples { + +template +struct Log1p { + typedef std::complex InputType; + typedef std::complex OutputType; + typedef T FloatType; + using TableType = std::vector>; + static constexpr int dps_deficiency = default_dps_deficiency; + const TableType get() { + const T inf = std::numeric_limits::infinity(); + const T min = std::numeric_limits::min(); + const T max = std::numeric_limits::max(); + if constexpr (std::is_same_v) { + const T pi = 3.1415927f; + const T pi_4 = 0.7853982f; + const T pi_2 = 1.5707964f; + const T pi3_4 = 2.3561945f; + const T zero = 0.0f; + const TableType table{ + /* 0 */ {{-inf, -inf}, {inf, -pi3_4}, 1.e+00f}, + /* 1 */ {{-max, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 2 */ {{-6.14096e+25f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 3 */ {{-1.108238e+13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 4 */ {{-2.e+00f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 5 */ {{-3.609332e-13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 6 */ {{-6.513639e-26f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 7 */ {{-min, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 8 */ {{zero, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 9 */ {{min, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 10 */ {{6.513639e-26f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 11 */ {{3.609332e-13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 12 */ {{2.e+00f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 13 */ {{1.108238e+13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 14 */ {{6.14096e+25f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 15 */ {{max, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 16 */ {{inf, -inf}, {inf, -pi_4}, 1.e+00f}, + /* 17 */ {{-inf, -max}, {inf, -pi}, 1.e+00f}, + /* 18 */ {{-max, -max}, {8.906941e+01f, -pi3_4}, 7.8125e-03f}, + /* 19 */ {{-6.14096e+25f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 20 */ + {{-1.108238e+13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 21 */ {{-2.e+00f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 22 */ + {{-3.609332e-13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 23 */ + {{-6.513639e-26f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 24 */ {{-min, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 25 */ {{zero, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 26 */ {{min, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 27 */ {{6.513639e-26f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 28 */ {{3.609332e-13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 29 */ {{2.e+00f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 30 */ {{1.108238e+13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 31 */ {{6.14096e+25f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 32 */ {{max, -max}, {8.906941e+01f, -pi_4}, 7.8125e-03f}, + /* 33 */ {{inf, -max}, {inf, zero}, 1.e+00f}, + /* 34 */ {{-inf, -6.14096e+25f}, {inf, -pi}, 1.e+00f}, + /* 35 */ {{-max, -6.14096e+25f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 36 */ + {{-6.14096e+25f, -6.14096e+25f}, + {5.972618e+01f, -pi3_4}, + 1.5625e-02f}, + /* 37 */ + {{-1.108238e+13f, -6.14096e+25f}, + {5.937961e+01f, -pi_2}, + 1.5625e-02f}, + /* 38 */ + {{-2.e+00f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 39 */ + {{-3.609332e-13f, -6.14096e+25f}, + {5.937961e+01f, -pi_2}, + 1.5625e-02f}, + /* 40 */ + {{-6.513639e-26f, -6.14096e+25f}, + {5.937961e+01f, -pi_2}, + 1.5625e-02f}, + /* 41 */ {{-min, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 42 */ {{zero, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 43 */ {{min, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 44 */ + {{6.513639e-26f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 45 */ + {{3.609332e-13f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 46 */ + {{2.e+00f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 47 */ + {{1.108238e+13f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 48 */ + {{6.14096e+25f, -6.14096e+25f}, {5.972618e+01f, -pi_4}, 1.5625e-02f}, + /* 49 */ + {{max, -6.14096e+25f}, {8.872284e+01f, -1.804666e-13f}, 7.8125e-03f}, + /* 50 */ {{inf, -6.14096e+25f}, {inf, zero}, 1.e+00f}, + /* 51 */ {{-inf, -1.108238e+13f}, {inf, -pi}, 1.e+00f}, + /* 52 */ {{-max, -1.108238e+13f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 53 */ + {{-6.14096e+25f, -1.108238e+13f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 54 */ + {{-1.108238e+13f, -1.108238e+13f}, + {3.038295e+01f, -pi3_4}, + 3.125e-02f}, + /* 55 */ + {{-2.e+00f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 56 */ + {{-3.609332e-13f, -1.108238e+13f}, + {3.003638e+01f, -pi_2}, + 3.125e-02f}, + /* 57 */ + {{-6.513639e-26f, -1.108238e+13f}, + {3.003638e+01f, -pi_2}, + 3.125e-02f}, + /* 58 */ {{-min, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 59 */ {{zero, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 60 */ {{min, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 61 */ + {{6.513639e-26f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 62 */ + {{3.609332e-13f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 63 */ + {{2.e+00f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 64 */ + {{1.108238e+13f, -1.108238e+13f}, {3.038295e+01f, -pi_4}, 3.125e-02f}, + /* 65 */ + {{6.14096e+25f, -1.108238e+13f}, + {5.937961e+01f, -1.804666e-13f}, + 1.5625e-02f}, + /* 66 */ + {{max, -1.108238e+13f}, {8.872284e+01f, -3.25682e-26f}, 7.8125e-03f}, + /* 67 */ {{inf, -1.108238e+13f}, {inf, zero}, 1.e+00f}, + /* 68 */ {{-inf, -2.e+00f}, {inf, -pi}, 1.e+00f}, + /* 69 */ {{-max, -2.e+00f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 70 */ + {{-6.14096e+25f, -2.e+00f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 71 */ + {{-1.108238e+13f, -2.e+00f}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 72 */ + {{-2.e+00f, -2.e+00f}, {8.04719e-01f, -2.034444e+00f}, 2.5e-01f}, + /* 73 */ + {{-3.609332e-13f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 74 */ + {{-6.513639e-26f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 75 */ {{-min, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 76 */ {{zero, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 77 */ {{min, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 78 */ + {{6.513639e-26f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 79 */ + {{3.609332e-13f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 80 */ + {{2.e+00f, -2.e+00f}, {1.282475e+00f, -5.880026e-01f}, 5.e-01f}, + /* 81 */ + {{1.108238e+13f, -2.e+00f}, + {3.003638e+01f, -1.804666e-13f}, + 3.125e-02f}, + /* 82 */ + {{6.14096e+25f, -2.e+00f}, + {5.937961e+01f, -3.25682e-26f}, + 1.5625e-02f}, + /* 83 */ + {{max, -2.e+00f}, {8.872284e+01f, -5.877472e-39f}, 7.8125e-03f}, + /* 84 */ {{inf, -2.e+00f}, {inf, zero}, 1.e+00f}, + /* 85 */ {{-inf, -3.609332e-13f}, {inf, -pi}, 1.e+00f}, + /* 86 */ {{-max, -3.609332e-13f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 87 */ + {{-6.14096e+25f, -3.609332e-13f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 88 */ + {{-1.108238e+13f, -3.609332e-13f}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 89 */ {{-2.e+00f, -3.609332e-13f}, {6.513639e-26f, -pi}, 2.5e-01f}, + /* 90 */ + {{-3.609332e-13f, -3.609332e-13f}, + {-3.609332e-13f, -3.609332e-13f}, + 1.099512e+12f}, + /* 91 */ + {{-6.513639e-26f, -3.609332e-13f}, + {-2.843711e-33f, -3.609332e-13f}, + 2.199023e+12f}, + /* 92 */ + {{-min, -3.609332e-13f}, + {6.513639e-26f, -3.609332e-13f}, + 2.199023e+12f}, + /* 93 */ + {{zero, -3.609332e-13f}, + {6.513639e-26f, -3.609332e-13f}, + 2.199023e+12f}, + /* 94 */ + {{min, -3.609332e-13f}, + {6.513639e-26f, -3.609332e-13f}, + 2.199023e+12f}, + /* 95 */ + {{6.513639e-26f, -3.609332e-13f}, + {1.302728e-25f, -3.609332e-13f}, + 2.199023e+12f}, + /* 96 */ + {{3.609332e-13f, -3.609332e-13f}, + {3.609332e-13f, -3.609332e-13f}, + 1.099512e+12f}, + /* 97 */ + {{2.e+00f, -3.609332e-13f}, {1.098612e+00f, -1.203111e-13f}, 5.e-01f}, + /* 98 */ + {{1.108238e+13f, -3.609332e-13f}, + {3.003638e+01f, -3.256819e-26f}, + 3.125e-02f}, + /* 99 */ + {{6.14096e+25f, -3.609332e-13f}, + {5.937961e+01f, -5.877472e-39f}, + 1.5625e-02f}, + /* 100 */ {{max, -3.609332e-13f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 101 */ {{inf, -3.609332e-13f}, {inf, zero}, 1.e+00f}, + /* 102 */ {{-inf, -6.513639e-26f}, {inf, -pi}, 1.e+00f}, + /* 103 */ {{-max, -6.513639e-26f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 104 */ + {{-6.14096e+25f, -6.513639e-26f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 105 */ + {{-1.108238e+13f, -6.513639e-26f}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 106 */ {{-2.e+00f, -6.513639e-26f}, {zero, -pi}, 2.5e-01f}, + /* 107 */ + {{-3.609332e-13f, -6.513639e-26f}, + {-3.609332e-13f, -6.513639e-26f}, + 2.199023e+12f}, + /* 108 */ + {{-6.513639e-26f, -6.513639e-26f}, + {-6.513639e-26f, -6.513639e-26f}, + 9.671407e+24f}, + /* 109 */ + {{-min, -6.513639e-26f}, {-min, -6.513639e-26f}, 9.671407e+24f}, + /* 110 */ + {{zero, -6.513639e-26f}, {zero, -6.513639e-26f}, 9.671407e+24f}, + /* 111 */ + {{min, -6.513639e-26f}, {min, -6.513639e-26f}, 9.671407e+24f}, + /* 112 */ + {{6.513639e-26f, -6.513639e-26f}, + {6.513639e-26f, -6.513639e-26f}, + 9.671407e+24f}, + /* 113 */ + {{3.609332e-13f, -6.513639e-26f}, + {3.609332e-13f, -6.513639e-26f}, + 2.199023e+12f}, + /* 114 */ + {{2.e+00f, -6.513639e-26f}, {1.098612e+00f, -2.171213e-26f}, 5.e-01f}, + /* 115 */ + {{1.108238e+13f, -6.513639e-26f}, + {3.003638e+01f, -5.877472e-39f}, + 3.125e-02f}, + /* 116 */ + {{6.14096e+25f, -6.513639e-26f}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 117 */ {{max, -6.513639e-26f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 118 */ {{inf, -6.513639e-26f}, {inf, zero}, 1.e+00f}, + /* 119 */ {{-inf, -min}, {inf, -pi}, 1.e+00f}, + /* 120 */ {{-max, -min}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 121 */ {{-6.14096e+25f, -min}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 122 */ {{-1.108238e+13f, -min}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 123 */ {{-2.e+00f, -min}, {zero, -pi}, 2.5e-01f}, + /* 124 */ + {{-3.609332e-13f, -min}, {-3.609332e-13f, -min}, 2.199023e+12f}, + /* 125 */ + {{-6.513639e-26f, -min}, {-6.513639e-26f, -min}, 9.671407e+24f}, + /* 126 */ {{-min, -min}, {-min, -min}, 4.25353e+37f}, + /* 127 */ {{zero, -min}, {zero, -min}, 4.25353e+37f}, + /* 128 */ {{min, -min}, {min, -min}, 4.25353e+37f}, + /* 129 */ + {{6.513639e-26f, -min}, {6.513639e-26f, -min}, 9.671407e+24f}, + /* 130 */ + {{3.609332e-13f, -min}, {3.609332e-13f, -min}, 2.199023e+12f}, + /* 131 */ {{2.e+00f, -min}, {1.098612e+00f, zero}, 5.e-01f}, + /* 132 */ {{1.108238e+13f, -min}, {3.003638e+01f, zero}, 3.125e-02f}, + /* 133 */ {{6.14096e+25f, -min}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 134 */ {{max, -min}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 135 */ {{inf, -min}, {inf, zero}, 1.e+00f}, + /* 136 */ {{-inf, zero}, {inf, pi}, 1.e+00f}, + /* 137 */ {{-max, zero}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 138 */ {{-6.14096e+25f, zero}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 139 */ {{-1.108238e+13f, zero}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 140 */ {{-2.e+00f, zero}, {zero, pi}, 2.5e-01f}, + /* 141 */ + {{-3.609332e-13f, zero}, {-3.609332e-13f, zero}, 2.199023e+12f}, + /* 142 */ + {{-6.513639e-26f, zero}, {-6.513639e-26f, zero}, 9.671407e+24f}, + /* 143 */ {{-min, zero}, {-min, zero}, 4.25353e+37f}, + /* 144 */ {{zero, zero}, {zero, zero}, 1.e+00f}, + /* 145 */ {{min, zero}, {min, zero}, 4.25353e+37f}, + /* 146 */ + {{6.513639e-26f, zero}, {6.513639e-26f, zero}, 9.671407e+24f}, + /* 147 */ + {{3.609332e-13f, zero}, {3.609332e-13f, zero}, 2.199023e+12f}, + /* 148 */ {{2.e+00f, zero}, {1.098612e+00f, zero}, 5.e-01f}, + /* 149 */ {{1.108238e+13f, zero}, {3.003638e+01f, zero}, 3.125e-02f}, + /* 150 */ {{6.14096e+25f, zero}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 151 */ {{max, zero}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 152 */ {{inf, zero}, {inf, zero}, 1.e+00f}, + /* 153 */ {{-inf, min}, {inf, pi}, 1.e+00f}, + /* 154 */ {{-max, min}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 155 */ {{-6.14096e+25f, min}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 156 */ {{-1.108238e+13f, min}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 157 */ {{-2.e+00f, min}, {zero, pi}, 2.5e-01f}, + /* 158 */ + {{-3.609332e-13f, min}, {-3.609332e-13f, min}, 2.199023e+12f}, + /* 159 */ + {{-6.513639e-26f, min}, {-6.513639e-26f, min}, 9.671407e+24f}, + /* 160 */ {{-min, min}, {-min, min}, 4.25353e+37f}, + /* 161 */ {{zero, min}, {zero, min}, 4.25353e+37f}, + /* 162 */ {{min, min}, {min, min}, 4.25353e+37f}, + /* 163 */ {{6.513639e-26f, min}, {6.513639e-26f, min}, 9.671407e+24f}, + /* 164 */ {{3.609332e-13f, min}, {3.609332e-13f, min}, 2.199023e+12f}, + /* 165 */ {{2.e+00f, min}, {1.098612e+00f, zero}, 5.e-01f}, + /* 166 */ {{1.108238e+13f, min}, {3.003638e+01f, zero}, 3.125e-02f}, + /* 167 */ {{6.14096e+25f, min}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 168 */ {{max, min}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 169 */ {{inf, min}, {inf, zero}, 1.e+00f}, + /* 170 */ {{-inf, 6.513639e-26f}, {inf, pi}, 1.e+00f}, + /* 171 */ {{-max, 6.513639e-26f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 172 */ + {{-6.14096e+25f, 6.513639e-26f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 173 */ + {{-1.108238e+13f, 6.513639e-26f}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 174 */ {{-2.e+00f, 6.513639e-26f}, {zero, pi}, 2.5e-01f}, + /* 175 */ + {{-3.609332e-13f, 6.513639e-26f}, + {-3.609332e-13f, 6.513639e-26f}, + 2.199023e+12f}, + /* 176 */ + {{-6.513639e-26f, 6.513639e-26f}, + {-6.513639e-26f, 6.513639e-26f}, + 9.671407e+24f}, + /* 177 */ + {{-min, 6.513639e-26f}, {-min, 6.513639e-26f}, 9.671407e+24f}, + /* 178 */ + {{zero, 6.513639e-26f}, {zero, 6.513639e-26f}, 9.671407e+24f}, + /* 179 */ {{min, 6.513639e-26f}, {min, 6.513639e-26f}, 9.671407e+24f}, + /* 180 */ + {{6.513639e-26f, 6.513639e-26f}, + {6.513639e-26f, 6.513639e-26f}, + 9.671407e+24f}, + /* 181 */ + {{3.609332e-13f, 6.513639e-26f}, + {3.609332e-13f, 6.513639e-26f}, + 2.199023e+12f}, + /* 182 */ + {{2.e+00f, 6.513639e-26f}, {1.098612e+00f, 2.171213e-26f}, 5.e-01f}, + /* 183 */ + {{1.108238e+13f, 6.513639e-26f}, + {3.003638e+01f, 5.877472e-39f}, + 3.125e-02f}, + /* 184 */ + {{6.14096e+25f, 6.513639e-26f}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 185 */ {{max, 6.513639e-26f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 186 */ {{inf, 6.513639e-26f}, {inf, zero}, 1.e+00f}, + /* 187 */ {{-inf, 3.609332e-13f}, {inf, pi}, 1.e+00f}, + /* 188 */ {{-max, 3.609332e-13f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 189 */ + {{-6.14096e+25f, 3.609332e-13f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 190 */ + {{-1.108238e+13f, 3.609332e-13f}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 191 */ {{-2.e+00f, 3.609332e-13f}, {6.513639e-26f, pi}, 2.5e-01f}, + /* 192 */ + {{-3.609332e-13f, 3.609332e-13f}, + {-3.609332e-13f, 3.609332e-13f}, + 1.099512e+12f}, + /* 193 */ + {{-6.513639e-26f, 3.609332e-13f}, + {-2.843711e-33f, 3.609332e-13f}, + 2.199023e+12f}, + /* 194 */ + {{-min, 3.609332e-13f}, + {6.513639e-26f, 3.609332e-13f}, + 2.199023e+12f}, + /* 195 */ + {{zero, 3.609332e-13f}, + {6.513639e-26f, 3.609332e-13f}, + 2.199023e+12f}, + /* 196 */ + {{min, 3.609332e-13f}, {6.513639e-26f, 3.609332e-13f}, 2.199023e+12f}, + /* 197 */ + {{6.513639e-26f, 3.609332e-13f}, + {1.302728e-25f, 3.609332e-13f}, + 2.199023e+12f}, + /* 198 */ + {{3.609332e-13f, 3.609332e-13f}, + {3.609332e-13f, 3.609332e-13f}, + 1.099512e+12f}, + /* 199 */ + {{2.e+00f, 3.609332e-13f}, {1.098612e+00f, 1.203111e-13f}, 5.e-01f}, + /* 200 */ + {{1.108238e+13f, 3.609332e-13f}, + {3.003638e+01f, 3.256819e-26f}, + 3.125e-02f}, + /* 201 */ + {{6.14096e+25f, 3.609332e-13f}, + {5.937961e+01f, 5.877472e-39f}, + 1.5625e-02f}, + /* 202 */ {{max, 3.609332e-13f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 203 */ {{inf, 3.609332e-13f}, {inf, zero}, 1.e+00f}, + /* 204 */ {{-inf, 2.e+00f}, {inf, pi}, 1.e+00f}, + /* 205 */ {{-max, 2.e+00f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 206 */ + {{-6.14096e+25f, 2.e+00f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 207 */ + {{-1.108238e+13f, 2.e+00f}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 208 */ + {{-2.e+00f, 2.e+00f}, {8.04719e-01f, 2.034444e+00f}, 2.5e-01f}, + /* 209 */ + {{-3.609332e-13f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 210 */ + {{-6.513639e-26f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 211 */ {{-min, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 212 */ {{zero, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 213 */ {{min, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 214 */ + {{6.513639e-26f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 215 */ + {{3.609332e-13f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 216 */ + {{2.e+00f, 2.e+00f}, {1.282475e+00f, 5.880026e-01f}, 5.e-01f}, + /* 217 */ + {{1.108238e+13f, 2.e+00f}, + {3.003638e+01f, 1.804666e-13f}, + 3.125e-02f}, + /* 218 */ + {{6.14096e+25f, 2.e+00f}, {5.937961e+01f, 3.25682e-26f}, 1.5625e-02f}, + /* 219 */ + {{max, 2.e+00f}, {8.872284e+01f, 5.877472e-39f}, 7.8125e-03f}, + /* 220 */ {{inf, 2.e+00f}, {inf, zero}, 1.e+00f}, + /* 221 */ {{-inf, 1.108238e+13f}, {inf, pi}, 1.e+00f}, + /* 222 */ {{-max, 1.108238e+13f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 223 */ + {{-6.14096e+25f, 1.108238e+13f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 224 */ + {{-1.108238e+13f, 1.108238e+13f}, {3.038295e+01f, pi3_4}, 3.125e-02f}, + /* 225 */ + {{-2.e+00f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 226 */ + {{-3.609332e-13f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 227 */ + {{-6.513639e-26f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 228 */ {{-min, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 229 */ {{zero, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 230 */ {{min, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 231 */ + {{6.513639e-26f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 232 */ + {{3.609332e-13f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 233 */ + {{2.e+00f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 234 */ + {{1.108238e+13f, 1.108238e+13f}, {3.038295e+01f, pi_4}, 3.125e-02f}, + /* 235 */ + {{6.14096e+25f, 1.108238e+13f}, + {5.937961e+01f, 1.804666e-13f}, + 1.5625e-02f}, + /* 236 */ + {{max, 1.108238e+13f}, {8.872284e+01f, 3.25682e-26f}, 7.8125e-03f}, + /* 237 */ {{inf, 1.108238e+13f}, {inf, zero}, 1.e+00f}, + /* 238 */ {{-inf, 6.14096e+25f}, {inf, pi}, 1.e+00f}, + /* 239 */ {{-max, 6.14096e+25f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 240 */ + {{-6.14096e+25f, 6.14096e+25f}, {5.972618e+01f, pi3_4}, 1.5625e-02f}, + /* 241 */ + {{-1.108238e+13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 242 */ + {{-2.e+00f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 243 */ + {{-3.609332e-13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 244 */ + {{-6.513639e-26f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 245 */ {{-min, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 246 */ {{zero, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 247 */ {{min, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 248 */ + {{6.513639e-26f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 249 */ + {{3.609332e-13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 250 */ + {{2.e+00f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 251 */ + {{1.108238e+13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 252 */ + {{6.14096e+25f, 6.14096e+25f}, {5.972618e+01f, pi_4}, 1.5625e-02f}, + /* 253 */ + {{max, 6.14096e+25f}, {8.872284e+01f, 1.804666e-13f}, 7.8125e-03f}, + /* 254 */ {{inf, 6.14096e+25f}, {inf, zero}, 1.e+00f}, + /* 255 */ {{-inf, max}, {inf, pi}, 1.e+00f}, + /* 256 */ {{-max, max}, {8.906941e+01f, pi3_4}, 7.8125e-03f}, + /* 257 */ {{-6.14096e+25f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 258 */ {{-1.108238e+13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 259 */ {{-2.e+00f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 260 */ {{-3.609332e-13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 261 */ {{-6.513639e-26f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 262 */ {{-min, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 263 */ {{zero, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 264 */ {{min, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 265 */ {{6.513639e-26f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 266 */ {{3.609332e-13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 267 */ {{2.e+00f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 268 */ {{1.108238e+13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 269 */ {{6.14096e+25f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 270 */ {{max, max}, {8.906941e+01f, pi_4}, 7.8125e-03f}, + /* 271 */ {{inf, max}, {inf, zero}, 1.e+00f}, + /* 272 */ {{-inf, inf}, {inf, pi3_4}, 1.e+00f}, + /* 273 */ {{-max, inf}, {inf, pi_2}, 1.e+00f}, + /* 274 */ {{-6.14096e+25f, inf}, {inf, pi_2}, 1.e+00f}, + /* 275 */ {{-1.108238e+13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 276 */ {{-2.e+00f, inf}, {inf, pi_2}, 1.e+00f}, + /* 277 */ {{-3.609332e-13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 278 */ {{-6.513639e-26f, inf}, {inf, pi_2}, 1.e+00f}, + /* 279 */ {{-min, inf}, {inf, pi_2}, 1.e+00f}, + /* 280 */ {{zero, inf}, {inf, pi_2}, 1.e+00f}, + /* 281 */ {{min, inf}, {inf, pi_2}, 1.e+00f}, + /* 282 */ {{6.513639e-26f, inf}, {inf, pi_2}, 1.e+00f}, + /* 283 */ {{3.609332e-13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 284 */ {{2.e+00f, inf}, {inf, pi_2}, 1.e+00f}, + /* 285 */ {{1.108238e+13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 286 */ {{6.14096e+25f, inf}, {inf, pi_2}, 1.e+00f}, + /* 287 */ {{max, inf}, {inf, pi_2}, 1.e+00f}, + /* 288 */ {{inf, inf}, {inf, pi_4}, 1.e+00f}}; + return table; + } else if constexpr (std::is_same_v) { + const T pi = 3.141592653589793; + const T pi_4 = 0.7853981633974483; + const T pi_2 = 1.5707963267948966; + const T pi3_4 = 2.356194490192345; + const T zero = 0.0; + const TableType table{ + /* 0 */ {{-inf, -inf}, {inf, -pi3_4}, 1.e+00}, + /* 1 */ {{-max, -inf}, {inf, -pi_2}, 1.e+00}, + /* 2 */ {{-4.013165208090075e+205, -inf}, {inf, -pi_2}, 1.e+00}, + /* 3 */ {{-8.958978968710456e+102, -inf}, {inf, -pi_2}, 1.e+00}, + /* 4 */ {{-1.999999999999869e+00, -inf}, {inf, -pi_2}, 1.e+00}, + /* 5 */ {{-4.464794497196183e-103, -inf}, {inf, -pi_2}, 1.e+00}, + /* 6 */ {{-9.967194951097309e-206, -inf}, {inf, -pi_2}, 1.e+00}, + /* 7 */ {{-min, -inf}, {inf, -pi_2}, 1.e+00}, + /* 8 */ {{zero, -inf}, {inf, -pi_2}, 1.e+00}, + /* 9 */ {{min, -inf}, {inf, -pi_2}, 1.e+00}, + /* 10 */ {{9.967194951097309e-206, -inf}, {inf, -pi_2}, 1.e+00}, + /* 11 */ {{4.464794497196183e-103, -inf}, {inf, -pi_2}, 1.e+00}, + /* 12 */ {{1.999999999999869e+00, -inf}, {inf, -pi_2}, 1.e+00}, + /* 13 */ {{8.958978968710456e+102, -inf}, {inf, -pi_2}, 1.e+00}, + /* 14 */ {{4.013165208090075e+205, -inf}, {inf, -pi_2}, 1.e+00}, + /* 15 */ {{max, -inf}, {inf, -pi_2}, 1.e+00}, + /* 16 */ {{inf, -inf}, {inf, -pi_4}, 1.e+00}, + /* 17 */ {{-inf, -max}, {inf, -pi}, 1.e+00}, + /* 18 */ + {{-max, -max}, {7.101292864836639e+02, -pi3_4}, 9.765625e-04}, + /* 19 */ + {{-4.013165208090075e+205, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 20 */ + {{-8.958978968710456e+102, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 21 */ + {{-1.999999999999869e+00, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 22 */ + {{-4.464794497196183e-103, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 23 */ + {{-9.967194951097309e-206, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 24 */ {{-min, -max}, {7.09782712893384e+02, -pi_2}, 9.765625e-04}, + /* 25 */ {{zero, -max}, {7.09782712893384e+02, -pi_2}, 9.765625e-04}, + /* 26 */ {{min, -max}, {7.09782712893384e+02, -pi_2}, 9.765625e-04}, + /* 27 */ + {{9.967194951097309e-206, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 28 */ + {{4.464794497196183e-103, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 29 */ + {{1.999999999999869e+00, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 30 */ + {{8.958978968710456e+102, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 31 */ + {{4.013165208090075e+205, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 32 */ {{max, -max}, {7.101292864836639e+02, -pi_4}, 9.765625e-04}, + /* 33 */ {{inf, -max}, {inf, zero}, 1.e+00}, + /* 34 */ {{-inf, -4.013165208090075e+205}, {inf, -pi}, 1.e+00}, + /* 35 */ + {{-max, -4.013165208090075e+205}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 36 */ + {{-4.013165208090075e+205, -4.013165208090075e+205}, + {4.737660979127225e+02, -pi3_4}, + 1.953125e-03}, + /* 37 */ + {{-8.958978968710456e+102, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 38 */ + {{-1.999999999999869e+00, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 39 */ + {{-4.464794497196183e-103, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 40 */ + {{-9.967194951097309e-206, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 41 */ + {{-min, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 42 */ + {{zero, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 43 */ + {{min, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 44 */ + {{9.967194951097309e-206, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 45 */ + {{4.464794497196183e-103, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 46 */ + {{1.999999999999869e+00, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 47 */ + {{8.958978968710456e+102, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 48 */ + {{4.013165208090075e+205, -4.013165208090075e+205}, + {4.737660979127225e+02, -pi_4}, + 1.953125e-03}, + /* 49 */ + {{max, -4.013165208090075e+205}, + {7.09782712893384e+02, -2.23239724859796e-103}, + 9.765625e-04}, + /* 50 */ {{inf, -4.013165208090075e+205}, {inf, zero}, 1.e+00}, + /* 51 */ {{-inf, -8.958978968710456e+102}, {inf, -pi}, 1.e+00}, + /* 52 */ + {{-max, -8.958978968710456e+102}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 53 */ + {{-4.013165208090075e+205, -8.958978968710456e+102}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 54 */ + {{-8.958978968710456e+102, -8.958978968710456e+102}, + {2.374029093417812e+02, -pi3_4}, + 3.90625e-03}, + /* 55 */ + {{-1.999999999999869e+00, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 56 */ + {{-4.464794497196183e-103, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 57 */ + {{-9.967194951097309e-206, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 58 */ + {{-min, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 59 */ + {{zero, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 60 */ + {{min, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 61 */ + {{9.967194951097309e-206, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 62 */ + {{4.464794497196183e-103, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 63 */ + {{1.999999999999869e+00, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 64 */ + {{8.958978968710456e+102, -8.958978968710456e+102}, + {2.374029093417812e+02, -pi_4}, + 3.90625e-03}, + /* 65 */ + {{4.013165208090075e+205, -8.958978968710456e+102}, + {4.734195243224426e+02, -2.232397248598237e-103}, + 1.953125e-03}, + /* 66 */ + {{max, -8.958978968710456e+102}, + {7.09782712893384e+02, -4.983597475548361e-206}, + 9.765625e-04}, + /* 67 */ {{inf, -8.958978968710456e+102}, {inf, zero}, 1.e+00}, + /* 68 */ {{-inf, -1.999999999999869e+00}, {inf, -pi}, 1.e+00}, + /* 69 */ + {{-max, -1.999999999999869e+00}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 70 */ + {{-4.013165208090075e+205, -1.999999999999869e+00}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 71 */ + {{-8.958978968710456e+102, -1.999999999999869e+00}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 72 */ + {{-1.999999999999869e+00, -1.999999999999869e+00}, + {8.047189562169719e-01, -2.034443935795677e+00}, + 2.5e-01}, + /* 73 */ + {{-4.464794497196183e-103, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 74 */ + {{-9.967194951097309e-206, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 75 */ + {{-min, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 76 */ + {{zero, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 77 */ + {{min, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 78 */ + {{9.967194951097309e-206, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 79 */ + {{4.464794497196183e-103, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 80 */ + {{1.999999999999869e+00, -1.999999999999869e+00}, + {1.282474678730718e+00, -5.880026035475575e-01}, + 5.e-01}, + /* 81 */ + {{8.958978968710456e+102, -1.999999999999869e+00}, + {2.370563357515012e+02, -2.232397248598237e-103}, + 3.90625e-03}, + /* 82 */ + {{4.013165208090075e+205, -1.999999999999869e+00}, + {4.734195243224426e+02, -4.98359747554898e-206}, + 1.953125e-03}, + /* 83 */ + {{max, -1.999999999999869e+00}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 84 */ {{inf, -1.999999999999869e+00}, {inf, zero}, 1.e+00}, + /* 85 */ {{-inf, -4.464794497196183e-103}, {inf, -pi}, 1.e+00}, + /* 86 */ + {{-max, -4.464794497196183e-103}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 87 */ + {{-4.013165208090075e+205, -4.464794497196183e-103}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 88 */ + {{-8.958978968710456e+102, -4.464794497196183e-103}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 89 */ + {{-1.999999999999869e+00, -4.464794497196183e-103}, + {-1.305622276959269e-13, -pi}, + 2.5e-01}, + /* 90 */ + {{-4.464794497196183e-103, -4.464794497196183e-103}, + {-4.464794497196183e-103, -4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 91 */ + {{-9.967194951097309e-206, -4.464794497196183e-103}, + {-6.506695883473837e-219, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 92 */ + {{-min, -4.464794497196183e-103}, + {9.967194951096658e-206, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 93 */ + {{zero, -4.464794497196183e-103}, + {9.967194951096658e-206, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 94 */ + {{min, -4.464794497196183e-103}, + {9.967194951096658e-206, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 95 */ + {{9.967194951097309e-206, -4.464794497196183e-103}, + {1.993438990219397e-205, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 96 */ + {{4.464794497196183e-103, -4.464794497196183e-103}, + {4.464794497196183e-103, -4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 97 */ + {{1.999999999999869e+00, -4.464794497196183e-103}, + {1.098612288668066e+00, -1.488264832398792e-103}, + 5.e-01}, + /* 98 */ + {{8.958978968710456e+102, -4.464794497196183e-103}, + {2.370563357515012e+02, -4.98359747554898e-206}, + 3.90625e-03}, + /* 99 */ + {{4.013165208090075e+205, -4.464794497196183e-103}, + {4.734195243224426e+02, -1.112536929253666e-308}, + 1.953125e-03}, + /* 100 */ + {{max, -4.464794497196183e-103}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 101 */ {{inf, -4.464794497196183e-103}, {inf, zero}, 1.e+00}, + /* 102 */ {{-inf, -9.967194951097309e-206}, {inf, -pi}, 1.e+00}, + /* 103 */ + {{-max, -9.967194951097309e-206}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 104 */ + {{-4.013165208090075e+205, -9.967194951097309e-206}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 105 */ + {{-8.958978968710456e+102, -9.967194951097309e-206}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 106 */ + {{-1.999999999999869e+00, -9.967194951097309e-206}, + {-1.305622276959269e-13, -pi}, + 2.5e-01}, + /* 107 */ + {{-4.464794497196183e-103, -9.967194951097309e-206}, + {-4.464794497196183e-103, -9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 108 */ + {{-9.967194951097309e-206, -9.967194951097309e-206}, + {-9.967194951097309e-206, -9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 109 */ + {{-min, -9.967194951097309e-206}, + {-min, -9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 110 */ + {{zero, -9.967194951097309e-206}, + {zero, -9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 111 */ + {{min, -9.967194951097309e-206}, + {min, -9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 112 */ + {{9.967194951097309e-206, -9.967194951097309e-206}, + {9.967194951097309e-206, -9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 113 */ + {{4.464794497196183e-103, -9.967194951097309e-206}, + {4.464794497196183e-103, -9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 114 */ + {{1.999999999999869e+00, -9.967194951097309e-206}, + {1.098612288668066e+00, -3.322398317032581e-206}, + 5.e-01}, + /* 115 */ + {{8.958978968710456e+102, -9.967194951097309e-206}, + {2.370563357515012e+02, -1.112536929253666e-308}, + 3.90625e-03}, + /* 116 */ + {{4.013165208090075e+205, -9.967194951097309e-206}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 117 */ + {{max, -9.967194951097309e-206}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 118 */ {{inf, -9.967194951097309e-206}, {inf, zero}, 1.e+00}, + /* 119 */ {{-inf, -min}, {inf, -pi}, 1.e+00}, + /* 120 */ {{-max, -min}, {7.09782712893384e+02, -pi}, 9.765625e-04}, + /* 121 */ + {{-4.013165208090075e+205, -min}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 122 */ + {{-8.958978968710456e+102, -min}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 123 */ + {{-1.999999999999869e+00, -min}, + {-1.305622276959269e-13, -pi}, + 2.5e-01}, + /* 124 */ + {{-4.464794497196183e-103, -min}, + {-4.464794497196183e-103, -min}, + 2.239744742177804e+102}, + /* 125 */ + {{-9.967194951097309e-206, -min}, + {-9.967194951097309e-206, -min}, + 1.003291302022624e+205}, + /* 126 */ {{-min, -min}, {-min, -min}, 2.247116418577895e+307}, + /* 127 */ {{zero, -min}, {zero, -min}, 2.247116418577895e+307}, + /* 128 */ {{min, -min}, {min, -min}, 2.247116418577895e+307}, + /* 129 */ + {{9.967194951097309e-206, -min}, + {9.967194951097309e-206, -min}, + 1.003291302022624e+205}, + /* 130 */ + {{4.464794497196183e-103, -min}, + {4.464794497196183e-103, -min}, + 2.239744742177804e+102}, + /* 131 */ + {{1.999999999999869e+00, -min}, + {1.098612288668066e+00, zero}, + 5.e-01}, + /* 132 */ + {{8.958978968710456e+102, -min}, + {2.370563357515012e+02, zero}, + 3.90625e-03}, + /* 133 */ + {{4.013165208090075e+205, -min}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 134 */ {{max, -min}, {7.09782712893384e+02, zero}, 9.765625e-04}, + /* 135 */ {{inf, -min}, {inf, zero}, 1.e+00}, + /* 136 */ {{-inf, zero}, {inf, pi}, 1.e+00}, + /* 137 */ {{-max, zero}, {7.09782712893384e+02, pi}, 9.765625e-04}, + /* 138 */ + {{-4.013165208090075e+205, zero}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 139 */ + {{-8.958978968710456e+102, zero}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 140 */ + {{-1.999999999999869e+00, zero}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 141 */ + {{-4.464794497196183e-103, zero}, + {-4.464794497196183e-103, zero}, + 2.239744742177804e+102}, + /* 142 */ + {{-9.967194951097309e-206, zero}, + {-9.967194951097309e-206, zero}, + 1.003291302022624e+205}, + /* 143 */ {{-min, zero}, {-min, zero}, 2.247116418577895e+307}, + /* 144 */ {{zero, zero}, {zero, zero}, 1.e+00}, + /* 145 */ {{min, zero}, {min, zero}, 2.247116418577895e+307}, + /* 146 */ + {{9.967194951097309e-206, zero}, + {9.967194951097309e-206, zero}, + 1.003291302022624e+205}, + /* 147 */ + {{4.464794497196183e-103, zero}, + {4.464794497196183e-103, zero}, + 2.239744742177804e+102}, + /* 148 */ + {{1.999999999999869e+00, zero}, + {1.098612288668066e+00, zero}, + 5.e-01}, + /* 149 */ + {{8.958978968710456e+102, zero}, + {2.370563357515012e+02, zero}, + 3.90625e-03}, + /* 150 */ + {{4.013165208090075e+205, zero}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 151 */ {{max, zero}, {7.09782712893384e+02, zero}, 9.765625e-04}, + /* 152 */ {{inf, zero}, {inf, zero}, 1.e+00}, + /* 153 */ {{-inf, min}, {inf, pi}, 1.e+00}, + /* 154 */ {{-max, min}, {7.09782712893384e+02, pi}, 9.765625e-04}, + /* 155 */ + {{-4.013165208090075e+205, min}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 156 */ + {{-8.958978968710456e+102, min}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 157 */ + {{-1.999999999999869e+00, min}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 158 */ + {{-4.464794497196183e-103, min}, + {-4.464794497196183e-103, min}, + 2.239744742177804e+102}, + /* 159 */ + {{-9.967194951097309e-206, min}, + {-9.967194951097309e-206, min}, + 1.003291302022624e+205}, + /* 160 */ {{-min, min}, {-min, min}, 2.247116418577895e+307}, + /* 161 */ {{zero, min}, {zero, min}, 2.247116418577895e+307}, + /* 162 */ {{min, min}, {min, min}, 2.247116418577895e+307}, + /* 163 */ + {{9.967194951097309e-206, min}, + {9.967194951097309e-206, min}, + 1.003291302022624e+205}, + /* 164 */ + {{4.464794497196183e-103, min}, + {4.464794497196183e-103, min}, + 2.239744742177804e+102}, + /* 165 */ + {{1.999999999999869e+00, min}, {1.098612288668066e+00, zero}, 5.e-01}, + /* 166 */ + {{8.958978968710456e+102, min}, + {2.370563357515012e+02, zero}, + 3.90625e-03}, + /* 167 */ + {{4.013165208090075e+205, min}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 168 */ {{max, min}, {7.09782712893384e+02, zero}, 9.765625e-04}, + /* 169 */ {{inf, min}, {inf, zero}, 1.e+00}, + /* 170 */ {{-inf, 9.967194951097309e-206}, {inf, pi}, 1.e+00}, + /* 171 */ + {{-max, 9.967194951097309e-206}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 172 */ + {{-4.013165208090075e+205, 9.967194951097309e-206}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 173 */ + {{-8.958978968710456e+102, 9.967194951097309e-206}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 174 */ + {{-1.999999999999869e+00, 9.967194951097309e-206}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 175 */ + {{-4.464794497196183e-103, 9.967194951097309e-206}, + {-4.464794497196183e-103, 9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 176 */ + {{-9.967194951097309e-206, 9.967194951097309e-206}, + {-9.967194951097309e-206, 9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 177 */ + {{-min, 9.967194951097309e-206}, + {-min, 9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 178 */ + {{zero, 9.967194951097309e-206}, + {zero, 9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 179 */ + {{min, 9.967194951097309e-206}, + {min, 9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 180 */ + {{9.967194951097309e-206, 9.967194951097309e-206}, + {9.967194951097309e-206, 9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 181 */ + {{4.464794497196183e-103, 9.967194951097309e-206}, + {4.464794497196183e-103, 9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 182 */ + {{1.999999999999869e+00, 9.967194951097309e-206}, + {1.098612288668066e+00, 3.322398317032581e-206}, + 5.e-01}, + /* 183 */ + {{8.958978968710456e+102, 9.967194951097309e-206}, + {2.370563357515012e+02, 1.112536929253666e-308}, + 3.90625e-03}, + /* 184 */ + {{4.013165208090075e+205, 9.967194951097309e-206}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 185 */ + {{max, 9.967194951097309e-206}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 186 */ {{inf, 9.967194951097309e-206}, {inf, zero}, 1.e+00}, + /* 187 */ {{-inf, 4.464794497196183e-103}, {inf, pi}, 1.e+00}, + /* 188 */ + {{-max, 4.464794497196183e-103}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 189 */ + {{-4.013165208090075e+205, 4.464794497196183e-103}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 190 */ + {{-8.958978968710456e+102, 4.464794497196183e-103}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 191 */ + {{-1.999999999999869e+00, 4.464794497196183e-103}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 192 */ + {{-4.464794497196183e-103, 4.464794497196183e-103}, + {-4.464794497196183e-103, 4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 193 */ + {{-9.967194951097309e-206, 4.464794497196183e-103}, + {-6.506695883473837e-219, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 194 */ + {{-min, 4.464794497196183e-103}, + {9.967194951096658e-206, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 195 */ + {{zero, 4.464794497196183e-103}, + {9.967194951096658e-206, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 196 */ + {{min, 4.464794497196183e-103}, + {9.967194951096658e-206, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 197 */ + {{9.967194951097309e-206, 4.464794497196183e-103}, + {1.993438990219397e-205, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 198 */ + {{4.464794497196183e-103, 4.464794497196183e-103}, + {4.464794497196183e-103, 4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 199 */ + {{1.999999999999869e+00, 4.464794497196183e-103}, + {1.098612288668066e+00, 1.488264832398792e-103}, + 5.e-01}, + /* 200 */ + {{8.958978968710456e+102, 4.464794497196183e-103}, + {2.370563357515012e+02, 4.98359747554898e-206}, + 3.90625e-03}, + /* 201 */ + {{4.013165208090075e+205, 4.464794497196183e-103}, + {4.734195243224426e+02, 1.112536929253666e-308}, + 1.953125e-03}, + /* 202 */ + {{max, 4.464794497196183e-103}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 203 */ {{inf, 4.464794497196183e-103}, {inf, zero}, 1.e+00}, + /* 204 */ {{-inf, 1.999999999999869e+00}, {inf, pi}, 1.e+00}, + /* 205 */ + {{-max, 1.999999999999869e+00}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 206 */ + {{-4.013165208090075e+205, 1.999999999999869e+00}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 207 */ + {{-8.958978968710456e+102, 1.999999999999869e+00}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 208 */ + {{-1.999999999999869e+00, 1.999999999999869e+00}, + {8.047189562169719e-01, 2.034443935795677e+00}, + 2.5e-01}, + /* 209 */ + {{-4.464794497196183e-103, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 210 */ + {{-9.967194951097309e-206, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 211 */ + {{-min, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 212 */ + {{zero, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 213 */ + {{min, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 214 */ + {{9.967194951097309e-206, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 215 */ + {{4.464794497196183e-103, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 216 */ + {{1.999999999999869e+00, 1.999999999999869e+00}, + {1.282474678730718e+00, 5.880026035475575e-01}, + 5.e-01}, + /* 217 */ + {{8.958978968710456e+102, 1.999999999999869e+00}, + {2.370563357515012e+02, 2.232397248598237e-103}, + 3.90625e-03}, + /* 218 */ + {{4.013165208090075e+205, 1.999999999999869e+00}, + {4.734195243224426e+02, 4.98359747554898e-206}, + 1.953125e-03}, + /* 219 */ + {{max, 1.999999999999869e+00}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 220 */ {{inf, 1.999999999999869e+00}, {inf, zero}, 1.e+00}, + /* 221 */ {{-inf, 8.958978968710456e+102}, {inf, pi}, 1.e+00}, + /* 222 */ + {{-max, 8.958978968710456e+102}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 223 */ + {{-4.013165208090075e+205, 8.958978968710456e+102}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 224 */ + {{-8.958978968710456e+102, 8.958978968710456e+102}, + {2.374029093417812e+02, pi3_4}, + 3.90625e-03}, + /* 225 */ + {{-1.999999999999869e+00, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 226 */ + {{-4.464794497196183e-103, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 227 */ + {{-9.967194951097309e-206, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 228 */ + {{-min, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 229 */ + {{zero, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 230 */ + {{min, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 231 */ + {{9.967194951097309e-206, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 232 */ + {{4.464794497196183e-103, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 233 */ + {{1.999999999999869e+00, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 234 */ + {{8.958978968710456e+102, 8.958978968710456e+102}, + {2.374029093417812e+02, pi_4}, + 3.90625e-03}, + /* 235 */ + {{4.013165208090075e+205, 8.958978968710456e+102}, + {4.734195243224426e+02, 2.232397248598237e-103}, + 1.953125e-03}, + /* 236 */ + {{max, 8.958978968710456e+102}, + {7.09782712893384e+02, 4.983597475548361e-206}, + 9.765625e-04}, + /* 237 */ {{inf, 8.958978968710456e+102}, {inf, zero}, 1.e+00}, + /* 238 */ {{-inf, 4.013165208090075e+205}, {inf, pi}, 1.e+00}, + /* 239 */ + {{-max, 4.013165208090075e+205}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 240 */ + {{-4.013165208090075e+205, 4.013165208090075e+205}, + {4.737660979127225e+02, pi3_4}, + 1.953125e-03}, + /* 241 */ + {{-8.958978968710456e+102, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 242 */ + {{-1.999999999999869e+00, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 243 */ + {{-4.464794497196183e-103, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 244 */ + {{-9.967194951097309e-206, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 245 */ + {{-min, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 246 */ + {{zero, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 247 */ + {{min, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 248 */ + {{9.967194951097309e-206, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 249 */ + {{4.464794497196183e-103, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 250 */ + {{1.999999999999869e+00, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 251 */ + {{8.958978968710456e+102, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 252 */ + {{4.013165208090075e+205, 4.013165208090075e+205}, + {4.737660979127225e+02, pi_4}, + 1.953125e-03}, + /* 253 */ + {{max, 4.013165208090075e+205}, + {7.09782712893384e+02, 2.23239724859796e-103}, + 9.765625e-04}, + /* 254 */ {{inf, 4.013165208090075e+205}, {inf, zero}, 1.e+00}, + /* 255 */ {{-inf, max}, {inf, pi}, 1.e+00}, + /* 256 */ {{-max, max}, {7.101292864836639e+02, pi3_4}, 9.765625e-04}, + /* 257 */ + {{-4.013165208090075e+205, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 258 */ + {{-8.958978968710456e+102, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 259 */ + {{-1.999999999999869e+00, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 260 */ + {{-4.464794497196183e-103, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 261 */ + {{-9.967194951097309e-206, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 262 */ {{-min, max}, {7.09782712893384e+02, pi_2}, 9.765625e-04}, + /* 263 */ {{zero, max}, {7.09782712893384e+02, pi_2}, 9.765625e-04}, + /* 264 */ {{min, max}, {7.09782712893384e+02, pi_2}, 9.765625e-04}, + /* 265 */ + {{9.967194951097309e-206, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 266 */ + {{4.464794497196183e-103, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 267 */ + {{1.999999999999869e+00, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 268 */ + {{8.958978968710456e+102, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 269 */ + {{4.013165208090075e+205, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 270 */ {{max, max}, {7.101292864836639e+02, pi_4}, 9.765625e-04}, + /* 271 */ {{inf, max}, {inf, zero}, 1.e+00}, + /* 272 */ {{-inf, inf}, {inf, pi3_4}, 1.e+00}, + /* 273 */ {{-max, inf}, {inf, pi_2}, 1.e+00}, + /* 274 */ {{-4.013165208090075e+205, inf}, {inf, pi_2}, 1.e+00}, + /* 275 */ {{-8.958978968710456e+102, inf}, {inf, pi_2}, 1.e+00}, + /* 276 */ {{-1.999999999999869e+00, inf}, {inf, pi_2}, 1.e+00}, + /* 277 */ {{-4.464794497196183e-103, inf}, {inf, pi_2}, 1.e+00}, + /* 278 */ {{-9.967194951097309e-206, inf}, {inf, pi_2}, 1.e+00}, + /* 279 */ {{-min, inf}, {inf, pi_2}, 1.e+00}, + /* 280 */ {{zero, inf}, {inf, pi_2}, 1.e+00}, + /* 281 */ {{min, inf}, {inf, pi_2}, 1.e+00}, + /* 282 */ {{9.967194951097309e-206, inf}, {inf, pi_2}, 1.e+00}, + /* 283 */ {{4.464794497196183e-103, inf}, {inf, pi_2}, 1.e+00}, + /* 284 */ {{1.999999999999869e+00, inf}, {inf, pi_2}, 1.e+00}, + /* 285 */ {{8.958978968710456e+102, inf}, {inf, pi_2}, 1.e+00}, + /* 286 */ {{4.013165208090075e+205, inf}, {inf, pi_2}, 1.e+00}, + /* 287 */ {{max, inf}, {inf, pi_2}, 1.e+00}, + /* 288 */ {{inf, inf}, {inf, pi_4}, 1.e+00}}; + return table; + } else { + static_assert(false); /* unreachable */ + } + } +}; + +} // namespace complex_unary_op_samples + +#endif // XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ diff --git a/third_party/xla/xla/tests/complex_unary_op_test.cc b/third_party/xla/xla/tests/complex_unary_op_test.cc new file mode 100644 index 00000000000000..fefe96cb59c69d --- /dev/null +++ b/third_party/xla/xla/tests/complex_unary_op_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/client/global_data.h" +#include "xla/client/local_client.h" +#include "xla/client/xla_builder.h" +#include "xla/tests/client_library_test_base.h" +#include "xla/tests/complex_unary_op_samples.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tests/test_macros.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +class ComplexUnaryOpTest : public ClientLibraryTestBase { + protected: + template + std::vector get_column(const std::vector>& table) { + std::vector column; + std::transform( + table.cbegin(), table.cend(), std::back_inserter(column), + [](const auto& item) { return static_cast(std::get(item)); }); + return column; + } + + template + void scale_column(std::vector& column, const std::vector& scales) { + std::transform(column.begin(), column.end(), scales.begin(), column.begin(), + [](const T& lhs, const S& rhs) { return lhs * rhs; }); + } + + template + void UnaryTestHelper(XlaOp (*Op)(const XlaOp operand)) { + using InputType = typename C::InputType; + using OutputType = typename C::OutputType; + using FloatType = typename C::FloatType; + + float atol; + // log(10)/log(2) = 3.3219... + constexpr int precision_deficiency = + static_cast(C::dps_deficiency * 3.3219280948873626); + // precision_deficiency defines a slack allowed when comparing a + // result value against expected value that is known to be + // inaccurate to some extent. + if constexpr (std::is_same_v) { + atol = std::ldexp(1e-6f, precision_deficiency); + } else if constexpr (std::is_same_v) { + atol = std::ldexp(1e-15f, precision_deficiency); + } else { + static_assert(false); // unreachable + } + + XlaBuilder builder(TestName()); + auto table = C().get(); + auto inputs_vec = get_column(table); + auto expected_vec = get_column(table); + auto scales_vec = get_column(table); + scale_column(expected_vec, scales_vec); + + auto inputs = ConstantR1(&builder, inputs_vec); + auto scales = ConstantR1(&builder, scales_vec); + Literal expected = LiteralUtil::CreateR1(expected_vec); + + if constexpr (std::is_same_v) { + auto results = Op(inputs); + Mul(results, scales); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(atol)); + } else { + auto results = Op(inputs); + auto re = Mul(Real(results), scales); + auto im = Mul(Imag(results), scales); + Complex(re, im); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(atol)); + } + } +}; + +XLA_TEST_F(ComplexUnaryOpTest, Log1pTest) { + UnaryTestHelper>(Log1p); + UnaryTestHelper>(Log1p); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/tests/generate_complex_unary_op_samples.py b/third_party/xla/xla/tests/generate_complex_unary_op_samples.py new file mode 100644 index 00000000000000..5dcb848f0eae3b --- /dev/null +++ b/third_party/xla/xla/tests/generate_complex_unary_op_samples.py @@ -0,0 +1,231 @@ +"""A script to generate the complex_unary_op_samples.h file. + +The generated file contains samples and reference values of complex unary +functions used by the complex_unary_op_test program. + +Prerequisites: + jax version 0.4.26 or newer + mpmath 1.3 + numpy + +Usage: + Running + python /path/to/generate_complex_unary_op_samples.py + will create + /path/to/generate_complex_unary_op_samples.h +""" + +import os +import re +import sys +import jax._src.test_util as jtu +import mpmath +import numpy as np + + +def disable(op, real, imag): + del op, real, imag + # Return True to disable samples (real, imag) that are know to be + # problematic for the op. + return False + + +def main(): + default_size = 7 + nmp = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=1) + blocks = [] + for opname in ['Log1p']: + mpmath_op = opname.lower() + size_re, size_im = dict(Log1p=(7, 7)).get( + opname, (default_size, default_size) + ) + ifblocks = [] + input_ttype = 'std::complex' + output_ttype = 'TBD' + for dtype in [np.complex64, np.complex128]: + float_dtype = {np.complex64: np.float32, np.complex128: np.float64}[dtype] + ctype = {np.float32: 'float', np.float64: 'double'}[float_dtype] + cnan = {np.float32: 'std::nanf("")', np.float64: 'std::nan("")'}[ + float_dtype + ] + pi = float_dtype(np.pi) + h_pi = float_dtype(np.pi / 2) + q_pi = float_dtype(np.pi / 4) + tq_pi = float_dtype(3 * np.pi / 4) + cfloat_suffix = 'f' if float_dtype == np.float32 else '' + cpi = str(pi) + cfloat_suffix + cpi_2 = str(h_pi) + cfloat_suffix + cpi_4 = str(q_pi) + cfloat_suffix + cpi3_4 = str(tq_pi) + cfloat_suffix + czero = str(float_dtype(0)) + cfloat_suffix + + sample = jtu.complex_plane_sample(dtype, size_re=size_re, size_im=size_im) + values = getattr(nmp, mpmath_op)(sample) + finfo = np.finfo(float_dtype) + + # pylint: disable=cell-var-from-loop + def _tostr(v): + if v == pi: + return 'pi' + if v == -pi: + return '-pi' + if v == h_pi: + return 'pi_2' + if v == -h_pi: + return '-pi_2' + if v == q_pi: + return 'pi_4' + if v == -q_pi: + return '-pi_4' + if v == tq_pi: + return 'pi3_4' + if v == -tq_pi: + return '-pi3_4' + if v == finfo.max: + return 'max' + if v == -finfo.max: + return '-max' + if v == finfo.tiny: + return 'min' + if v == -finfo.tiny: + return '-min' + if np.isnan(v): + return 'nan' + if np.isneginf(v): + return '-inf' + if np.isposinf(v): + return 'inf' + if v == 0.0: + return 'zero' + if float_dtype == np.float32: + s = f'{v:.6e}f' + elif float_dtype == np.float64: + s = f'{v:.15e}' + else: + assert 0 # unreachable + return re.sub(r'0+e', 'e', s) + + used_constants = set() + + def tostr(v): + r = _tostr(v) + used_constants.add(r.removeprefix('-')) + return r + + rows = [] + counter = 0 + for x, y in zip(sample.flatten(), values.flatten()): + re_x, im_x = tostr(x.real), tostr(x.imag) + if disable(opname, re_x, im_x): + prefix = '// ' + else: + # to ease tracking mismatching cases: + prefix = f'/* {counter} */ ' + counter += 1 + if values.dtype.kind == 'c': + output_ttype = 'std::complex' + re_y, im_y = tostr(y.real), tostr(y.imag) + scale = tostr(np.ldexp(1.0, -np.frexp(abs(y))[1])) + rows.append( + f'{prefix}{{ {{ {re_x}, {im_x} }}, {{ {re_y}, {im_y} }},' + f' {scale} }}' + ) + else: + assert values.dtype.kind == 'f' + output_ttype = 'T' + # Scale is power of 2 so that multiplication with + # it has minimal effect to the binary mantissa + # part of other operand. + scale = tostr(np.ldexp(1.0, -np.frexp(abs(y))[1])) + rows.append( + f'{prefix}{{ {{ {re_x}, {im_x} }}, {tostr(y)}, {scale} }}' + ) + rows = ',\n '.join(rows) + + constants = [] + for name, value in dict( + nan=cnan, + pi=cpi, + pi_4=cpi_4, + pi_2=cpi_2, + pi3_4=cpi3_4, + zero=czero, + ).items(): + if name in used_constants: + constants.append(f'const T {name} = {value};') + constants = '\n '.join(constants) + + ifblocks.append(f"""\ +if constexpr (std::is_same_v) {{ + {constants} + const TableType table{{ + {rows} + }}; + return table; + }}""") + ifblocks.append('{ static_assert(false); /* unreachable */ }') + ifblocks = ' else '.join(ifblocks) + blocks.append(f""" + template + struct {opname} {{ + typedef {input_ttype} InputType; + typedef {output_ttype} OutputType; + typedef T FloatType; + using TableType = std::vector>; + static constexpr int dps_deficiency = default_dps_deficiency; + const TableType get() {{ + const T inf = std::numeric_limits::infinity(); + const T min = std::numeric_limits::min(); + const T max = std::numeric_limits::max(); + {ifblocks} + }} + }}; +""") + blocks = '\n'.join(blocks) + + output_filename = os.path.join( + os.path.dirname(__file__), 'complex_unary_op_samples.h' + ) + output = open(output_filename, 'w') + + output.write(f"""\ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + This file is generated using xla/tests/{os.path.basename(__file__)}. Do not edit! + */ + +#include +#include +#include +#include +#include + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ + +namespace complex_unary_op_samples {{ +{blocks} +}} + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ +""") + output.close() + sys.stdout.write(f'Created {output_filename}\n') + + +if __name__ == '__main__': + main()