diff --git a/src/inline/inline.py b/src/inline/inline.py index 5766542..847c253 100644 --- a/src/inline/inline.py +++ b/src/inline/inline.py @@ -10,6 +10,7 @@ def __init__( tag: List = [], disabled: bool = False, timeout: float = -1.0, + devices: List = None, ): """ Initialize Inline object with test name / parametrized flag @@ -20,6 +21,8 @@ def __init__( :param tag: tags to group tests :param disabled: whether the test is disabled :param timeout: seconds to timeout the test, must be a float + :param devices: list of devices to run differential testing on (e.g., ["cpu", "cuda", "mps"]) + if None, differential testing is disabled """ def given(self, variable, value): @@ -42,6 +45,19 @@ def check_eq(self, actual_value, expected_value): :raises: AssertionError """ return self + + def diff_test(self, outputs): + """ + Assert whether outputs are consistent across different devices. + This method compares the outputs from different devices specified in the constructor. + + :param outputs: a dictionary mapping device names to their outputs, or a single output value + if a single value is provided, the test will run the computation on all devices + and compare against this reference value + :returns: Inline object + :raises: AssertionError if outputs differ across devices + """ + return self def check_neq(self, actual_value, expected_value): """ diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 11c0774..d331cf3 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -171,6 +171,7 @@ def __init__(self): self.tag = [] self.disabled = False self.timeout = -1.0 + self.devices = None self.globs = {} def to_test(self): @@ -293,6 +294,8 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" + arg_devices_str = "devices" + diff_test_str = "diff_test" assume = "assume" inline_module_imported = False @@ -362,186 +365,44 @@ def parse_constructor(self, node): """ Parse a constructor call. """ - NUM_OF_ARGUMENTS = 6 + + # Argument Order: + # 0) test_name (str) + # 1) parameterized (bool) + # 2) repeated (positive integer) + # 3) tag (str) + # 4) disabled (bool) + # 5) timeout (positive float) + # 6) devices (str array) + + + + keyword_idxs = { + self.arg_test_name_str : 0, + self.arg_parameterized_str : 1, + self.arg_repeated_str : 2, + self.arg_tag_str : 3, + self.arg_disabled_str : 4, + self.arg_timeout_str : 5, + self.arg_devices_str : 6 + } + + NUM_OF_ARGUMENTS = 7 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments - if sys.version_info >= (3, 8, 0): - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Constant) and isinstance(arg.value, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.value - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Constant) and isinstance(arg.value, int): - if arg.value <= 0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.value - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) - self.cur_inline_test.tag = tags - elif index == 4 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - elif ( - index == 5 - and isinstance(arg, ast.Constant) - and (isinstance(arg.value, float) or isinstance(arg.value, int)) - ): - self.cur_inline_test.timeout = arg.value - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, str) - ): - self.cur_inline_test.test_name = keyword.value.value - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, int) - ): - if keyword.value.value <= 0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.value - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): - tags = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Constant) - and (isinstance(keyword.value.value, float) or isinstance(keyword.value.value, int)) - ): - if keyword.value.value <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.value - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Str) and isinstance(arg.s, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.s - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Num) and isinstance(arg.n, int): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.n - # check if "tag" is a list of string - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif index == 4 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - # check if "timeout" is a positive int - elif ( - index == 5 and isinstance(arg, ast.Num) and (isinstance(arg.n, float) or isinstance(arg.n, int)) - ): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = arg.n - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive intege, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Str) - and isinstance(keyword.value.s, str) - ): - self.cur_inline_test.test_name = keyword.value.s - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Num) - and isinstance(keyword.value.n, int) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.n - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): - tags = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Num) - and (isinstance(keyword.value.n, float) or isinstance(keyword.value.n, int)) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.n - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - raise MalformedException(f"inline test: invalid {self.class_name_str}(), expected at most 3 args") + self.parse_constructor_args(node.args) + + #keyword arguments + keyword_args = [] + + #create list with 7 null values (for each position) + for i in range(0, NUM_OF_ARGUMENTS): + keyword_args.append(None) + + for keyword in node.keywords: + keyword_args[keyword_idxs[keyword.arg]] = keyword.value + self.parse_constructor_args(keyword_args) + if not self.cur_inline_test.test_name: # by default, use lineno as test name @@ -549,6 +410,149 @@ def parse_constructor(self, node): # set the line number self.cur_inline_test.lineno = node.lineno + def parse_constructor_args(self, args): + class ConstrArgs(enum.Enum): + TEST_NAME = 0 + PARAMETERIZED = 1 + REPEATED = 2 + TAG_STR = 3 + DISABLED = 4 + TIMEOUT = 5 + DEVICES = 6 + + property_names = { + ConstrArgs.TEST_NAME : "test_name", + ConstrArgs.PARAMETERIZED : "parameterized", + ConstrArgs.REPEATED : "repeated", + ConstrArgs.TAG_STR : "tag", + ConstrArgs.DISABLED : "disabled", + ConstrArgs.TIMEOUT : "timeout", + ConstrArgs.DEVICES : "devices" + } + + pre_38_val_names = { + ConstrArgs.TEST_NAME : "s", + ConstrArgs.PARAMETERIZED : "value", + ConstrArgs.REPEATED : "n", + ConstrArgs.TAG_STR : "s", + ConstrArgs.DISABLED : "value", + ConstrArgs.TIMEOUT : "n", + ConstrArgs.DEVICES : "" + } + + pre_38_expec_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Str, + ConstrArgs.PARAMETERIZED : ast.NameConstant, + ConstrArgs.REPEATED : ast.Num, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.NameConstant, + ConstrArgs.TIMEOUT : ast.Num, + } + + expected_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Constant, + ConstrArgs.PARAMETERIZED : ast.Constant, + ConstrArgs.REPEATED : ast.Constant, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.Constant, + ConstrArgs.TIMEOUT : ast.Constant + } + + expected_ast_val_args = { + ConstrArgs.TEST_NAME : [str], + ConstrArgs.PARAMETERIZED : [bool], + ConstrArgs.REPEATED : [int], + ConstrArgs.TAG_STR : [None], + ConstrArgs.DISABLED : [bool], + ConstrArgs.TIMEOUT : [float, int], + ConstrArgs.DEVICES : [str] + } + + NUM_OF_ARGUMENTS = 7 + + # Arguments organized by expected ast type, value type, and index in that order + for index, arg in enumerate(args): + # Skips over null arguments; needed for keywords + if arg == None: + continue + + # Devices are not referenced in versions before 3.8; all other arguments can be from any version + if index == ConstrArgs.DEVICES and isinstance(arg, ast.List): + devices = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException("devices can only be List of string") + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + self.cur_inline_test.devices = devices + # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8 + else: + corr_arg_type = False + corr_val_type = False + value_prop_name = "" + arg_idx = ConstrArgs(index) + + if sys.version_info >= (3, 8, 0) and isinstance(arg, expected_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = "value" + elif sys.version_info < (3, 8, 0) and isinstance(arg, pre_38_expec_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = pre_38_val_names[arg_idx] + + # Verifies value types; skipped for ast node types with no nested values + for arg_type in expected_ast_val_args[arg_idx]: + if arg_type == None: + corr_val_type = True + break + if isinstance(arg.value, arg_type): + corr_val_type = True + break + + if corr_val_type and corr_arg_type: + # Accounts for additional checks for REPEATED and TAG_STR arguments + if arg_idx == ConstrArgs.REPEATED: + if arg.value <= 0: + raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") + self.cur_inline_test.repeated = getattr(arg, value_prop_name) + elif arg_idx == ConstrArgs.TAG_STR: + tags = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException(f"tag can only be List of string") + tags.append(getattr(elt, value_prop_name)) + self.cur_inline_test.tag = tags + # For non-special cases, set the attribute defined by the dictionary + else: + setattr(self.cur_inline_test, + property_names[arg_idx], + getattr(arg, value_prop_name)) + + + ## Match implementation of above conditional tree; commented since Python < 3.10 does not support match + + # match arg_idx: + # case ConstrArgs.REPEATED: + # if arg.value <= 0: + # raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") + # self.cur_inline_test.repeated = getattr(arg, value_prop_name) + # case ConstrArgs.TAG_STR: + # tags = [] + # for elt in arg.elts: + # if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + # raise MalformedException(f"tag can only be List of string") + # tags.append(getattr(elt, value_prop_name)) + # self.cur_inline_test.tag = tags + # # For non-special cases, set the attribute defined by the dictionary + # case _: + # setattr(self.cur_inline_test, + # property_names[arg_idx], + # getattr(arg, value_prop_name)) + else: + raise MalformedException( + f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" + ) + def parameterized_inline_tests_init(self, node: ast.List): if not self.cur_inline_test.parameterized_inline_tests: self.cur_inline_test.parameterized_inline_tests = [InlineTest() for _ in range(len(node.elts))] @@ -885,6 +889,231 @@ def parse_check_not_same(self, node): self.cur_inline_test.check_stmts.append(assert_node) else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") + + def parse_diff_test(self, node): + if not self.cur_inline_test.devices: + raise MalformedException("diff_test can only be used with the 'devices' parameter.") + + if len(node.args) != 1: + raise MalformedException("diff_test() requires exactly 1 argument.") + + output_node = self.parse_group(node.args[0]) + + # Get the original operation + original_op = None + for stmt in self.cur_inline_test.previous_stmts: + if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: + original_op = stmt.value + break + + if not original_op: + raise MalformedException("Could not find original operation for diff_test") + + # Create our new statements + new_statements = [] + device_outputs = [] + + # Import necessary modules for seed setting - Always add these + # Import random module + import_random = ast.ImportFrom( + module='random', + names=[ast.alias(name='seed', asname=None)], + level=0 + ) + new_statements.append(import_random) + + # Import numpy.random + import_np = ast.ImportFrom( + module='numpy', + names=[ast.alias(name='random', asname='np_random')], + level=0 + ) + new_statements.append(import_np) + + # Create seed function - Always add this + seed_func_def = ast.FunctionDef( + name='set_random_seed', + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg='seed_value', annotation=None)], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Name(id='seed', ctx=ast.Load()), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='manual_seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='np_random', ctx=ast.Load()), + attr='seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ) + ], + decorator_list=[], + returns=None + ) + new_statements.append(seed_func_def) + + # Process input tensors + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + ref_var = f"{input_var}_ref" + + # Always clone inputs for in-place operations + new_statements.append( + ast.Assign( + targets=[ast.Name(id=ref_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=given_stmt.value, + attr="clone" + ), + args=[], + keywords=[] + ) + ) + ) + + # Create device-specific versions + for device in self.cur_inline_test.devices: + device_var = f"{input_var}_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=ref_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + ) + + # Create device-specific operations + device_input_map = {device: {} for device in self.cur_inline_test.devices} + for device in self.cur_inline_test.devices: + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + device_input_map[device][input_var] = f"{input_var}_{device}" + + # Always set seed before each device operation - no condition check + new_statements.append( + ast.Expr( + value=ast.Call( + func=ast.Name(id='set_random_seed', ctx=ast.Load()), + args=[ast.Constant(value=42)], # Use constant seed 42 + keywords=[] + ) + ) + ) + + device_op = copy.deepcopy(original_op) + + # Replace input references + class ReplaceInputs(ast.NodeTransformer): + def visit_Name(self, node): + if node.id in device_input_map[device]: + return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) + return node + + device_op = ReplaceInputs().visit(device_op) + device_output = f"output_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_output, ctx=ast.Store())], + value=device_op + ) + ) + device_outputs.append(device_output) + + # Standard comparison method for all operations - no condition check + comparisons = [] + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] + ), + ast.Constant(value=True) + ) + comparisons.append(comparison) + + # Replace statements + self.cur_inline_test.previous_stmts = new_statements + self.cur_inline_test.check_stmts = comparisons + + + def build_fail(self): equal_node = ast.Compare( @@ -986,11 +1215,13 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) + elif call.func.attr == self.diff_test_str: + self.parse_diff_test(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") @@ -1131,6 +1362,8 @@ def _find(self, tests, obj, module, globs, seen): ###################################################################### class InlineTestRunner: def run(self, test: InlineTest, out: List) -> None: + test_str = test.to_test() + print(test_str) tree = ast.parse(test.to_test()) codeobj = compile(tree, filename="", mode="exec") start_time = time.time()