From f2230dfee3ab19974bfe818afdca7526bc81280e Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Mon, 25 Nov 2024 15:46:26 -0500 Subject: [PATCH 01/10] changes for check_differential_testing_function --- src/inline/inline.py | 16 ++++++ src/inline/plugin.py | 127 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/src/inline/inline.py b/src/inline/inline.py index 5766542..18606f0 100644 --- a/src/inline/inline.py +++ b/src/inline/inline.py @@ -10,6 +10,7 @@ def __init__( tag: List = [], disabled: bool = False, timeout: float = -1.0, + devices: List[str] = None, ): """ Initialize Inline object with test name / parametrized flag @@ -20,6 +21,8 @@ def __init__( :param tag: tags to group tests :param disabled: whether the test is disabled :param timeout: seconds to timeout the test, must be a float + :param devices: list of devices to run differential testing on (e.g., ["cpu", "cuda", "mps"]) + if None, differential testing is disabled """ def given(self, variable, value): @@ -42,6 +45,19 @@ def check_eq(self, actual_value, expected_value): :raises: AssertionError """ return self + + def check_differential_testing(self, outputs): + """ + Assert whether outputs are consistent across different devices. + This method compares the outputs from different devices specified in the constructor. + + :param outputs: a dictionary mapping device names to their outputs, or a single output value + if a single value is provided, the test will run the computation on all devices + and compare against this reference value + :returns: Inline object + :raises: AssertionError if outputs differ across devices + """ + return self def check_neq(self, actual_value, expected_value): """ diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 11c0774..60faf3b 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -171,6 +171,7 @@ def __init__(self): self.tag = [] self.disabled = False self.timeout = -1.0 + self.devices = None self.globs = {} def to_test(self): @@ -293,6 +294,8 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" + arg_devices_str = "devices" + check_differential_testing_str = "check_differential_testing" assume = "assume" inline_module_imported = False @@ -362,7 +365,7 @@ def parse_constructor(self, node): """ Parse a constructor call. """ - NUM_OF_ARGUMENTS = 6 + NUM_OF_ARGUMENTS = 7 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments if sys.version_info >= (3, 8, 0): @@ -394,6 +397,17 @@ def parse_constructor(self, node): and (isinstance(arg.value, float) or isinstance(arg.value, int)) ): self.cur_inline_test.timeout = arg.value + + elif index == 6 and isinstance(arg, ast.List): + devices = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException("devices can only be List of string") + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + self.cur_inline_test.devices = devices + else: raise MalformedException( f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" @@ -431,6 +445,16 @@ def parse_constructor(self, node): raise MalformedException(f"tag can only be List of string") tags.append(elt.value) self.cur_inline_test.tag = tags + # Add devices handling for keyword args + elif keyword.arg == self.arg_devices_str and isinstance(keyword.value, ast.List): + devices = [] + for elt in keyword.value.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException("devices can only be List of string") + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + self.cur_inline_test.devices = devices # check if "disabled" is a boolean elif ( keyword.arg == self.arg_disabled_str @@ -447,6 +471,16 @@ def parse_constructor(self, node): if keyword.value.value <= 0.0: raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") self.cur_inline_test.timeout = keyword.value.value + # Add devices handling for Python 3.7 + elif index == 6 and isinstance(arg, ast.List): + devices = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): # Note: ast.Str for Python 3.7 + raise MalformedException("devices can only be List of string") + if elt.s not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.s}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.s) + self.cur_inline_test.devices = devices else: raise MalformedException( f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" @@ -536,6 +570,19 @@ def parse_constructor(self, node): if keyword.value.n <= 0.0: raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") self.cur_inline_test.timeout = keyword.value.n + #keyword arg for devices + elif ( + keyword.arg == self.arg_devices_str + and isinstance(keyword.value, ast.List) + ): + devices = [] + for elt in keyword.value.elts: + if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): + raise MalformedException("devices can only be List of string") + if elt.s not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.s}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.s) + self.cur_inline_test.devices = devices else: raise MalformedException( f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" @@ -885,6 +932,80 @@ def parse_check_not_same(self, node): self.cur_inline_test.check_stmts.append(assert_node) else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") + + def parse_check_differential_testing(self, node): + + if self.cur_inline_test.devices is None: + raise MalformedException("check_differential_testing can only be used when devices parameter is provided") + + if len(node.args) == 1: + output_node = self.parse_group(node.args[0]) + + if self.cur_inline_test.parameterized: + self.parameterized_inline_tests_init(node.args[0]) + for index, _ in enumerate(node.args[0].elts): + # Compare outputs between consecutive devices + for i in range(len(self.cur_inline_test.devices)-1): + device1 = self.cur_inline_test.devices[i] + device2 = self.cur_inline_test.devices[i+1] + assert_node = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute(value=output_node, attr='to'), + args=[ast.Constant(value=device1)], + keywords=[] + ), + attr='allclose' + ), + args=[ + ast.Call( + func=ast.Attribute(value=output_node, attr='to'), + args=[ast.Constant(value=device2)], + keywords=[] + ) + ], + keywords=[ + ast.keyword(arg='rtol', value=ast.Constant(value=1e-5)), + ast.keyword(arg='atol', value=ast.Constant(value=1e-5)) + ] + ), + ast.Constant(value=True) + ) + self.cur_inline_test.parameterized_inline_tests[index].check_stmts.append(assert_node) + else: + # Non-parameterized case + for i in range(len(self.cur_inline_test.devices)-1): + device1 = self.cur_inline_test.devices[i] + device2 = self.cur_inline_test.devices[i+1] + assert_node = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute(value=output_node, attr='to'), + args=[ast.Constant(value=device1)], + keywords=[] + ), + attr='allclose' + ), + args=[ + ast.Call( + func=ast.Attribute(value=output_node, attr='to'), + args=[ast.Constant(value=device2)], + keywords=[] + ) + ], + keywords=[ + ast.keyword(arg='rtol', value=ast.Constant(value=1e-5)), + ast.keyword(arg='atol', value=ast.Constant(value=1e-5)) + ] + ), + ast.Constant(value=True) + ) + self.cur_inline_test.check_stmts.append(assert_node) + else: + raise MalformedException("check_differential_testing() accepts exactly 1 argument") + def build_fail(self): equal_node = ast.Compare( @@ -986,11 +1107,13 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) + elif call.func.attr == self.check_differential_testing_str: + self.parse_check_differential_testing(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()/check_differential_testing()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") From 54373f2d330198636632274c9018010b5b1e1d5e Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Sun, 15 Dec 2024 18:57:14 -0500 Subject: [PATCH 02/10] working changes for the check_diff() --- src/inline/inline.py | 2 +- src/inline/plugin.py | 156 ++++++++++++++++++++++++------------------- 2 files changed, 88 insertions(+), 70 deletions(-) diff --git a/src/inline/inline.py b/src/inline/inline.py index 18606f0..9673ce8 100644 --- a/src/inline/inline.py +++ b/src/inline/inline.py @@ -10,7 +10,7 @@ def __init__( tag: List = [], disabled: bool = False, timeout: float = -1.0, - devices: List[str] = None, + devices: List = None, ): """ Initialize Inline object with test name / parametrized flag diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 60faf3b..24715ab 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -935,77 +935,93 @@ def parse_check_not_same(self, node): def parse_check_differential_testing(self, node): - if self.cur_inline_test.devices is None: - raise MalformedException("check_differential_testing can only be used when devices parameter is provided") - - if len(node.args) == 1: - output_node = self.parse_group(node.args[0]) + if not self.cur_inline_test.devices: + raise MalformedException("check_differential_testing can only be used with the 'devices' parameter.") + + if len(node.args) != 1: + raise MalformedException("check_differential_testing() requires exactly 1 argument.") + + # Parse the tensor operation + output_node = self.parse_group(node.args[0]) + + # Get the original operation from the previous statements + original_op = None + for stmt in self.cur_inline_test.previous_stmts: + if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: + original_op = stmt.value + break + + if not original_op: + raise MalformedException("Could not find original operation for differential testing") + + device_statements = [] + device_outputs = [] + + # Generate device-specific tensors and operations + for device in self.cur_inline_test.devices: + # Create device-specific input tensor + input_var = self.cur_inline_test.given_stmts[0].targets[0].id + device_input_var = f"{input_var}_{device}" + device_input_stmt = ast.Assign( + targets=[ast.Name(id=device_input_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=self.cur_inline_test.given_stmts[0].value, + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + device_statements.append(device_input_stmt) + + # Create device-specific operation result + device_output_var = f"output_{device}" + # Copy the original operation but replace the input tensor with device-specific one + device_op = copy.deepcopy(original_op) + # Replace the input tensor reference in the operation + for node in ast.walk(device_op): + if isinstance(node, ast.Name) and node.id == input_var: + node.id = device_input_var + + device_output_stmt = ast.Assign( + targets=[ast.Name(id=device_output_var, ctx=ast.Store())], + value=device_op + ) + device_statements.append(device_output_stmt) + device_outputs.append(device_output_var) + + # Add the comparison across devices + comparisons = [] + for i in range(len(device_outputs) - 1): + device1_var = device_outputs[i] + device2_var = device_outputs[i + 1] - if self.cur_inline_test.parameterized: - self.parameterized_inline_tests_init(node.args[0]) - for index, _ in enumerate(node.args[0].elts): - # Compare outputs between consecutive devices - for i in range(len(self.cur_inline_test.devices)-1): - device1 = self.cur_inline_test.devices[i] - device2 = self.cur_inline_test.devices[i+1] - assert_node = self.build_assert_eq( - ast.Call( - func=ast.Attribute( - value=ast.Call( - func=ast.Attribute(value=output_node, attr='to'), - args=[ast.Constant(value=device1)], - keywords=[] - ), - attr='allclose' - ), - args=[ - ast.Call( - func=ast.Attribute(value=output_node, attr='to'), - args=[ast.Constant(value=device2)], - keywords=[] - ) - ], - keywords=[ - ast.keyword(arg='rtol', value=ast.Constant(value=1e-5)), - ast.keyword(arg='atol', value=ast.Constant(value=1e-5)) - ] - ), - ast.Constant(value=True) - ) - self.cur_inline_test.parameterized_inline_tests[index].check_stmts.append(assert_node) - else: - # Non-parameterized case - for i in range(len(self.cur_inline_test.devices)-1): - device1 = self.cur_inline_test.devices[i] - device2 = self.cur_inline_test.devices[i+1] - assert_node = self.build_assert_eq( - ast.Call( - func=ast.Attribute( - value=ast.Call( - func=ast.Attribute(value=output_node, attr='to'), - args=[ast.Constant(value=device1)], - keywords=[] - ), - attr='allclose' - ), - args=[ - ast.Call( - func=ast.Attribute(value=output_node, attr='to'), - args=[ast.Constant(value=device2)], - keywords=[] - ) - ], - keywords=[ - ast.keyword(arg='rtol', value=ast.Constant(value=1e-5)), - ast.keyword(arg='atol', value=ast.Constant(value=1e-5)) - ] - ), - ast.Constant(value=True) - ) - self.cur_inline_test.check_stmts.append(assert_node) - else: - raise MalformedException("check_differential_testing() accepts exactly 1 argument") + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=device1_var, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=device2_var, ctx=ast.Load()), + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-5)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-5)) + ] + ), + ast.Constant(value=True) + ) + comparisons.append(comparison) + # Update the inline test object + self.cur_inline_test.previous_stmts.extend(device_statements) + self.cur_inline_test.check_stmts.extend(comparisons) + + + + def build_fail(self): equal_node = ast.Compare( @@ -1254,6 +1270,8 @@ def _find(self, tests, obj, module, globs, seen): ###################################################################### class InlineTestRunner: def run(self, test: InlineTest, out: List) -> None: + test_str = test.to_test() + print(test_str) tree = ast.parse(test.to_test()) codeobj = compile(tree, filename="", mode="exec") start_time = time.time() From 4b1ae6b6a58367268f12e06c2dab11bafa814cc6 Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Sun, 15 Dec 2024 19:42:23 -0500 Subject: [PATCH 03/10] changes to fix runtime error due to tensors on different devices --- src/inline/plugin.py | 42 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 24715ab..42c72bb 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -992,19 +992,54 @@ def parse_check_differential_testing(self, node): device_outputs.append(device_output_var) # Add the comparison across devices + # Always convert tensors to CPU for comparison comparisons = [] for i in range(len(device_outputs) - 1): device1_var = device_outputs[i] device2_var = device_outputs[i + 1] + # Create CPU versions of the outputs for comparison + device1_cpu_var = f"{device1_var}_cpu_compare" + device2_cpu_var = f"{device2_var}_cpu_compare" + + # Add statements to move tensors to CPU + device_statements.append( + ast.Assign( + targets=[ast.Name(id=device1_cpu_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=device1_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + device_statements.append( + ast.Assign( + targets=[ast.Name(id=device2_cpu_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=device2_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Compare the CPU versions comparison = self.build_assert_eq( ast.Call( func=ast.Attribute( - value=ast.Name(id=device1_var, ctx=ast.Load()), + value=ast.Name(id=device1_cpu_var, ctx=ast.Load()), attr="allclose" ), args=[ - ast.Name(id=device2_var, ctx=ast.Load()), + ast.Name(id=device2_cpu_var, ctx=ast.Load()), ], keywords=[ ast.keyword(arg="rtol", value=ast.Constant(value=1e-5)), @@ -1018,8 +1053,7 @@ def parse_check_differential_testing(self, node): # Update the inline test object self.cur_inline_test.previous_stmts.extend(device_statements) self.cur_inline_test.check_stmts.extend(comparisons) - - + From d3f89413ab762d8fe34fbe9c310ac0c5830c5e58 Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Sun, 9 Feb 2025 13:54:02 -0500 Subject: [PATCH 04/10] changes for the name of the differential testing function --- src/inline/inline.py | 2 +- src/inline/plugin.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/inline/inline.py b/src/inline/inline.py index 9673ce8..847c253 100644 --- a/src/inline/inline.py +++ b/src/inline/inline.py @@ -46,7 +46,7 @@ def check_eq(self, actual_value, expected_value): """ return self - def check_differential_testing(self, outputs): + def diff_test(self, outputs): """ Assert whether outputs are consistent across different devices. This method compares the outputs from different devices specified in the constructor. diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 42c72bb..ec0a6b9 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -295,7 +295,7 @@ class ExtractInlineTest(ast.NodeTransformer): arg_disabled_str = "disabled" arg_timeout_str = "timeout" arg_devices_str = "devices" - check_differential_testing_str = "check_differential_testing" + diff_test_str = "diff_test" assume = "assume" inline_module_imported = False @@ -933,13 +933,13 @@ def parse_check_not_same(self, node): else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") - def parse_check_differential_testing(self, node): + def parse_diff_test(self, node): if not self.cur_inline_test.devices: - raise MalformedException("check_differential_testing can only be used with the 'devices' parameter.") + raise MalformedException("diff_test() can only be used with the 'devices' parameter.") if len(node.args) != 1: - raise MalformedException("check_differential_testing() requires exactly 1 argument.") + raise MalformedException("diff_test() requires exactly 1 argument.") # Parse the tensor operation output_node = self.parse_group(node.args[0]) @@ -952,7 +952,7 @@ def parse_check_differential_testing(self, node): break if not original_op: - raise MalformedException("Could not find original operation for differential testing") + raise MalformedException("Could not find original operation for diff_test") device_statements = [] device_outputs = [] @@ -1157,13 +1157,13 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) - elif call.func.attr == self.check_differential_testing_str: - self.parse_check_differential_testing(call) + elif call.func.attr == self.diff_test_str: + self.parse_diff_test(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()/check_differential_testing()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") From 7dd09a1d0dc40a4f18352606eb89b226472dd0bd Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Sun, 9 Feb 2025 17:01:11 -0500 Subject: [PATCH 05/10] fixing the issue where the assertion failed if the randn() generated nan since nan!= nan. Adding it in the allclose() call --- src/inline/plugin.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index ec0a6b9..74860ee 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -934,9 +934,8 @@ def parse_check_not_same(self, node): raise MalformedException("inline test: invalid check_not_same(), expected 2 args") def parse_diff_test(self, node): - if not self.cur_inline_test.devices: - raise MalformedException("diff_test() can only be used with the 'devices' parameter.") + raise MalformedException("diff_test can only be used with the 'devices' parameter.") if len(node.args) != 1: raise MalformedException("diff_test() requires exactly 1 argument.") @@ -957,16 +956,25 @@ def parse_diff_test(self, node): device_statements = [] device_outputs = [] - # Generate device-specific tensors and operations + # Use the input tensor from the given statement + input_tensor = self.cur_inline_test.given_stmts[0].value # This is the actual input tensor + input_var = self.cur_inline_test.given_stmts[0].targets[0].id + + # Store original input tensor to avoid regeneration + input_store = ast.Assign( + targets=[ast.Name(id=input_var, ctx=ast.Store())], + value=input_tensor + ) + device_statements.append(input_store) + for device in self.cur_inline_test.devices: # Create device-specific input tensor - input_var = self.cur_inline_test.given_stmts[0].targets[0].id device_input_var = f"{input_var}_{device}" device_input_stmt = ast.Assign( targets=[ast.Name(id=device_input_var, ctx=ast.Store())], value=ast.Call( func=ast.Attribute( - value=self.cur_inline_test.given_stmts[0].value, + value=ast.Name(id=input_var, ctx=ast.Load()), # Use stored input tensor attr="to" ), args=[ast.Constant(value=device)], @@ -975,6 +983,8 @@ def parse_diff_test(self, node): ) device_statements.append(device_input_stmt) + # Rest of your code remains the same... + # Create device-specific operation result device_output_var = f"output_{device}" # Copy the original operation but replace the input tensor with device-specific one @@ -1043,7 +1053,8 @@ def parse_diff_test(self, node): ], keywords=[ ast.keyword(arg="rtol", value=ast.Constant(value=1e-5)), - ast.keyword(arg="atol", value=ast.Constant(value=1e-5)) + ast.keyword(arg="atol", value=ast.Constant(value=1e-5)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) ] ), ast.Constant(value=True) From ddd381e5b4a24840d0e20d4c97ecbb4ecea53b51 Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Wed, 19 Feb 2025 15:28:56 -0500 Subject: [PATCH 06/10] changes to solve the tensor on multiple devices for multiple inputs --- src/inline/plugin.py | 74 ++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 74860ee..4a5f464 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -940,10 +940,9 @@ def parse_diff_test(self, node): if len(node.args) != 1: raise MalformedException("diff_test() requires exactly 1 argument.") - # Parse the tensor operation output_node = self.parse_group(node.args[0]) - # Get the original operation from the previous statements + # Get the original operation original_op = None for stmt in self.cur_inline_test.previous_stmts: if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: @@ -956,43 +955,48 @@ def parse_diff_test(self, node): device_statements = [] device_outputs = [] - # Use the input tensor from the given statement - input_tensor = self.cur_inline_test.given_stmts[0].value # This is the actual input tensor - input_var = self.cur_inline_test.given_stmts[0].targets[0].id - - # Store original input tensor to avoid regeneration - input_store = ast.Assign( - targets=[ast.Name(id=input_var, ctx=ast.Store())], - value=input_tensor - ) - device_statements.append(input_store) - - for device in self.cur_inline_test.devices: - # Create device-specific input tensor - device_input_var = f"{input_var}_{device}" - device_input_stmt = ast.Assign( - targets=[ast.Name(id=device_input_var, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=input_var, ctx=ast.Load()), # Use stored input tensor - attr="to" - ), - args=[ast.Constant(value=device)], - keywords=[] + # Handle all input tensors from given statements + input_vars = [] + for given_stmt in self.cur_inline_test.given_stmts: + input_vars.append(given_stmt.targets[0].id) + # Store original input + device_statements.append( + ast.Assign( + targets=[ast.Name(id=given_stmt.targets[0].id, ctx=ast.Store())], + value=given_stmt.value ) ) - device_statements.append(device_input_stmt) - # Rest of your code remains the same... + # For each device, create device-specific versions of all input tensors + for device in self.cur_inline_test.devices: + device_input_vars = {} # Map original var names to device-specific ones + + # Move each input tensor to the current device + for input_var in input_vars: + device_input_var = f"{input_var}_{device}" + device_input_vars[input_var] = device_input_var + + device_input_stmt = ast.Assign( + targets=[ast.Name(id=device_input_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=input_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + device_statements.append(device_input_stmt) - # Create device-specific operation result + # Create device-specific operation device_output_var = f"output_{device}" - # Copy the original operation but replace the input tensor with device-specific one device_op = copy.deepcopy(original_op) - # Replace the input tensor reference in the operation + + # Replace all input tensor references with device-specific ones for node in ast.walk(device_op): - if isinstance(node, ast.Name) and node.id == input_var: - node.id = device_input_var + if isinstance(node, ast.Name) and node.id in device_input_vars: + node.id = device_input_vars[node.id] device_output_stmt = ast.Assign( targets=[ast.Name(id=device_output_var, ctx=ast.Store())], @@ -1001,18 +1005,16 @@ def parse_diff_test(self, node): device_statements.append(device_output_stmt) device_outputs.append(device_output_var) + # Rest of the comparison code remains the same... # Add the comparison across devices - # Always convert tensors to CPU for comparison comparisons = [] for i in range(len(device_outputs) - 1): device1_var = device_outputs[i] device2_var = device_outputs[i + 1] - # Create CPU versions of the outputs for comparison device1_cpu_var = f"{device1_var}_cpu_compare" device2_cpu_var = f"{device2_var}_cpu_compare" - # Add statements to move tensors to CPU device_statements.append( ast.Assign( targets=[ast.Name(id=device1_cpu_var, ctx=ast.Store())], @@ -1041,7 +1043,6 @@ def parse_diff_test(self, node): ) ) - # Compare the CPU versions comparison = self.build_assert_eq( ast.Call( func=ast.Attribute( @@ -1061,7 +1062,6 @@ def parse_diff_test(self, node): ) comparisons.append(comparison) - # Update the inline test object self.cur_inline_test.previous_stmts.extend(device_statements) self.cur_inline_test.check_stmts.extend(comparisons) From 56d7d45fef81d57372a5632daf21daab47db64b2 Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Mon, 10 Mar 2025 17:22:33 -0400 Subject: [PATCH 07/10] new changes for in place torch apis --- src/inline/plugin.py | 250 ++++++++++++++++++++++++++++--------------- 1 file changed, 162 insertions(+), 88 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 4a5f464..68db137 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -951,119 +951,193 @@ def parse_diff_test(self, node): if not original_op: raise MalformedException("Could not find original operation for diff_test") - - device_statements = [] + + # Check if the operation is in-place (ends with _) + is_inplace = False + op_name = "" + + if isinstance(original_op, ast.Call) and hasattr(original_op, "func"): + if isinstance(original_op.func, ast.Attribute): + op_name = original_op.func.attr + + # Check if operation name ends with '_' (in-place) + is_inplace = op_name.endswith('_') + + # Create our new statements + new_statements = [] device_outputs = [] - - # Handle all input tensors from given statements - input_vars = [] + + # Process input tensors for given_stmt in self.cur_inline_test.given_stmts: - input_vars.append(given_stmt.targets[0].id) - # Store original input - device_statements.append( - ast.Assign( - targets=[ast.Name(id=given_stmt.targets[0].id, ctx=ast.Store())], - value=given_stmt.value - ) - ) - - # For each device, create device-specific versions of all input tensors - for device in self.cur_inline_test.devices: - device_input_vars = {} # Map original var names to device-specific ones + input_var = given_stmt.targets[0].id + ref_var = f"{input_var}_ref" - # Move each input tensor to the current device - for input_var in input_vars: - device_input_var = f"{input_var}_{device}" - device_input_vars[input_var] = device_input_var - - device_input_stmt = ast.Assign( - targets=[ast.Name(id=device_input_var, ctx=ast.Store())], + # Always clone inputs for in-place operations + new_statements.append( + ast.Assign( + targets=[ast.Name(id=ref_var, ctx=ast.Store())], value=ast.Call( func=ast.Attribute( - value=ast.Name(id=input_var, ctx=ast.Load()), - attr="to" + value=given_stmt.value, + attr="clone" ), - args=[ast.Constant(value=device)], + args=[], keywords=[] ) ) - device_statements.append(device_input_stmt) - - # Create device-specific operation - device_output_var = f"output_{device}" + ) + + # Create device-specific versions + for device in self.cur_inline_test.devices: + device_var = f"{input_var}_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=ref_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + ) + + # Create device-specific operations + device_input_map = {device: {} for device in self.cur_inline_test.devices} + for device in self.cur_inline_test.devices: + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + device_input_map[device][input_var] = f"{input_var}_{device}" + device_op = copy.deepcopy(original_op) - # Replace all input tensor references with device-specific ones - for node in ast.walk(device_op): - if isinstance(node, ast.Name) and node.id in device_input_vars: - node.id = device_input_vars[node.id] - - device_output_stmt = ast.Assign( - targets=[ast.Name(id=device_output_var, ctx=ast.Store())], - value=device_op - ) - device_statements.append(device_output_stmt) - device_outputs.append(device_output_var) - - # Rest of the comparison code remains the same... - # Add the comparison across devices - comparisons = [] - for i in range(len(device_outputs) - 1): - device1_var = device_outputs[i] - device2_var = device_outputs[i + 1] + # Replace input references + class ReplaceInputs(ast.NodeTransformer): + def visit_Name(self, node): + if node.id in device_input_map[device]: + return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) + return node - device1_cpu_var = f"{device1_var}_cpu_compare" - device2_cpu_var = f"{device2_var}_cpu_compare" + device_op = ReplaceInputs().visit(device_op) + device_output = f"output_{device}" - device_statements.append( + new_statements.append( ast.Assign( - targets=[ast.Name(id=device1_cpu_var, ctx=ast.Store())], - value=ast.Call( + targets=[ast.Name(id=device_output, ctx=ast.Store())], + value=device_op + ) + ) + device_outputs.append(device_output) + + # Choose appropriate comparison method + comparisons = [] + if is_inplace and ('random' in op_name.lower() or 'exponential' in op_name.lower() or + 'normal' in op_name.lower() or 'uniform' in op_name.lower()): + # For in-place stochastic operations, check basic properties + for device_output in device_outputs: + # 1. Check that output has values (not all zeros) + has_values_check = self.build_assert_true( + ast.Call( func=ast.Attribute( - value=ast.Name(id=device1_var, ctx=ast.Load()), - attr="to" + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=device_output, ctx=ast.Load()), + attr="abs" + ), + args=[], + keywords=[] + ), + attr="sum" ), - args=[ast.Constant(value="cpu")], + args=[], keywords=[] ) ) - ) - - device_statements.append( - ast.Assign( - targets=[ast.Name(id=device2_cpu_var, ctx=ast.Store())], - value=ast.Call( + comparisons.append(has_values_check) + + # 2. Check that output has finite values (not inf) + has_finite_check = self.build_assert_true( + ast.Call( func=ast.Attribute( - value=ast.Name(id=device2_var, ctx=ast.Load()), - attr="to" + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), + attr="isfinite" + ), + args=[ast.Name(id=device_output, ctx=ast.Load())], + keywords=[] + ), + attr="all" ), - args=[ast.Constant(value="cpu")], + args=[], keywords=[] ) ) - ) - - comparison = self.build_assert_eq( - ast.Call( - func=ast.Attribute( - value=ast.Name(id=device1_cpu_var, ctx=ast.Load()), - attr="allclose" + comparisons.append(has_finite_check) + else: + # For deterministic operations, use allclose comparison + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] ), - args=[ - ast.Name(id=device2_cpu_var, ctx=ast.Load()), - ], - keywords=[ - ast.keyword(arg="rtol", value=ast.Constant(value=1e-5)), - ast.keyword(arg="atol", value=ast.Constant(value=1e-5)), - ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) - ] - ), - ast.Constant(value=True) - ) - comparisons.append(comparison) - - self.cur_inline_test.previous_stmts.extend(device_statements) - self.cur_inline_test.check_stmts.extend(comparisons) + ast.Constant(value=True) + ) + comparisons.append(comparison) + + # Replace statements + self.cur_inline_test.previous_stmts = new_statements + self.cur_inline_test.check_stmts = comparisons From dc12691a8800261a985413aaac58d453ea588311 Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Tue, 20 May 2025 00:39:02 +0530 Subject: [PATCH 08/10] random sampling issue --- src/inline/plugin.py | 178 +++++++++++++++++++++++++++++++++---------- 1 file changed, 138 insertions(+), 40 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 68db137..66e3179 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -952,21 +952,93 @@ def parse_diff_test(self, node): if not original_op: raise MalformedException("Could not find original operation for diff_test") - # Check if the operation is in-place (ends with _) - is_inplace = False + # Check if the operation is stochastic + is_stochastic = False op_name = "" if isinstance(original_op, ast.Call) and hasattr(original_op, "func"): if isinstance(original_op.func, ast.Attribute): op_name = original_op.func.attr - # Check if operation name ends with '_' (in-place) - is_inplace = op_name.endswith('_') + # Check if operation name indicates a stochastic operation + stochastic_keywords = ['sample', 'random', 'exponential', 'normal', 'uniform', 'multinomial', 'dropout'] + is_stochastic = any(keyword in op_name.lower() for keyword in stochastic_keywords) # Create our new statements new_statements = [] device_outputs = [] + # Import necessary modules for seed setting + if is_stochastic: + # Import needed modules + import_random = ast.ImportFrom( + module='random', + names=[ast.alias(name='seed', asname=None)], + level=0 + ) + new_statements.append(import_random) + + import_np = ast.ImportFrom( + module='numpy', + names=[ast.alias(name='random', asname='np_random')], + level=0 + ) + new_statements.append(import_np) + + # Create seed function + seed_func_def = ast.FunctionDef( + name='set_random_seed', + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg='seed_value', annotation=None)], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Name(id='seed', ctx=ast.Load()), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='manual_seed', + ctx=ast.Load() + ), + attr='__call__', + ctx=ast.Load() + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id='np_random', ctx=ast.Load()), + attr='seed', + ctx=ast.Load() + ), + attr='__call__', + ctx=ast.Load() + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ) + ], + decorator_list=[], + returns=None + ) + new_statements.append(seed_func_def) + # Process input tensors for given_stmt in self.cur_inline_test.given_stmts: input_var = given_stmt.targets[0].id @@ -1012,6 +1084,18 @@ def parse_diff_test(self, node): input_var = given_stmt.targets[0].id device_input_map[device][input_var] = f"{input_var}_{device}" + # Set same seed before each device operation if stochastic + if is_stochastic: + new_statements.append( + ast.Expr( + value=ast.Call( + func=ast.Name(id='set_random_seed', ctx=ast.Load()), + args=[ast.Constant(value=42)], # Use constant seed 42 + keywords=[] + ) + ) + ) + device_op = copy.deepcopy(original_op) # Replace input references @@ -1034,51 +1118,65 @@ def visit_Name(self, node): # Choose appropriate comparison method comparisons = [] - if is_inplace and ('random' in op_name.lower() or 'exponential' in op_name.lower() or - 'normal' in op_name.lower() or 'uniform' in op_name.lower()): - # For in-place stochastic operations, check basic properties - for device_output in device_outputs: - # 1. Check that output has values (not all zeros) - has_values_check = self.build_assert_true( - ast.Call( - func=ast.Attribute( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=device_output, ctx=ast.Load()), - attr="abs" - ), - args=[], - keywords=[] + if is_stochastic: + # For stochastic operations, use standard comparison + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" ), - attr="sum" - ), - args=[], - keywords=[] + args=[ast.Constant(value="cpu")], + keywords=[] + ) ) ) - comparisons.append(has_values_check) - # 2. Check that output has finite values (not inf) - has_finite_check = self.build_assert_true( + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( ast.Call( func=ast.Attribute( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="torch", ctx=ast.Load()), - attr="isfinite" - ), - args=[ast.Name(id=device_output, ctx=ast.Load())], - keywords=[] - ), - attr="all" + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" ), - args=[], - keywords=[] - ) + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] + ), + ast.Constant(value=True) ) - comparisons.append(has_finite_check) + comparisons.append(comparison) else: - # For deterministic operations, use allclose comparison + # For deterministic operations, use standard comparison for i in range(len(device_outputs) - 1): dev1 = device_outputs[i] dev2 = device_outputs[i + 1] From 661c9527a41936e6b0e5588c269531b7e2a3e1d0 Mon Sep 17 00:00:00 2001 From: Chaitanya Shahane Date: Thu, 22 May 2025 00:12:28 +0530 Subject: [PATCH 09/10] new update to apply changes to tackle random operations for all --- src/inline/plugin.py | 312 ++++++++++++++++--------------------------- 1 file changed, 115 insertions(+), 197 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 66e3179..13f3025 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -952,92 +952,70 @@ def parse_diff_test(self, node): if not original_op: raise MalformedException("Could not find original operation for diff_test") - # Check if the operation is stochastic - is_stochastic = False - op_name = "" - - if isinstance(original_op, ast.Call) and hasattr(original_op, "func"): - if isinstance(original_op.func, ast.Attribute): - op_name = original_op.func.attr - - # Check if operation name indicates a stochastic operation - stochastic_keywords = ['sample', 'random', 'exponential', 'normal', 'uniform', 'multinomial', 'dropout'] - is_stochastic = any(keyword in op_name.lower() for keyword in stochastic_keywords) - # Create our new statements new_statements = [] device_outputs = [] - # Import necessary modules for seed setting - if is_stochastic: - # Import needed modules - import_random = ast.ImportFrom( - module='random', - names=[ast.alias(name='seed', asname=None)], - level=0 - ) - new_statements.append(import_random) - - import_np = ast.ImportFrom( - module='numpy', - names=[ast.alias(name='random', asname='np_random')], - level=0 - ) - new_statements.append(import_np) - - # Create seed function - seed_func_def = ast.FunctionDef( - name='set_random_seed', - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg='seed_value', annotation=None)], - kwonlyargs=[], - kw_defaults=[], - defaults=[] + # Import necessary modules for seed setting - Always add these + # Import random module + import_random = ast.ImportFrom( + module='random', + names=[ast.alias(name='seed', asname=None)], + level=0 + ) + new_statements.append(import_random) + + # Import numpy.random + import_np = ast.ImportFrom( + module='numpy', + names=[ast.alias(name='random', asname='np_random')], + level=0 + ) + new_statements.append(import_np) + + # Create seed function - Always add this + seed_func_def = ast.FunctionDef( + name='set_random_seed', + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg='seed_value', annotation=None)], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Name(id='seed', ctx=ast.Load()), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) ), - body=[ - ast.Expr( - value=ast.Call( - func=ast.Name(id='seed', ctx=ast.Load()), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Attribute( - value=ast.Name(id='torch', ctx=ast.Load()), - attr='manual_seed', - ctx=ast.Load() - ), - attr='__call__', - ctx=ast.Load() - ), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Attribute( - value=ast.Name(id='np_random', ctx=ast.Load()), - attr='seed', - ctx=ast.Load() - ), - attr='__call__', - ctx=ast.Load() - ), - args=[ast.Name(id='seed_value', ctx=ast.Load())], - keywords=[] - ) + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='manual_seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] ) - ], - decorator_list=[], - returns=None - ) - new_statements.append(seed_func_def) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='np_random', ctx=ast.Load()), + attr='seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ) + ], + decorator_list=[], + returns=None + ) + new_statements.append(seed_func_def) # Process input tensors for given_stmt in self.cur_inline_test.given_stmts: @@ -1084,17 +1062,16 @@ def parse_diff_test(self, node): input_var = given_stmt.targets[0].id device_input_map[device][input_var] = f"{input_var}_{device}" - # Set same seed before each device operation if stochastic - if is_stochastic: - new_statements.append( - ast.Expr( - value=ast.Call( - func=ast.Name(id='set_random_seed', ctx=ast.Load()), - args=[ast.Constant(value=42)], # Use constant seed 42 - keywords=[] - ) + # Always set seed before each device operation - no condition check + new_statements.append( + ast.Expr( + value=ast.Call( + func=ast.Name(id='set_random_seed', ctx=ast.Load()), + args=[ast.Constant(value=42)], # Use constant seed 42 + keywords=[] ) ) + ) device_op = copy.deepcopy(original_op) @@ -1116,122 +1093,63 @@ def visit_Name(self, node): ) device_outputs.append(device_output) - # Choose appropriate comparison method + # Standard comparison method for all operations - no condition check comparisons = [] - if is_stochastic: - # For stochastic operations, use standard comparison - for i in range(len(device_outputs) - 1): - dev1 = device_outputs[i] - dev2 = device_outputs[i + 1] - - dev1_cpu = f"{dev1}_cpu" - dev2_cpu = f"{dev2}_cpu" - - # Move outputs back to CPU for comparison - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev1, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) - ) - ) - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev2, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) - ) - ) - - # Standard allclose comparison - comparison = self.build_assert_eq( - ast.Call( + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( func=ast.Attribute( - value=ast.Name(id=dev1_cpu, ctx=ast.Load()), - attr="allclose" + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" ), - args=[ - ast.Name(id=dev2_cpu, ctx=ast.Load()) - ], - keywords=[ - ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) - ] - ), - ast.Constant(value=True) - ) - comparisons.append(comparison) - else: - # For deterministic operations, use standard comparison - for i in range(len(device_outputs) - 1): - dev1 = device_outputs[i] - dev2 = device_outputs[i + 1] - - dev1_cpu = f"{dev1}_cpu" - dev2_cpu = f"{dev2}_cpu" - - # Move outputs back to CPU for comparison - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev1, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) - ) - ) - - new_statements.append( - ast.Assign( - targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=dev2, ctx=ast.Load()), - attr="to" - ), - args=[ast.Constant(value="cpu")], - keywords=[] - ) + args=[ast.Constant(value="cpu")], + keywords=[] ) ) - - # Standard allclose comparison - comparison = self.build_assert_eq( - ast.Call( + ) + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( func=ast.Attribute( - value=ast.Name(id=dev1_cpu, ctx=ast.Load()), - attr="allclose" + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" ), - args=[ - ast.Name(id=dev2_cpu, ctx=ast.Load()) - ], - keywords=[ - ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), - ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) - ] - ), - ast.Constant(value=True) + args=[ast.Constant(value="cpu")], + keywords=[] + ) ) - comparisons.append(comparison) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] + ), + ast.Constant(value=True) + ) + comparisons.append(comparison) # Replace statements self.cur_inline_test.previous_stmts = new_statements From bfcce75ab9dcaca68029902c7a3239517ce406cf Mon Sep 17 00:00:00 2001 From: hanse141 Date: Fri, 21 Nov 2025 13:03:16 -0500 Subject: [PATCH 10/10] Isolated Constructor Changes --- src/inline/plugin.py | 397 +++++++++++++++++++------------------------ 1 file changed, 177 insertions(+), 220 deletions(-) diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 13f3025..d331cf3 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -365,230 +365,44 @@ def parse_constructor(self, node): """ Parse a constructor call. """ + + # Argument Order: + # 0) test_name (str) + # 1) parameterized (bool) + # 2) repeated (positive integer) + # 3) tag (str) + # 4) disabled (bool) + # 5) timeout (positive float) + # 6) devices (str array) + + + + keyword_idxs = { + self.arg_test_name_str : 0, + self.arg_parameterized_str : 1, + self.arg_repeated_str : 2, + self.arg_tag_str : 3, + self.arg_disabled_str : 4, + self.arg_timeout_str : 5, + self.arg_devices_str : 6 + } + NUM_OF_ARGUMENTS = 7 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments - if sys.version_info >= (3, 8, 0): - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Constant) and isinstance(arg.value, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.value - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Constant) and isinstance(arg.value, int): - if arg.value <= 0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.value - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) - self.cur_inline_test.tag = tags - elif index == 4 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - elif ( - index == 5 - and isinstance(arg, ast.Constant) - and (isinstance(arg.value, float) or isinstance(arg.value, int)) - ): - self.cur_inline_test.timeout = arg.value - - elif index == 6 and isinstance(arg, ast.List): - devices = [] - for elt in arg.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException("devices can only be List of string") - if elt.value not in {"cpu", "cuda", "mps"}: - raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") - devices.append(elt.value) - self.cur_inline_test.devices = devices + self.parse_constructor_args(node.args) + + #keyword arguments + keyword_args = [] + + #create list with 7 null values (for each position) + for i in range(0, NUM_OF_ARGUMENTS): + keyword_args.append(None) + + for keyword in node.keywords: + keyword_args[keyword_idxs[keyword.arg]] = keyword.value + self.parse_constructor_args(keyword_args) - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, str) - ): - self.cur_inline_test.test_name = keyword.value.value - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, int) - ): - if keyword.value.value <= 0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.value - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): - tags = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) - self.cur_inline_test.tag = tags - # Add devices handling for keyword args - elif keyword.arg == self.arg_devices_str and isinstance(keyword.value, ast.List): - devices = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException("devices can only be List of string") - if elt.value not in {"cpu", "cuda", "mps"}: - raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") - devices.append(elt.value) - self.cur_inline_test.devices = devices - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Constant) - and (isinstance(keyword.value.value, float) or isinstance(keyword.value.value, int)) - ): - if keyword.value.value <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.value - # Add devices handling for Python 3.7 - elif index == 6 and isinstance(arg, ast.List): - devices = [] - for elt in arg.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): # Note: ast.Str for Python 3.7 - raise MalformedException("devices can only be List of string") - if elt.s not in {"cpu", "cuda", "mps"}: - raise MalformedException(f"Invalid device: {elt.s}. Must be one of ['cpu', 'cuda', 'mps']") - devices.append(elt.s) - self.cur_inline_test.devices = devices - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Str) and isinstance(arg.s, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.s - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Num) and isinstance(arg.n, int): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.n - # check if "tag" is a list of string - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif index == 4 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - # check if "timeout" is a positive int - elif ( - index == 5 and isinstance(arg, ast.Num) and (isinstance(arg.n, float) or isinstance(arg.n, int)) - ): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = arg.n - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive intege, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Str) - and isinstance(keyword.value.s, str) - ): - self.cur_inline_test.test_name = keyword.value.s - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Num) - and isinstance(keyword.value.n, int) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.n - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): - tags = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Num) - and (isinstance(keyword.value.n, float) or isinstance(keyword.value.n, int)) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.n - #keyword arg for devices - elif ( - keyword.arg == self.arg_devices_str - and isinstance(keyword.value, ast.List) - ): - devices = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException("devices can only be List of string") - if elt.s not in {"cpu", "cuda", "mps"}: - raise MalformedException(f"Invalid device: {elt.s}. Must be one of ['cpu', 'cuda', 'mps']") - devices.append(elt.s) - self.cur_inline_test.devices = devices - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - raise MalformedException(f"inline test: invalid {self.class_name_str}(), expected at most 3 args") if not self.cur_inline_test.test_name: # by default, use lineno as test name @@ -596,6 +410,149 @@ def parse_constructor(self, node): # set the line number self.cur_inline_test.lineno = node.lineno + def parse_constructor_args(self, args): + class ConstrArgs(enum.Enum): + TEST_NAME = 0 + PARAMETERIZED = 1 + REPEATED = 2 + TAG_STR = 3 + DISABLED = 4 + TIMEOUT = 5 + DEVICES = 6 + + property_names = { + ConstrArgs.TEST_NAME : "test_name", + ConstrArgs.PARAMETERIZED : "parameterized", + ConstrArgs.REPEATED : "repeated", + ConstrArgs.TAG_STR : "tag", + ConstrArgs.DISABLED : "disabled", + ConstrArgs.TIMEOUT : "timeout", + ConstrArgs.DEVICES : "devices" + } + + pre_38_val_names = { + ConstrArgs.TEST_NAME : "s", + ConstrArgs.PARAMETERIZED : "value", + ConstrArgs.REPEATED : "n", + ConstrArgs.TAG_STR : "s", + ConstrArgs.DISABLED : "value", + ConstrArgs.TIMEOUT : "n", + ConstrArgs.DEVICES : "" + } + + pre_38_expec_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Str, + ConstrArgs.PARAMETERIZED : ast.NameConstant, + ConstrArgs.REPEATED : ast.Num, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.NameConstant, + ConstrArgs.TIMEOUT : ast.Num, + } + + expected_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Constant, + ConstrArgs.PARAMETERIZED : ast.Constant, + ConstrArgs.REPEATED : ast.Constant, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.Constant, + ConstrArgs.TIMEOUT : ast.Constant + } + + expected_ast_val_args = { + ConstrArgs.TEST_NAME : [str], + ConstrArgs.PARAMETERIZED : [bool], + ConstrArgs.REPEATED : [int], + ConstrArgs.TAG_STR : [None], + ConstrArgs.DISABLED : [bool], + ConstrArgs.TIMEOUT : [float, int], + ConstrArgs.DEVICES : [str] + } + + NUM_OF_ARGUMENTS = 7 + + # Arguments organized by expected ast type, value type, and index in that order + for index, arg in enumerate(args): + # Skips over null arguments; needed for keywords + if arg == None: + continue + + # Devices are not referenced in versions before 3.8; all other arguments can be from any version + if index == ConstrArgs.DEVICES and isinstance(arg, ast.List): + devices = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException("devices can only be List of string") + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + self.cur_inline_test.devices = devices + # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8 + else: + corr_arg_type = False + corr_val_type = False + value_prop_name = "" + arg_idx = ConstrArgs(index) + + if sys.version_info >= (3, 8, 0) and isinstance(arg, expected_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = "value" + elif sys.version_info < (3, 8, 0) and isinstance(arg, pre_38_expec_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = pre_38_val_names[arg_idx] + + # Verifies value types; skipped for ast node types with no nested values + for arg_type in expected_ast_val_args[arg_idx]: + if arg_type == None: + corr_val_type = True + break + if isinstance(arg.value, arg_type): + corr_val_type = True + break + + if corr_val_type and corr_arg_type: + # Accounts for additional checks for REPEATED and TAG_STR arguments + if arg_idx == ConstrArgs.REPEATED: + if arg.value <= 0: + raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") + self.cur_inline_test.repeated = getattr(arg, value_prop_name) + elif arg_idx == ConstrArgs.TAG_STR: + tags = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException(f"tag can only be List of string") + tags.append(getattr(elt, value_prop_name)) + self.cur_inline_test.tag = tags + # For non-special cases, set the attribute defined by the dictionary + else: + setattr(self.cur_inline_test, + property_names[arg_idx], + getattr(arg, value_prop_name)) + + + ## Match implementation of above conditional tree; commented since Python < 3.10 does not support match + + # match arg_idx: + # case ConstrArgs.REPEATED: + # if arg.value <= 0: + # raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") + # self.cur_inline_test.repeated = getattr(arg, value_prop_name) + # case ConstrArgs.TAG_STR: + # tags = [] + # for elt in arg.elts: + # if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + # raise MalformedException(f"tag can only be List of string") + # tags.append(getattr(elt, value_prop_name)) + # self.cur_inline_test.tag = tags + # # For non-special cases, set the attribute defined by the dictionary + # case _: + # setattr(self.cur_inline_test, + # property_names[arg_idx], + # getattr(arg, value_prop_name)) + else: + raise MalformedException( + f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" + ) + def parameterized_inline_tests_init(self, node: ast.List): if not self.cur_inline_test.parameterized_inline_tests: self.cur_inline_test.parameterized_inline_tests = [InlineTest() for _ in range(len(node.elts))]