diff --git a/tests/interop/test_validate_gpu_nodes.py b/tests/interop/test_validate_gpu_nodes.py index f38c8efc..c3ae585c 100644 --- a/tests/interop/test_validate_gpu_nodes.py +++ b/tests/interop/test_validate_gpu_nodes.py @@ -27,7 +27,11 @@ def test_validate_gpu_nodes(openshift_dyn_client): found = False for machineset in machinesets: logger.info(machineset.instance.metadata.name) - if re.search("gpu", machineset.instance.metadata.name): + # "gpu" for AWS machineset + # "nvidia" for Azure machineset + if re.search("gpu", machineset.instance.metadata.name) or re.search( + "nvidia", machineset.instance.metadata.name + ): gpu_machineset = machineset found = True break @@ -86,18 +90,20 @@ def test_validate_gpu_nodes(openshift_dyn_client): logger.info("Checking GPU machineset instance type") err_msg = "No instanceType found for GPU machineset" - try: - logger.info( - machineset.instance.spec.template.spec.providerSpec.value.instanceType - ) - except AttributeError: - logger.error(f"FAIL: {err_msg}") - assert False, err_msg - instance_type = str( - machineset.instance.spec.template.spec.providerSpec.value.instanceType + # for AWS + instance_type = ( + gpu_machineset.instance.spec.template.spec.providerSpec.value.instanceType ) - if instance_type == "None": + if instance_type is None: + # for Azure + instance_type = ( + gpu_machineset.instance.spec.template.spec.providerSpec.value.vmSize + ) + + logger.info(instance_type) + + if instance_type is None: logger.error(f"FAIL: {err_msg}") assert False, err_msg