Skip to content
24 changes: 24 additions & 0 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,30 @@ elif [[ "${TEST_CONFIG}" == smoke ]]; then
test_python_smoke
elif [[ "${TEST_CONFIG}" == h100_distributed ]]; then
test_h100_distributed
elif [[ "${TEST_CONFIG}" == "check_unimplemented_calls" ]]; then
# Get the changed files from the PR that match torch/_dynamo/**/*.py pattern
CHANGED_FILES=$(git diff --name-only origin/main HEAD | grep -E 'torch/_dynamo/.*\.py$' | tr '\n' ' ')
##CHANGED_FILES=$(git diff --name-only "$GITHUB_BASE_SHA" "$GITHUB_SHA" | grep -E 'torch/_dynamo/.*\.py$' | tr '\n' ' ')

if [[ -z "$CHANGED_FILES" ]]; then
echo "No Python files in torch/_dynamo/ directory were modified."
exit 0
fi

echo "Checking unimplemented_v2 calls in: $CHANGED_FILES"

if ! python tools/dynamo/gb_id_mapping.py check --files "$CHANGED_FILES" --registry-path tools/dynamo/graph_break_registry.json; then
echo "::error::Found unimplemented_v2 calls that don't match the registry."
echo "::error::Please update the registry using one of these commands:"
echo "::error::- Add new entry: python tools/dynamo/gb_id_mapping.py add \"GB_TYPE\" PATH_TO_FILE [--additional-info \"INFO\"]"
echo "::error::- Update existing: python tools/dynamo/gb_id_mapping.py update \"GB_TYPE\" PATH_TO_FILE [--new_gb_type \"NEW_NAME\"] [--additional-info \"INFO\"]"
echo "::error::- Recreate registry: python tools/dynamo/gb_id_mapping.py create"
echo "::error::Note: If you've reset the entire registry file, you can force push to bypass this check."
exit 1
fi

echo "All unimplemented_v2 calls match the registry."

else
install_torchvision
install_monkeytype
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jobs:
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "check_unimplemented_calls", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"},
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
Expand Down
75 changes: 74 additions & 1 deletion tools/dynamo/gb_id_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,50 @@ def cmd_update_gb_type(
return True


def check_unimplemented_calls(files, registry_path):
"""
Checks if the unimplemented_v2 calls in the specified files match the entries in the registry.
Args:
files (list of str): A list of file paths to check for unimplemented_v2 calls.
registry_path (str or Path): The path to the registry JSON file containing expected entries.
Returns:
bool: True if all unimplemented_v2 calls match the registry entries, False if there are mismatches.
The function compares the gb_type, context, explanation, and hints of each unimplemented_v2 call
in the provided files against the corresponding entries in the registry. If any discrepancies are
found, it prints the details and returns False. Otherwise, it confirms that all calls match the
registry and returns True.
"""
registry_path = Path(registry_path)
reg = load_registry(registry_path)

gb_type_to_entry = {entries[0]["Gb_type"]: entries[0] for _, entries in reg.items()}

mismatches = []
for file in files:
calls = find_unimplemented_v2_calls(Path(file))
for call in calls:
gb_type = call["gb_type"]
if gb_type not in gb_type_to_entry:
mismatches.append((gb_type, file, "Not found in registry"))
continue

entry = gb_type_to_entry[gb_type]
if call["context"] != entry["Context"]:
mismatches.append((gb_type, file, "Context mismatch"))
elif call["explanation"] != entry["Explanation"]:
mismatches.append((gb_type, file, "Explanation mismatch"))
elif sorted(call["hints"]) != sorted(entry["Hints"]):
mismatches.append((gb_type, file, "Hints mismatch"))

if mismatches:
for gb_type, file, reason in mismatches:
print(f" - {gb_type} in {file}: {reason}")
return False

print("All unimplemented_v2 calls match the registry.")
return True


def create_registry(dynamo_dir, registry_path):
calls = find_unimplemented_v2_calls(dynamo_dir)
registry = {}
Expand Down Expand Up @@ -314,6 +358,12 @@ def main():
default=default_dynamo_dir,
help="Directory to search for unimplemented_v2 calls.",
)
create_parser.add_argument(
"--registry-path",
type=str,
default=str(registry_path),
help="Path to save the registry JSON file",
)

add_parser = subparsers.add_parser("add", help="Add a gb_type to registry")
add_parser.add_argument("gb_type", help="The gb_type to add")
Expand All @@ -323,6 +373,12 @@ def main():
add_parser.add_argument(
"--additional-info", help="Optional additional information to include"
)
add_parser.add_argument(
"--registry-path",
type=str,
default=str(registry_path),
help="Path to save the registry JSON file",
)

update_parser = subparsers.add_parser(
"update", help="Update an existing gb_type in registry"
Expand All @@ -338,8 +394,20 @@ def main():
update_parser.add_argument(
"--additional-info", help="Optional additional information to include"
)
update_parser.add_argument(
"--registry-path",
type=str,
default=str(registry_path),
help="Path to save the registry JSON file",
)

parser.add_argument(
check_parser = subparsers.add_parser(
"check", help="Check if unimplemented_v2 calls match registry entries"
)
check_parser.add_argument(
"--files", type=str, help="Space-separated list of files to check"
)
check_parser.add_argument(
"--registry-path",
type=str,
default=str(registry_path),
Expand All @@ -366,6 +434,11 @@ def main():
)
if not success:
sys.exit(1)
elif args.command == "check":
files = args.files.split()
success = check_unimplemented_calls(files, args.registry_path)
if not success:
sys.exit(1)
else:
parser.print_help()

Expand Down
13 changes: 13 additions & 0 deletions torch/_dynamo/testing_github_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# mypy: ignore-errors

from . import graph_break_hints
from .exc import unimplemented_v2


def dummy_method():
unimplemented_v2(
gb_type="Testing gb_type",
context="Testing context",
explanation="testing",
hints=[*graph_break_hints.USER_ERROR],
)
Loading