diff --git a/.circleci/config.yml b/.circleci/config.yml index 81329f6e85a..bc724dcc3df 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -99,6 +99,81 @@ commands: - brew_install: formulae: libtool + apt_install: + parameters: + args: + type: string + descr: + type: string + default: "" + update: + type: boolean + default: true + steps: + - run: + name: > + <<^ parameters.descr >> apt install << parameters.args >> <> + <<# parameters.descr >> << parameters.descr >> <> + command: | + <<# parameters.update >> sudo apt update -qy <> + sudo apt install << parameters.args >> + + pip_install: + parameters: + args: + type: string + descr: + type: string + default: "" + user: + type: boolean + default: true + steps: + - run: + name: > + <<^ parameters.descr >> pip install << parameters.args >> <> + <<# parameters.descr >> << parameters.descr >> <> + command: > + pip install + <<# parameters.user >> --user <> + --progress-bar=off + << parameters.args >> + + install_torchvision: + parameters: + editable: + type: boolean + default: true + steps: + - pip_install: + args: --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + descr: Install PyTorch from nightly releases + - pip_install: + args: --no-build-isolation <<# parameters.editable >> --editable <> . + descr: Install torchvision <<# parameters.editable >> in editable mode <> + + install_prototype_dependencies: + steps: + - pip_install: + args: iopath git+https://github.com/pytorch/data + descr: Install prototype dependencies + + # Most of the test suite is handled by the `unittest` jobs, with completely different workflow and setup. + # This command can be used if only a selection of tests need to be run, for ad-hoc files. + run_tests_selective: + parameters: + file_or_dir: + type: string + steps: + - run: + name: Install test utilities + command: pip install --progress-bar=off pytest pytest-mock + - run: + name: Run tests + command: pytest --junitxml=test-results/junit.xml -v --durations 20 <> + - store_test_results: + path: test-results + binary_common: &binary_common parameters: # Edit these defaults to do a release @@ -171,107 +246,99 @@ jobs: - image: circleci/python:3.7 steps: - checkout + - pip_install: + args: jinja2 pyyaml - run: + name: Check CircleCI config consistency command: | - pip install --user --progress-bar off jinja2 pyyaml python .circleci/regenerate.py git diff --exit-code || (echo ".circleci/config.yml not in sync with config.yml.in! Run .circleci/regenerate.py to update config"; exit 1) - python_lint: + lint_python_and_config: docker: - image: circleci/python:3.7 steps: - checkout + - pip_install: + args: pre-commit + descr: Install lint utilities - run: - command: | - pip install --user --progress-bar off pre-commit - pre-commit install-hooks - - run: pre-commit run --all-files + name: Install pre-commit hooks + command: pre-commit install-hooks + - run: + name: Lint Python code and config files + command: pre-commit run --all-files - run: name: Required lint modifications when: on_fail command: git --no-pager diff - python_type_check: + lint_c: docker: - image: circleci/python:3.7 steps: + - apt_install: + args: libtinfo5 + descr: Install additional system libraries - checkout - run: + name: Install lint utilities command: | - sudo apt-get update -y - sudo apt install -y libturbojpeg-dev - pip install --user --progress-bar off mypy - pip install --user --progress-bar off types-requests - pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - pip install --user --progress-bar off git+https://github.com/pytorch/data.git - pip install --user --progress-bar off --no-build-isolation --editable . - mypy --config-file mypy.ini - - docstring_parameters_sync: - docker: - - image: circleci/python:3.7 - steps: - - checkout + curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o clang-format + chmod +x clang-format + sudo mv clang-format /opt/clang-format - run: - command: | - pip install --user pydocstyle - pydocstyle + name: Lint C code + command: ./.circleci/unittest/linux/scripts/run-clang-format.py -r torchvision/csrc --clang-format-executable /opt/clang-format + - run: + name: Required lint modifications + when: on_fail + command: git --no-pager diff - clang_format: + type_check_python: docker: - image: circleci/python:3.7 steps: + - apt_install: + args: libturbojpeg-dev + descr: Install additional system libraries - checkout + - install_torchvision: + editable: true + - install_prototype_dependencies + - pip_install: + args: mypy + descr: Install Python type check utilities - run: - command: | - sudo apt-get update -y - sudo apt install -y libtinfo5 - curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o clang-format - chmod +x clang-format - sudo mv clang-format /opt/clang-format - ./.circleci/unittest/linux/scripts/run-clang-format.py -r torchvision/csrc --clang-format-executable /opt/clang-format + name: Check Python types statically + command: mypy --config-file mypy.ini - torchhub_test: + unittest_torchhub: docker: - image: circleci/python:3.7 steps: - checkout - - run: - command: | - pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off --no-build-isolation . - pip install pytest - python test/test_hub.py + - install_torchvision + - run_tests_selective: + file_or_dir: test/test_hub.py - torch_onnx_test: + unittest_onnx: docker: - image: circleci/python:3.7 steps: - checkout - - run: - command: | - pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off --no-build-isolation . - pip install --user onnx - pip install --user onnxruntime - pip install --user pytest - python test/test_onnx.py - - prototype_test: + - install_torchvision + - pip_install: + args: onnx onnxruntime + descr: Install ONNX + - run_tests_selective: + file_or_dir: test/test_onnx.py + + unittest_prototype: docker: - image: circleci/python:3.7 resource_class: xlarge steps: - - run: - name: Install torch - command: | - pip install --user --progress-bar=off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - - run: - name: Install prototype dependencies - command: pip install --user --progress-bar=off git+https://github.com/pytorch/data.git - checkout - run: name: Download model weights @@ -281,19 +348,16 @@ jobs: mkdir -p ~/.cache/torch/hub/checkpoints python scripts/collect_model_urls.py torchvision/prototype/models \ | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' + - install_torchvision + - install_prototype_dependencies + - pip_install: + args: scipy pycocotools + descr: Install optional dependencies - run: - name: Install torchvision - command: pip install --user --progress-bar off --no-build-isolation . - - run: - name: Install test requirements - command: pip install --user --progress-bar=off pytest pytest-mock scipy iopath - - run: - name: Run tests - environment: - PYTORCH_TEST_WITH_PROTOTYPE: 1 - command: pytest --junitxml=test-results/junit.xml -v --durations 20 test/test_prototype_*.py - - store_test_results: - path: test-results + name: Enable prototype tests + command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV + - run_tests_selective: + file_or_dir: test/test_prototype_*.py binary_linux_wheel: <<: *binary_common @@ -529,9 +593,10 @@ jobs: at: ~/workspace - designate_upload_channel - checkout + - pip_install: + args: awscli - run: command: | - pip install --user awscli export PATH="$HOME/.local/bin:$PATH" # Prevent credential from leaking set +x @@ -572,7 +637,8 @@ jobs: command: | set -x source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} - pip install $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html + - pip_install: + args: $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html - run: name: smoke test command: | @@ -641,7 +707,8 @@ jobs: eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" conda create -yn python${PYTHON_VERSION} python=${PYTHON_VERSION} conda activate python${PYTHON_VERSION} - pip install $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html + - pip_install: + args: $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html - run: name: smoke test command: | @@ -967,7 +1034,7 @@ jobs: eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env pushd docs - pip install -r requirements.txt + pip install --progress-bar=off -r requirements.txt make html popd - persist_to_workspace: @@ -1008,9 +1075,15 @@ jobs: workflows: - build: + lint: jobs: - circleci_consistency + - lint_python_and_config + - lint_c + - type_check_python + + build: + jobs: - binary_linux_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -1515,13 +1588,6 @@ workflows: python_version: '3.7' requires: - build_docs - - python_lint - - python_type_check - - docstring_parameters_sync - - clang_format - - torchhub_test - - torch_onnx_test - - prototype_test - binary_ios_build: build_environment: binary-libtorchvision_ops-ios-12.0.0-x86_64 ios_arch: x86_64 @@ -1538,6 +1604,9 @@ workflows: unittest: jobs: + - unittest_torchhub + - unittest_onnx + - unittest_prototype - unittest_linux_cpu: cu_version: cpu name: unittest_linux_cpu_py3.6 @@ -1675,14 +1744,6 @@ workflows: nightly: jobs: - - circleci_consistency - - python_lint - - python_type_check - - docstring_parameters_sync - - clang_format - - torchhub_test - - torch_onnx_test - - prototype_test - binary_ios_build: build_environment: nightly-binary-libtorchvision_ops-ios-12.0.0-x86_64 filters: diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 4f3adbff184..c029fa766ad 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -99,6 +99,81 @@ commands: - brew_install: formulae: libtool + apt_install: + parameters: + args: + type: string + descr: + type: string + default: "" + update: + type: boolean + default: true + steps: + - run: + name: > + <<^ parameters.descr >> apt install << parameters.args >> <> + <<# parameters.descr >> << parameters.descr >> <> + command: | + <<# parameters.update >> sudo apt update -qy <> + sudo apt install << parameters.args >> + + pip_install: + parameters: + args: + type: string + descr: + type: string + default: "" + user: + type: boolean + default: true + steps: + - run: + name: > + <<^ parameters.descr >> pip install << parameters.args >> <> + <<# parameters.descr >> << parameters.descr >> <> + command: > + pip install + <<# parameters.user >> --user <> + --progress-bar=off + << parameters.args >> + + install_torchvision: + parameters: + editable: + type: boolean + default: true + steps: + - pip_install: + args: --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + descr: Install PyTorch from nightly releases + - pip_install: + args: --no-build-isolation <<# parameters.editable >> --editable <> . + descr: Install torchvision <<# parameters.editable >> in editable mode <> + + install_prototype_dependencies: + steps: + - pip_install: + args: iopath git+https://github.com/pytorch/data + descr: Install prototype dependencies + + # Most of the test suite is handled by the `unittest` jobs, with completely different workflow and setup. + # This command can be used if only a selection of tests need to be run, for ad-hoc files. + run_tests_selective: + parameters: + file_or_dir: + type: string + steps: + - run: + name: Install test utilities + command: pip install --progress-bar=off pytest pytest-mock + - run: + name: Run tests + command: pytest --junitxml=test-results/junit.xml -v --durations 20 <> + - store_test_results: + path: test-results + binary_common: &binary_common parameters: # Edit these defaults to do a release @@ -171,107 +246,99 @@ jobs: - image: circleci/python:3.7 steps: - checkout + - pip_install: + args: jinja2 pyyaml - run: + name: Check CircleCI config consistency command: | - pip install --user --progress-bar off jinja2 pyyaml python .circleci/regenerate.py git diff --exit-code || (echo ".circleci/config.yml not in sync with config.yml.in! Run .circleci/regenerate.py to update config"; exit 1) - python_lint: + lint_python_and_config: docker: - image: circleci/python:3.7 steps: - checkout + - pip_install: + args: pre-commit + descr: Install lint utilities - run: - command: | - pip install --user --progress-bar off pre-commit - pre-commit install-hooks - - run: pre-commit run --all-files + name: Install pre-commit hooks + command: pre-commit install-hooks + - run: + name: Lint Python code and config files + command: pre-commit run --all-files - run: name: Required lint modifications when: on_fail command: git --no-pager diff - python_type_check: + lint_c: docker: - image: circleci/python:3.7 steps: + - apt_install: + args: libtinfo5 + descr: Install additional system libraries - checkout - run: + name: Install lint utilities command: | - sudo apt-get update -y - sudo apt install -y libturbojpeg-dev - pip install --user --progress-bar off mypy - pip install --user --progress-bar off types-requests - pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - pip install --user --progress-bar off git+https://github.com/pytorch/data.git - pip install --user --progress-bar off --no-build-isolation --editable . - mypy --config-file mypy.ini - - docstring_parameters_sync: - docker: - - image: circleci/python:3.7 - steps: - - checkout + curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o clang-format + chmod +x clang-format + sudo mv clang-format /opt/clang-format - run: - command: | - pip install --user pydocstyle - pydocstyle + name: Lint C code + command: ./.circleci/unittest/linux/scripts/run-clang-format.py -r torchvision/csrc --clang-format-executable /opt/clang-format + - run: + name: Required lint modifications + when: on_fail + command: git --no-pager diff - clang_format: + type_check_python: docker: - image: circleci/python:3.7 steps: + - apt_install: + args: libturbojpeg-dev + descr: Install additional system libraries - checkout - - run: - command: | - sudo apt-get update -y - sudo apt install -y libtinfo5 - curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o clang-format - chmod +x clang-format - sudo mv clang-format /opt/clang-format - ./.circleci/unittest/linux/scripts/run-clang-format.py -r torchvision/csrc --clang-format-executable /opt/clang-format + - install_torchvision: + editable: true + - install_prototype_dependencies + - pip_install: + args: mypy + descr: Install Python type check utilities + - run: + name: Check Python types statically + command: mypy --config-file mypy.ini - torchhub_test: + unittest_torchhub: docker: - image: circleci/python:3.7 steps: - checkout - - run: - command: | - pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off --no-build-isolation . - pip install pytest - python test/test_hub.py + - install_torchvision + - run_tests_selective: + file_or_dir: test/test_hub.py - torch_onnx_test: + unittest_onnx: docker: - image: circleci/python:3.7 steps: - checkout - - run: - command: | - pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off --no-build-isolation . - pip install --user onnx - pip install --user onnxruntime - pip install --user pytest - python test/test_onnx.py + - install_torchvision + - pip_install: + args: onnx onnxruntime + descr: Install ONNX + - run_tests_selective: + file_or_dir: test/test_onnx.py - prototype_test: + unittest_prototype: docker: - image: circleci/python:3.7 resource_class: xlarge steps: - - run: - name: Install torch - command: | - pip install --user --progress-bar=off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - - run: - name: Install prototype dependencies - command: pip install --user --progress-bar=off git+https://github.com/pytorch/data.git - checkout - run: name: Download model weights @@ -281,19 +348,16 @@ jobs: mkdir -p ~/.cache/torch/hub/checkpoints python scripts/collect_model_urls.py torchvision/prototype/models \ | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' - - run: - name: Install torchvision - command: pip install --user --progress-bar off --no-build-isolation . - - run: - name: Install test requirements - command: pip install --user --progress-bar=off pytest pytest-mock scipy iopath - - run: - name: Run tests - environment: - PYTORCH_TEST_WITH_PROTOTYPE: 1 - command: pytest --junitxml=test-results/junit.xml -v --durations 20 test/test_prototype_*.py - - store_test_results: - path: test-results + - install_torchvision + - install_prototype_dependencies + - pip_install: + args: scipy pycocotools + descr: Install optional dependencies + - run: + name: Enable prototype tests + command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV + - run_tests_selective: + file_or_dir: test/test_prototype_*.py binary_linux_wheel: <<: *binary_common @@ -529,9 +593,10 @@ jobs: at: ~/workspace - designate_upload_channel - checkout + - pip_install: + args: awscli - run: command: | - pip install --user awscli export PATH="$HOME/.local/bin:$PATH" # Prevent credential from leaking set +x @@ -572,7 +637,8 @@ jobs: command: | set -x source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} - pip install $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html + - pip_install: + args: $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html - run: name: smoke test command: | @@ -641,7 +707,8 @@ jobs: eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" conda create -yn python${PYTHON_VERSION} python=${PYTHON_VERSION} conda activate python${PYTHON_VERSION} - pip install $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html + - pip_install: + args: $(ls ~/workspace/torchvision*.whl) --pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html - run: name: smoke test command: | @@ -967,7 +1034,7 @@ jobs: eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env pushd docs - pip install -r requirements.txt + pip install --progress-bar=off -r requirements.txt make html popd - persist_to_workspace: @@ -1008,23 +1075,24 @@ jobs: workflows: - build: -{%- if True %} + lint: jobs: - circleci_consistency + - lint_python_and_config + - lint_c + - type_check_python + + build: + jobs: {{ build_workflows(windows_latest_only=True) }} - - python_lint - - python_type_check - - docstring_parameters_sync - - clang_format - - torchhub_test - - torch_onnx_test - - prototype_test {{ ios_workflows() }} {{ android_workflows() }} unittest: jobs: + - unittest_torchhub + - unittest_onnx + - unittest_prototype {{ unittest_workflows() }} cmake: @@ -1032,16 +1100,7 @@ workflows: {{ cmake_workflows() }} nightly: -{%- endif %} jobs: - - circleci_consistency - - python_lint - - python_type_check - - docstring_parameters_sync - - clang_format - - torchhub_test - - torch_onnx_test - - prototype_test {{ ios_workflows(nightly=True) }} {{ android_workflows(nightly=True) }} {{ build_workflows(prefix="nightly_", filter_branch="nightly", upload=True) }} diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000000..f267cc7da50 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1 @@ + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 920916030be..89f69bba52e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,8 @@ repos: - id: check-toml - id: check-yaml exclude: packaging/.* + - id: mixed-line-ending + args: [--fix=lf] - id: end-of-file-fixer # - repo: https://github.com/asottile/pyupgrade @@ -28,3 +30,8 @@ repos: hooks: - id: flake8 args: [--config=setup.cfg] + + - repo: https://github.com/PyCQA/pydocstyle + rev: 6.1.1 + hooks: + - id: pydocstyle diff --git a/android/gradlew.bat b/android/gradlew.bat index e95643d6a2c..f9553162f12 100644 --- a/android/gradlew.bat +++ b/android/gradlew.bat @@ -1,84 +1,84 @@ -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto init - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/docs/source/models.rst b/docs/source/models.rst index dbb1400e11e..ee8503a0857 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -7,7 +7,7 @@ Models and pre-trained weights The ``torchvision.models`` subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person -keypoint detection and video classification. +keypoint detection, video classification, and optical flow. .. note :: Backward compatibility is guaranteed for loading a serialized @@ -798,3 +798,16 @@ ResNet (2+1)D :template: function.rst torchvision.models.video.r2plus1d_18 + +Optical flow +============ + +Raft +---- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + torchvision.models.optical_flow.raft_large + torchvision.models.optical_flow.raft_small diff --git a/references/classification/train.py b/references/classification/train.py index b16ed3d2a42..b2c6844df9b 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -158,8 +158,7 @@ def load_data(traindir, valdir, args): crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) else: - fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) preprocessing = weights.transforms() dataset_test = torchvision.datasets.ImageFolder( diff --git a/references/detection/train.py b/references/detection/train.py index ae13a32bd22..0788895af20 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -53,8 +53,7 @@ def get_transform(train, args): elif not args.weights: return presets.DetectionPresetEval() else: - fn = PM.detection.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) return weights.transforms() diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 2dbb962fe2f..72a9bdb01f5 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -38,8 +38,7 @@ def get_transform(train, args): elif not args.weights: return presets.SegmentationPresetEval(base_size=520) else: - fn = PM.segmentation.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) return weights.transforms() diff --git a/references/video_classification/train.py b/references/video_classification/train.py index d66879e5b46..0cd88e8022f 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,19 +12,13 @@ from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler -try: - from apex import amp -except ImportError: - amp = None - - try: from torchvision.prototype import models as PM except ImportError: PM = None -def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): +def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi for video, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() video, target = video.to(device), target.to(device) - output = model(video) - loss = criterion(output, target) + with torch.cuda.amp.autocast(enabled=scaler is not None): + output = model(video) + loss = criterion(output, target) optimizer.zero_grad() - if apex: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + + if scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() else: loss.backward() - optimizer.step() + optimizer.step() acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = video.shape[0] @@ -101,11 +98,6 @@ def collate_fn(batch): def main(args): if args.weights and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if args.apex and amp is None: - raise RuntimeError( - "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training." - ) if args.output_dir: utils.mkdir(args.output_dir) @@ -160,8 +152,7 @@ def main(args): if not args.weights: transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) else: - fn = PM.video.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) transform_test = weights.transforms() if args.cache_dataset and os.path.exists(cache_path): @@ -225,9 +216,7 @@ def main(args): lr = args.lr * args.world_size optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) - - if args.apex: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) + scaler = torch.cuda.amp.GradScaler() if args.amp else None # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs @@ -268,6 +257,8 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -278,9 +269,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch( - model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex - ) + train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { @@ -290,6 +279,8 @@ def main(args): "epoch": epoch, "args": args, } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) @@ -364,17 +355,6 @@ def parse_args(): action="store_true", ) - # Mixed precision training parameters - parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") - parser.add_argument( - "--apex-opt-level", - default="O1", - type=str, - help="For apex mixed precision training" - "O0 for FP32 training, O1 for mixed precision training." - "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", - ) - # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") @@ -382,6 +362,9 @@ def parse_args(): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + args = parser.parse_args() return args diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 8d27240c75d..f399125b0af 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1,5 +1,6 @@ import functools import gzip +import json import lzma import pathlib import pickle @@ -8,6 +9,7 @@ from typing import Any, Dict, Tuple import numpy as np +import PIL.Image import pytest import torch from datasets_utils import create_image_folder, make_tar, make_zip @@ -18,7 +20,9 @@ from torchvision.prototype.datasets._api import find from torchvision.prototype.utils._internal import add_suggestion + make_tensor = functools.partial(_make_tensor, device="cpu") +make_scalar = functools.partial(make_tensor, ()) __all__ = ["load"] @@ -490,3 +494,113 @@ def imagenet(info, root, config): make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz") return num_samples + + +class CocoMockData: + @classmethod + def _make_images_archive(cls, root, name, *, num_samples): + image_paths = create_image_folder( + root, name, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_samples + ) + + images_meta = [] + for path in image_paths: + with PIL.Image.open(path) as image: + width, height = image.size + images_meta.append(dict(file_name=path.name, id=int(path.stem), width=width, height=height)) + + make_zip(root, f"{name}.zip") + + return images_meta + + @classmethod + def _make_annotations_json( + cls, + root, + name, + *, + images_meta, + fn, + ): + num_anns_per_image = torch.randint(1, 5, (len(images_meta),)) + num_anns_total = int(num_anns_per_image.sum()) + ann_ids_iter = iter(torch.arange(num_anns_total)[torch.randperm(num_anns_total)]) + + anns_meta = [] + for image_meta, num_anns in zip(images_meta, num_anns_per_image): + for _ in range(num_anns): + ann_id = int(next(ann_ids_iter)) + anns_meta.append(dict(fn(ann_id, image_meta), id=ann_id, image_id=image_meta["id"])) + anns_meta.sort(key=lambda ann: ann["id"]) + + with open(root / name, "w") as file: + json.dump(dict(images=images_meta, annotations=anns_meta), file) + + return num_anns_per_image + + @staticmethod + def _make_instances_data(ann_id, image_meta): + def make_rle_segmentation(): + height, width = image_meta["height"], image_meta["width"] + numel = height * width + counts = [] + while sum(counts) <= numel: + counts.append(int(torch.randint(5, 8, ()))) + if sum(counts) > numel: + counts[-1] -= sum(counts) - numel + return dict(counts=counts, size=[height, width]) + + return dict( + segmentation=make_rle_segmentation(), + bbox=make_tensor((4,), dtype=torch.float32, low=0).tolist(), + iscrowd=True, + area=float(make_scalar(dtype=torch.float32)), + category_id=int(make_scalar(dtype=torch.int64)), + ) + + @staticmethod + def _make_captions_data(ann_id, image_meta): + return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.") + + @classmethod + def _make_annotations(cls, root, name, *, images_meta): + num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64) + for annotations, fn in ( + ("instances", cls._make_instances_data), + ("captions", cls._make_captions_data), + ): + num_anns_per_image += cls._make_annotations_json( + root, f"{annotations}_{name}.json", images_meta=images_meta, fn=fn + ) + + return int(num_anns_per_image.sum()) + + @classmethod + def generate( + cls, + root, + *, + year, + num_samples, + ): + annotations_dir = root / "annotations" + annotations_dir.mkdir() + + for split in ("train", "val"): + config_name = f"{split}{year}" + + images_meta = cls._make_images_archive(root, config_name, num_samples=num_samples) + cls._make_annotations( + annotations_dir, + config_name, + images_meta=images_meta, + ) + + make_zip(root, f"annotations_trainval{year}.zip", annotations_dir) + + return num_samples + + +@dataset_mocks.register_mock_data_fn +def coco(info, root, config): + return CocoMockData.generate(root, year=config.year, num_samples=5) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 4e06fdfffbe..4012b29e7c5 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -866,6 +866,13 @@ def _split_files_or_dirs(root, *files_or_dirs): def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True): archive = pathlib.Path(root) / name + if not files_or_dirs: + dir = archive.parent / archive.name.replace("".join(archive.suffixes), "") + if dir.exists() and dir.is_dir(): + files_or_dirs = (dir,) + else: + raise ValueError("No file or dir provided.") + files, dirs = _split_files_or_dirs(root, *files_or_dirs) with opener(archive) as fh: diff --git a/test/expect/ModelTester.test_raft_large_expect.pkl b/test/expect/ModelTester.test_raft_large_expect.pkl new file mode 100644 index 00000000000..a6aad285f59 Binary files /dev/null and b/test/expect/ModelTester.test_raft_large_expect.pkl differ diff --git a/test/expect/ModelTester.test_raft_small_expect.pkl b/test/expect/ModelTester.test_raft_small_expect.pkl new file mode 100644 index 00000000000..bd0ee65add9 Binary files /dev/null and b/test/expect/ModelTester.test_raft_small_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 5fbe0dca38f..2e0ed783849 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -22,7 +22,11 @@ def get_models_from_module(module): # TODO add a registration mechanism to torchvision.models - return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + return [ + v + for k, v in module.__dict__.items() + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" + ] @pytest.fixture @@ -89,7 +93,7 @@ def _get_expected_file(name=None): return expected_file -def _assert_expected(output, name, prec): +def _assert_expected(output, name, prec=None, atol=None, rtol=None): """Test that a python value matches the recorded contents of a file based on a "check" name. The value must be pickable with `torch.save`. This file @@ -106,10 +110,11 @@ def _assert_expected(output, name, prec): MAX_PICKLE_SIZE = 50 * 1000 # 50 KB binary_size = os.path.getsize(expected_file) if binary_size > MAX_PICKLE_SIZE: - raise RuntimeError(f"The output for {filename}, is larger than 50kb") + raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb") else: expected = torch.load(expected_file) - rtol = atol = prec + rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally + atol = atol or prec torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False) @@ -814,5 +819,33 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] +@needs_cuda +@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small)) +@pytest.mark.parametrize("scripted", (False, True)) +def test_raft(model_builder, scripted): + + torch.manual_seed(0) + + # We need very small images, otherwise the pickle size would exceed the 50KB + # As a resut we need to override the correlation pyramid to not downsample + # too much, otherwise we would get nan values (effective H and W would be + # reduced to 1) + corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2) + + model = model_builder(corr_block=corr_block).eval().to("cuda") + if scripted: + model = torch.jit.script(model) + + bs = 1 + img1 = torch.rand(bs, 3, 80, 72).cuda() + img2 = torch.rand(bs, 3, 80, 72).cuda() + + preds = model(img1, img2) + flow_pred = preds[-1] + # Tolerance is fairly high, but there are 2 * H * W outputs to check + # The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different + _assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_onnx.py b/test/test_onnx.py index 830699ab5ee..b49e0e24c8e 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -141,6 +141,7 @@ def test_roi_align(self): model = ops.RoIAlign((5, 5), 1, -1) self.run_model(model, [(x, single_roi)]) + @pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.") def test_roi_align_aligned(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32) diff --git a/test/test_ops.py b/test/test_ops.py index c8e4e396c7e..d687e2e2952 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,12 +7,54 @@ import numpy as np import pytest import torch +import torch.fx from common_utils import needs_cuda, cpu_and_gpu, assert_equal from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck from torch.nn.modules.utils import _pair from torchvision import models, ops +from torchvision.models.feature_extraction import get_graph_node_names + + +class RoIOpTesterModuleWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 2 + + def forward(self, a, b): + self.layer(a, b) + + +class MultiScaleRoIAlignModuleWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 3 + + def forward(self, a, b, c): + self.layer(a, b, c) + + +class DeformConvModuleWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 3 + + def forward(self, a, b, c): + self.layer(a, b, c) + + +class StochasticDepthWrapper(nn.Module): + def __init__(self, obj): + super().__init__() + self.layer = obj + self.n_inputs = 1 + + def forward(self, a): + self.layer(a) class RoIOpTester(ABC): @@ -46,6 +88,15 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_is_leaf_node(self, device): + op_obj = self.make_obj(wrap=True).to(device=device) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) @@ -91,6 +142,10 @@ def _helper_boxes_shape(self, func): def fn(*args, **kwargs): pass + @abstractmethod + def make_obj(*args, **kwargs): + pass + @abstractmethod def get_script_fn(*args, **kwargs): pass @@ -104,6 +159,10 @@ class TestRoiPool(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois) + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False): + obj = ops.RoIPool((pool_h, pool_w), spatial_scale) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.roi_pool) return lambda x: scriped(x, rois, pool_size) @@ -144,6 +203,10 @@ class TestPSRoIPool(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False): + obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.ps_roi_pool) return lambda x: scriped(x, rois, pool_size) @@ -223,6 +286,12 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligne (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned )(x, rois) + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False): + obj = ops.RoIAlign( + (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned + ) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.roi_align) return lambda x: scriped(x, rois, pool_size) @@ -374,6 +443,10 @@ class TestPSRoIAlign(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) + def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False): + obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio) + return RoIOpTesterModuleWrapper(obj) if wrap else obj + def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.ps_roi_align) return lambda x: scriped(x, rois, pool_size) @@ -422,12 +495,18 @@ def test_boxes_shape(self): class TestMultiScaleRoIAlign: + def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False): + if fmap_names is None: + fmap_names = ["0"] + obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) + return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj + def test_msroialign_repr(self): fmap_names = ["0"] output_size = (7, 7) sampling_ratio = 2 # Pass mock feature map names - t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) + t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False) # Check integrity of object __repr__ attribute expected_string = ( @@ -436,6 +515,15 @@ def test_msroialign_repr(self): ) assert repr(t) == expected_string + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_is_leaf_node(self, device): + op_obj = self.make_obj(wrap=True).to(device=device) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + class TestNMS: def _reference_nms(self, boxes, scores, iou_threshold): @@ -693,6 +781,21 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype): return x, weight, offset, mask, bias, stride, pad, dilation + def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False): + obj = ops.DeformConv2d( + in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups + ) + return DeformConvModuleWrapper(obj) if wrap else obj + + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_is_leaf_node(self, device): + op_obj = self.make_obj(wrap=True).to(device=device) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) @@ -705,9 +808,9 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): groups = 2 tol = 2e-3 if dtype is torch.half else 1e-5 - layer = ops.DeformConv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups - ).to(device=x.device, dtype=dtype) + layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to( + device=x.device, dtype=dtype + ) res = layer(x, offset, mask) weight = layer.weight.data @@ -1200,6 +1303,20 @@ def test_stochastic_depth(self, seed, mode, p): elif p == 1: assert out.equal(torch.zeros_like(x)) + def make_obj(self, p, mode, wrap=False): + obj = ops.StochasticDepth(p, mode) + return StochasticDepthWrapper(obj) if wrap else obj + + @pytest.mark.parametrize("p", (0, 1)) + @pytest.mark.parametrize("mode", ["batch", "row"]) + def test_is_leaf_node(self, p, mode): + op_obj = self.make_obj(p, mode, wrap=True) + graph_node_names = get_graph_node_names(op_obj) + + assert len(graph_node_names) == 2 + assert len(graph_node_names[0]) == len(graph_node_names[1]) + assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + class TestUtils: @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm]) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 4c2a05e2f0a..9f12324fe34 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -13,6 +13,17 @@ def to_bytes(file): return file.read() +def config_id(name, config): + parts = [name] + for name, value in config.items(): + if isinstance(value, bool): + part = ("" if value else "no_") + name + else: + part = str(value) + parts.append(part) + return "-".join(parts) + + def dataset_parametrization(*names, decoder=to_bytes): if not names: # TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported @@ -27,16 +38,17 @@ def dataset_parametrization(*names, decoder=to_bytes): "caltech256", "caltech101", "imagenet", + "coco", ) - params = [] - for name in names: - for config in datasets.info(name)._configs: - id = f"{name}-{'-'.join([str(value) for value in config.values()])}" - dataset, mock_info = builtin_dataset_mocks.load(name, decoder=decoder, **config) - params.append(pytest.param(dataset, mock_info, id=id)) - - return pytest.mark.parametrize(("dataset", "mock_info"), params) + return pytest.mark.parametrize( + ("dataset", "mock_info"), + [ + pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config)) + for name in names + for config in datasets.info(name)._configs + ], + ) class TestCommon: diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 92a88342534..1dc883528ef 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -24,6 +24,19 @@ def _get_parent_module(model_fn): return module +def _get_model_weights(model_fn): + module = _get_parent_module(model_fn) + weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" + try: + return next( + v + for k, v in module.__dict__.items() + if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__ + ) + except StopIteration: + return None + + def _build_model(fn, **kwargs): try: model = fn(**kwargs) @@ -36,24 +49,22 @@ def _build_model(fn, **kwargs): @pytest.mark.parametrize( - "model_fn, name, weight", + "name, weight", [ - (models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1), - (models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2), + ("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1), + ("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2), ( - models.quantization.resnet50, - "default", + "ResNet50_QuantizedWeights.default", models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2, ), ( - models.quantization.resnet50, - "ImageNet1K_FBGEMM_V1", + "ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1", models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1, ), ], ) -def test_get_weight(model_fn, name, weight): - assert models._api.get_weight(model_fn, name) == weight +def test_get_weight(name, weight): + assert models.get_weight(name) == weight @pytest.mark.parametrize( @@ -65,10 +76,9 @@ def test_get_weight(model_fn, name, weight): + TM.get_models_from_module(models.video), ) def test_naming_conventions(model_fn): - model_name = model_fn.__name__ - module = _get_parent_module(model_fn) - weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" - assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name)) + weights_enum = _get_model_weights(model_fn) + assert weights_enum is not None + assert len(weights_enum) == 0 or hasattr(weights_enum, "default") @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index c6e971c3b12..0167ed70a64 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -70,6 +70,7 @@ static void torch_jpeg_set_source_mgr( } // namespace torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { + C10_LOG_API_USAGE_ONCE("torchvision.io.decode_jpeg_cpp"); // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 0df55daed68..8ab0fed205c 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -23,6 +23,7 @@ torch::Tensor decode_png( const torch::Tensor& data, ImageReadMode mode, bool allow_16_bits) { + C10_LOG_API_USAGE_ONCE("torchvision.io.decode_png_cpp"); // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional diff --git a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp index a8dbc7b2a28..739783919ae 100644 --- a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp @@ -25,6 +25,7 @@ using JpegSizeType = size_t; using namespace detail; torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { + C10_LOG_API_USAGE_ONCE("torchvision.io.encode_jpeg_cpp"); // Define compression structures and error handling struct jpeg_compress_struct cinfo {}; struct torch_jpeg_error_mgr jerr {}; diff --git a/torchvision/csrc/io/image/cpu/encode_png.cpp b/torchvision/csrc/io/image/cpu/encode_png.cpp index d28bad95890..ca308f357ff 100644 --- a/torchvision/csrc/io/image/cpu/encode_png.cpp +++ b/torchvision/csrc/io/image/cpu/encode_png.cpp @@ -63,6 +63,7 @@ void torch_png_write_data( } // namespace torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { + C10_LOG_API_USAGE_ONCE("torchvision.io.encode_png_cpp"); // Define compression structures and error handling png_structp png_write; png_infop info_ptr; diff --git a/torchvision/csrc/io/image/cpu/read_write_file.cpp b/torchvision/csrc/io/image/cpu/read_write_file.cpp index a0bb7df72d5..b1d1a48c4b9 100644 --- a/torchvision/csrc/io/image/cpu/read_write_file.cpp +++ b/torchvision/csrc/io/image/cpu/read_write_file.cpp @@ -33,6 +33,7 @@ std::wstring utf8_decode(const std::string& str) { #endif torch::Tensor read_file(const std::string& filename) { + C10_LOG_API_USAGE_ONCE("torchvision.io.read_file_cpp"); #ifdef _WIN32 // According to // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019, @@ -76,6 +77,7 @@ torch::Tensor read_file(const std::string& filename) { } void write_file(const std::string& filename, torch::Tensor& data) { + C10_LOG_API_USAGE_ONCE("torchvision.io.write_file_cpp"); // Check that the input tensor is on CPU TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 68f63ced427..37674d2b44d 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -33,6 +33,7 @@ torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, ImageReadMode mode, torch::Device device) { + C10_LOG_API_USAGE_ONCE("torchvision.io.decode_jpeg_cuda_cpp"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK( diff --git a/torchvision/csrc/io/video/video.cpp b/torchvision/csrc/io/video/video.cpp index d5a24398694..ea4d31628e6 100644 --- a/torchvision/csrc/io/video/video.cpp +++ b/torchvision/csrc/io/video/video.cpp @@ -157,6 +157,7 @@ void Video::_getDecoderParams( } // _get decoder params Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { + C10_LOG_API_USAGE_ONCE("torchvision.io.Video_cpp"); // set number of threads global numThreads_ = numThreads; // parse stream information diff --git a/torchvision/csrc/io/video_reader/video_reader.cpp b/torchvision/csrc/io/video_reader/video_reader.cpp index 51b0750b431..6b1c70d0bed 100644 --- a/torchvision/csrc/io/video_reader/video_reader.cpp +++ b/torchvision/csrc/io/video_reader/video_reader.cpp @@ -583,6 +583,7 @@ torch::List read_video_from_memory( int64_t audioEndPts, int64_t audioTimeBaseNum, int64_t audioTimeBaseDen) { + C10_LOG_API_USAGE_ONCE("torchvision.io.read_video_from_memory_cpp"); return readVideo( false, input_video, @@ -627,6 +628,7 @@ torch::List read_video_from_file( int64_t audioEndPts, int64_t audioTimeBaseNum, int64_t audioTimeBaseDen) { + C10_LOG_API_USAGE_ONCE("torchvision.io.read_video_from_file_cpp"); torch::Tensor dummy_input_video = torch::ones({0}); return readVideo( true, @@ -653,10 +655,12 @@ torch::List read_video_from_file( } torch::List probe_video_from_memory(torch::Tensor input_video) { + C10_LOG_API_USAGE_ONCE("torchvision.io.probe_video_from_memory_cpp"); return probeVideo(false, input_video, ""); } torch::List probe_video_from_file(std::string videoPath) { + C10_LOG_API_USAGE_ONCE("torchvision.io.probe_video_from_file_cpp"); torch::Tensor dummy_input_video = torch::ones({0}); return probeVideo(true, dummy_input_video, videoPath); } diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index e3222eca41d..acaac029137 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -233,14 +233,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ _save_response_content(itertools.chain((first_chunk,), content), fpath) -def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - return value - - return None - - def _save_response_content( response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined] diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 382e06fb4f2..f2ae6dff51e 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -2,6 +2,7 @@ import torch +from ..utils import _log_api_usage_once from ._video_opt import ( Timebase, VideoMetaData, @@ -106,11 +107,12 @@ class VideoReader: """ def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> None: + _log_api_usage_once(self) if not _has_video_opt(): raise RuntimeError( "Not compiled with video_reader support, " + "to enable video_reader support, please install " - + "ffmpeg (version 4.2 is currently supported) and" + + "ffmpeg (version 4.2 is currently supported) and " + "build torchvision from source." ) self._c = torch.classes.torchvision.Video(path, stream, num_threads) diff --git a/torchvision/io/image.py b/torchvision/io/image.py index f835565016c..dd1801d6bd6 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -3,6 +3,7 @@ import torch from .._internally_replaced_utils import _get_extension_path +from ..utils import _log_api_usage_once try: @@ -41,6 +42,7 @@ def read_file(path: str) -> torch.Tensor: Returns: data (Tensor) """ + _log_api_usage_once("torchvision.io.read_file") data = torch.ops.image.read_file(path) return data @@ -54,6 +56,7 @@ def write_file(filename: str, data: torch.Tensor) -> None: filename (str): the path to the file to be written data (Tensor): the contents to be written to the output file """ + _log_api_usage_once("torchvision.io.write_file") torch.ops.image.write_file(filename, data) @@ -74,6 +77,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE Returns: output (Tensor[image_channels, image_height, image_width]) """ + _log_api_usage_once("torchvision.io.decode_png") output = torch.ops.image.decode_png(input, mode.value, False) return output @@ -93,6 +97,7 @@ def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor: Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the PNG file. """ + _log_api_usage_once("torchvision.io.encode_png") output = torch.ops.image.encode_png(input, compression_level) return output @@ -109,6 +114,7 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): compression_level (int): Compression factor for the resulting file, it must be a number between 0 and 9. Default: 6 """ + _log_api_usage_once("torchvision.io.write_png") output = encode_png(input, compression_level) write_file(filename, output) @@ -137,6 +143,7 @@ def decode_jpeg( Returns: output (Tensor[image_channels, image_height, image_width]) """ + _log_api_usage_once("torchvision.io.decode_jpeg") device = torch.device(device) if device.type == "cuda": output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) @@ -160,6 +167,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the JPEG file. """ + _log_api_usage_once("torchvision.io.encode_jpeg") if quality < 1 or quality > 100: raise ValueError("Image quality should be a positive number between 1 and 100") @@ -178,6 +186,7 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): quality (int): Quality of the resulting JPEG file, it must be a number between 1 and 100. Default: 75 """ + _log_api_usage_once("torchvision.io.write_jpeg") output = encode_jpeg(input, quality) write_file(filename, output) @@ -201,6 +210,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN Returns: output (Tensor[image_channels, image_height, image_width]) """ + _log_api_usage_once("torchvision.io.decode_image") output = torch.ops.image.decode_image(input, mode.value) return output @@ -221,6 +231,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc Returns: output (Tensor[image_channels, image_height, image_width]) """ + _log_api_usage_once("torchvision.io.read_image") data = read_file(path) return decode_image(data, mode) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 0ddd60a4586..cdb426d6d09 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -9,6 +9,7 @@ import numpy as np import torch +from ..utils import _log_api_usage_once from . import _video_opt @@ -77,6 +78,7 @@ def write_video( audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream """ + _log_api_usage_once("torchvision.io.write_video") _check_av_available() video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() @@ -256,6 +258,7 @@ def read_video( aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ + _log_api_usage_once("torchvision.io.read_video") from torchvision import get_video_backend @@ -374,6 +377,7 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in video_fps (float, optional): the frame rate for the video """ + _log_api_usage_once("torchvision.io.read_video_timestamps") from torchvision import get_video_backend if get_video_backend() != "pyav": diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 516e47feb19..c9d11f88f01 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -12,6 +12,7 @@ from .regnet import * from . import detection from . import feature_extraction +from . import optical_flow from . import quantization from . import segmentation from . import video diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 0095f21f62b..0a2b597da23 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -1,11 +1,14 @@ +import inspect +import math import re import warnings from collections import OrderedDict from copy import deepcopy from itertools import chain -from typing import Dict, Callable, List, Union, Optional, Tuple +from typing import Dict, Callable, List, Union, Optional, Tuple, Any import torch +import torchvision from torch import fx from torch import nn from torch.fx.graph_module import _copy_attr @@ -172,8 +175,19 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT warnings.warn(msg + suggestion_msg) +def _get_leaf_modules_for_ops() -> List[type]: + members = inspect.getmembers(torchvision.ops) + result = [] + for _, obj in members: + if inspect.isclass(obj) and issubclass(obj, torch.nn.Module): + result.append(obj) + return result + + def get_graph_node_names( - model: nn.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False + model: nn.Module, + tracer_kwargs: Optional[Dict[str, Any]] = None, + suppress_diff_warning: bool = False, ) -> Tuple[List[str], List[str]]: """ Dev utility to return node names in order of execution. See note on node @@ -198,6 +212,7 @@ def get_graph_node_names( tracer_kwargs (dict, optional): a dictionary of keywork arguments for ``NodePathTracer`` (they are eventually passed onto `torch.fx.Tracer `_). + By default it will be set to wrap and make leaf nodes all torchvision ops. suppress_diff_warning (bool, optional): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. @@ -211,6 +226,14 @@ def get_graph_node_names( >>> model = torchvision.models.resnet18() >>> train_nodes, eval_nodes = get_graph_node_names(model) """ + if tracer_kwargs is None: + tracer_kwargs = { + "autowrap_modules": ( + math, + torchvision.ops, + ), + "leaf_modules": _get_leaf_modules_for_ops(), + } is_training = model.training train_tracer = NodePathTracer(**tracer_kwargs) train_tracer.trace(model.train()) @@ -294,7 +317,7 @@ def create_feature_extractor( return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - tracer_kwargs: Dict = {}, + tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, ) -> fx.GraphModule: """ @@ -353,6 +376,7 @@ def create_feature_extractor( tracer_kwargs (dict, optional): a dictionary of keywork arguments for ``NodePathTracer`` (which passes them onto it's parent class `torch.fx.Tracer `_). + By default it will be set to wrap and make leaf nodes all torchvision ops. suppress_diff_warning (bool, optional): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. @@ -397,6 +421,14 @@ def create_feature_extractor( >>> 'autowrap_functions': [leaf_function]}) """ + if tracer_kwargs is None: + tracer_kwargs = { + "autowrap_modules": ( + math, + torchvision.ops, + ), + "leaf_modules": _get_leaf_modules_for_ops(), + } is_training = model.training assert any( diff --git a/torchvision/models/optical_flow/__init__.py b/torchvision/models/optical_flow/__init__.py new file mode 100644 index 00000000000..9dd32f25dec --- /dev/null +++ b/torchvision/models/optical_flow/__init__.py @@ -0,0 +1 @@ +from .raft import RAFT, raft_large, raft_small diff --git a/torchvision/models/optical_flow/_utils.py b/torchvision/models/optical_flow/_utils.py new file mode 100644 index 00000000000..693b3f14009 --- /dev/null +++ b/torchvision/models/optical_flow/_utils.py @@ -0,0 +1,45 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None): + """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates.""" + h, w = img.shape[-2:] + + xgrid, ygrid = absolute_grid.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (w - 1) - 1 + ygrid = 2 * ygrid / (h - 1) - 1 + normalized_grid = torch.cat([xgrid, ygrid], dim=-1) + + return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners) + + +def make_coords_grid(batch_size: int, h: int, w: int): + coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch_size, 1, 1, 1) + + +def upsample_flow(flow, up_mask: Optional[Tensor] = None): + """Upsample flow by a factor of 8. + + If up_mask is None we just interpolate. + If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B. + Note that in appendix B the picture assumes a downsample factor of 4 instead of 8. + """ + batch_size, _, h, w = flow.shape + new_h, new_w = h * 8, w * 8 + + if up_mask is None: + return 8 * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True) + + up_mask = up_mask.view(batch_size, 1, 9, 8, 8, h, w) + up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1 + + upsampled_flow = F.unfold(8 * flow, kernel_size=3, padding=1).view(batch_size, 2, 9, 1, 1, h, w) + upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2) + + return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, 2, new_h, new_w) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py new file mode 100644 index 00000000000..02705a7ebdb --- /dev/null +++ b/torchvision/models/optical_flow/raft.py @@ -0,0 +1,659 @@ +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.instancenorm import InstanceNorm2d +from torchvision.ops import ConvNormActivation + +from ._utils import grid_sample, make_coords_grid, upsample_flow + + +__all__ = ( + "RAFT", + "raft_large", + "raft_small", +) + + +class ResidualBlock(nn.Module): + """Slightly modified Residual block with extra relu and biases.""" + + def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): + super().__init__() + + # Note regarding bias=True: + # Usually we can pass bias=False in conv layers followed by a norm layer. + # But in the RAFT training reference, the BatchNorm2d layers are only activated for the first dataset, + # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful + # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm + # because these aren't frozen, but we don't bother (also, we woudn't be able to load the original weights). + self.convnormrelu1 = ConvNormActivation( + in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True + ) + self.convnormrelu2 = ConvNormActivation( + out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True + ) + + if stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = ConvNormActivation( + in_channels, + out_channels, + norm_layer=norm_layer, + kernel_size=1, + stride=stride, + bias=True, + activation_layer=None, + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + y = x + y = self.convnormrelu1(y) + y = self.convnormrelu2(y) + + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + """Slightly modified BottleNeck block (extra relu and biases)""" + + def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): + super(BottleneckBlock, self).__init__() + + # See note in ResidualBlock for the reason behind bias=True + self.convnormrelu1 = ConvNormActivation( + in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True + ) + self.convnormrelu2 = ConvNormActivation( + out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True + ) + self.convnormrelu3 = ConvNormActivation( + out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True + ) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = ConvNormActivation( + in_channels, + out_channels, + norm_layer=norm_layer, + kernel_size=1, + stride=stride, + bias=True, + activation_layer=None, + ) + + def forward(self, x): + y = x + y = self.convnormrelu1(y) + y = self.convnormrelu2(y) + y = self.convnormrelu3(y) + + x = self.downsample(x) + + return self.relu(x + y) + + +class FeatureEncoder(nn.Module): + """The feature encoder, used both as the actual feature encoder, and as the context encoder. + + It must downsample its input by 8. + """ + + def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_layer=nn.BatchNorm2d): + super().__init__() + + assert len(layers) == 5 + + # See note in ResidualBlock for the reason behind bias=True + self.convnormrelu = ConvNormActivation(3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True) + + self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1) + self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2) + self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=2) + + self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride): + block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride) + block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1) + return nn.Sequential(block1, block2) + + def forward(self, x): + x = self.convnormrelu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv(x) + + return x + + +class MotionEncoder(nn.Module): + """The motion encoder, part of the update block. + + Takes the current predicted flow and the correlation features as input and returns an encoded version of these. + """ + + def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128, 64), out_channels=128): + super().__init__() + + assert len(flow_layers) == 2 + assert len(corr_layers) in (1, 2) + + self.convcorr1 = ConvNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) + if len(corr_layers) == 2: + self.convcorr2 = ConvNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) + else: + self.convcorr2 = nn.Identity() + + self.convflow1 = ConvNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) + self.convflow2 = ConvNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) + + # out_channels - 2 because we cat the flow (2 channels) at the end + self.conv = ConvNormActivation( + corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3 + ) + + self.out_channels = out_channels + + def forward(self, flow, corr_features): + corr = self.convcorr1(corr_features) + corr = self.convcorr2(corr) + + flow_orig = flow + flow = self.convflow1(flow) + flow = self.convflow2(flow) + + corr_flow = torch.cat([corr, flow], dim=1) + corr_flow = self.conv(corr_flow) + return torch.cat([corr_flow, flow_orig], dim=1) + + +class ConvGRU(nn.Module): + """Convolutional Gru unit.""" + + def __init__(self, *, input_size, hidden_size, kernel_size, padding): + super().__init__() + self.convz = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + self.convr = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + self.convq = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + return h + + +def _pass_through_h(h, _): + # Declared here for torchscript + return h + + +class RecurrentBlock(nn.Module): + """Recurrent block, part of the update block. + + Takes the current hidden state and the concatenation of (motion encoder output, context) as input. + Returns an updated hidden state. + """ + + def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))): + super().__init__() + + assert len(kernel_size) == len(padding) + assert len(kernel_size) in (1, 2) + + self.convgru1 = ConvGRU( + input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[0], padding=padding[0] + ) + if len(kernel_size) == 2: + self.convgru2 = ConvGRU( + input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[1], padding=padding[1] + ) + else: + self.convgru2 = _pass_through_h + + self.hidden_size = hidden_size + + def forward(self, h, x): + h = self.convgru1(h, x) + h = self.convgru2(h, x) + return h + + +class FlowHead(nn.Module): + """Flow head, part of the update block. + + Takes the hidden state of the recurrent unit as input, and outputs the predicted "delta flow". + """ + + def __init__(self, *, in_channels, hidden_size): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, hidden_size, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_size, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class UpdateBlock(nn.Module): + """The update block which contains the motion encoder, the recurrent block, and the flow head. + + It must expose a ``hidden_state_size`` attribute which is the hidden state size of its recurrent block. + """ + + def __init__(self, *, motion_encoder, recurrent_block, flow_head): + super().__init__() + self.motion_encoder = motion_encoder + self.recurrent_block = recurrent_block + self.flow_head = flow_head + + self.hidden_state_size = recurrent_block.hidden_size + + def forward(self, hidden_state, context, corr_features, flow): + motion_features = self.motion_encoder(flow, corr_features) + x = torch.cat([context, motion_features], dim=1) + + hidden_state = self.recurrent_block(hidden_state, x) + delta_flow = self.flow_head(hidden_state) + return hidden_state, delta_flow + + +class MaskPredictor(nn.Module): + """Mask predictor to be used when upsampling the predicted flow. + + It takes the hidden state of the recurrent unit as input and outputs the mask. + This is not used in the raft-small model. + """ + + def __init__(self, *, in_channels, hidden_size, multiplier=0.25): + super().__init__() + self.convrelu = ConvNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) + # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder + # and we interpolate with all 9 surrounding neighbors. See paper and appendix B. + self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0) + + # In the original code, they use a factor of 0.25 to "downweight the gradients" of that branch. + # See e.g. https://github.com/princeton-vl/RAFT/issues/119#issuecomment-953950419 + # or https://github.com/princeton-vl/RAFT/issues/24. + # It doesn't seem to affect epe significantly and can likely be set to 1. + self.multiplier = multiplier + + def forward(self, x): + x = self.convrelu(x) + x = self.conv(x) + return self.multiplier * x + + +class CorrBlock(nn.Module): + """The correlation block. + + Creates a correlation pyramid with ``num_levels`` levels from the outputs of the feature encoder, + and then indexes from this pyramid to create correlation features. + The "indexing" of a given centroid pixel x' is done by concatenating its surrounding neighbors that + are within a ``radius``, according to the infinity norm (see paper section 3.2). + Note: typo in the paper, it should be infinity norm, not 1-norm. + """ + + def __init__(self, *, num_levels: int = 4, radius: int = 4): + super().__init__() + self.num_levels = num_levels + self.radius = radius + + self.corr_pyramid: List[Tensor] = [torch.tensor(0)] # useless, but torchscript is otherwise confused :') + + # The neighborhood of a centroid pixel x' is {x' + delta, ||delta||_inf <= radius} + # so it's a square surrounding x', and its sides have a length of 2 * radius + 1 + # The paper claims that it's ||.||_1 instead of ||.||_inf but it's a typo: + # https://github.com/princeton-vl/RAFT/issues/122 + self.out_channels = num_levels * (2 * radius + 1) ** 2 + + def build_pyramid(self, fmap1, fmap2): + """Build the correlation pyramid from two feature maps. + + The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) + The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions + to build the correlation pyramid. + """ + + torch._assert(fmap1.shape == fmap2.shape, "Input feature maps should have the same shapes") + corr_volume = self._compute_corr_volume(fmap1, fmap2) + + batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w + corr_volume = corr_volume.reshape(batch_size * h * w, num_channels, h, w) + self.corr_pyramid = [corr_volume] + for _ in range(self.num_levels - 1): + corr_volume = F.avg_pool2d(corr_volume, kernel_size=2, stride=2) + self.corr_pyramid.append(corr_volume) + + def index_pyramid(self, centroids_coords): + """Return correlation features by indexing from the pyramid.""" + neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels + di = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device) + delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2) + + batch_size, _, h, w = centroids_coords.shape # _ = 2 + centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2) + + indexed_pyramid = [] + for corr_volume in self.corr_pyramid: + sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2) + indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view( + batch_size, h, w, -1 + ) + indexed_pyramid.append(indexed_corr_volume) + centroids_coords = centroids_coords / 2 + + corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() + + expected_output_shape = (batch_size, self.out_channels, h, w) + torch._assert( + corr_features.shape == expected_output_shape, + f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}", + ) + + return corr_features + + def _compute_corr_volume(self, fmap1, fmap2): + batch_size, num_channels, h, w = fmap1.shape + fmap1 = fmap1.view(batch_size, num_channels, h * w) + fmap2 = fmap2.view(batch_size, num_channels, h * w) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch_size, h, w, 1, h, w) + return corr / torch.sqrt(torch.tensor(num_channels)) + + +class RAFT(nn.Module): + def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block, mask_predictor=None): + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + args: + feature_encoder (nn.Module): The feature encoder. It must downsample the input by 8. + Its input is the concatenation of ``image1`` and ``image2``. + context_encoder (nn.Module): The context encoder. It must downsample the input by 8. + Its input is ``image1``. As in the original implementation, its output will be split into 2 parts: + + - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block`` + - one part will be used to initialize the hidden state of the of the recurrent unit of + the ``update_block`` + + These 2 parts are split according to the ``hidden_state_size`` of the ``update_block``, so the output + of the ``context_encoder`` must be strictly greater than ``hidden_state_size``. + + corr_block (nn.Module): The correlation block, which creates a correlation pyramid from the output of the + ``feature_encoder``, and then indexes from this pyramid to create correlation features. It must expose + 2 methods: + + - a ``build_pyramid`` method that takes ``feature_map_1`` and ``feature_map_2`` as input (these are the + output of the ``feature_encoder``). + - a ``index_pyramid`` method that takes the coordinates of the centroid pixels as input, and returns + the correlation features. See paper section 3.2. + + It must expose an ``out_channels`` attribute. + + update_block (nn.Module): The update block, which contains the motion encoder, the recurrent unit, and the + flow head. It takes as input the hidden state of its recurrent unit, the context, the correlation + features, and the current predicted flow. It outputs an updated hidden state, and the ``delta_flow`` + prediction (see paper appendix A). It must expose a ``hidden_state_size`` attribute. + mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow. + The output channel must be 8 * 8 * 9 - see paper section 3.3, and Appendix B. + If ``None`` (default), the flow is upsampled using interpolation. + """ + super().__init__() + + self.feature_encoder = feature_encoder + self.context_encoder = context_encoder + self.corr_block = corr_block + self.update_block = update_block + + self.mask_predictor = mask_predictor + + if not hasattr(self.update_block, "hidden_state_size"): + raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.") + + def forward(self, image1, image2, num_flow_updates: int = 12): + + batch_size, _, h, w = image1.shape + torch._assert((h, w) == image2.shape[-2:], "input images should have the same shape") + torch._assert((h % 8 == 0) and (w % 8 == 0), "input image H and W should be divisible by 8") + + fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) + fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) + torch._assert(fmap1.shape[-2:] == (h // 8, w // 8), "The feature encoder should downsample H and W by 8") + + self.corr_block.build_pyramid(fmap1, fmap2) + + context_out = self.context_encoder(image1) + torch._assert(context_out.shape[-2:] == (h // 8, w // 8), "The context encoder should downsample H and W by 8") + + # As in the original paper, the actual output of the context encoder is split in 2 parts: + # - one part is used to initialize the hidden state of the recurent units of the update block + # - the rest is the "actual" context. + hidden_state_size = self.update_block.hidden_state_size + out_channels_context = context_out.shape[1] - hidden_state_size + torch._assert( + out_channels_context > 0, + f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than" + f"hidden_state={hidden_state_size} channels", + ) + hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1) + hidden_state = torch.tanh(hidden_state) + context = F.relu(context) + + coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda() + coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda() + + flow_predictions = [] + for _ in range(num_flow_updates): + coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper + corr_features = self.corr_block.index_pyramid(centroids_coords=coords1) + + flow = coords1 - coords0 + hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow) + + coords1 = coords1 + delta_flow + + up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state) + upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask) + flow_predictions.append(upsampled_flow) + + return flow_predictions + + +def _raft( + *, + # Feature encoder + feature_encoder_layers, + feature_encoder_block, + feature_encoder_norm_layer, + # Context encoder + context_encoder_layers, + context_encoder_block, + context_encoder_norm_layer, + # Correlation block + corr_block_num_levels, + corr_block_radius, + # Motion encoder + motion_encoder_corr_layers, + motion_encoder_flow_layers, + motion_encoder_out_channels, + # Recurrent block + recurrent_block_hidden_state_size, + recurrent_block_kernel_size, + recurrent_block_padding, + # Flow Head + flow_head_hidden_size, + # Mask predictor + use_mask_predictor, + **kwargs, +): + feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder( + block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer + ) + context_encoder = kwargs.pop("context_encoder", None) or FeatureEncoder( + block=context_encoder_block, layers=context_encoder_layers, norm_layer=context_encoder_norm_layer + ) + + corr_block = kwargs.pop("corr_block", None) or CorrBlock(num_levels=corr_block_num_levels, radius=corr_block_radius) + + update_block = kwargs.pop("update_block", None) + if update_block is None: + motion_encoder = MotionEncoder( + in_channels_corr=corr_block.out_channels, + corr_layers=motion_encoder_corr_layers, + flow_layers=motion_encoder_flow_layers, + out_channels=motion_encoder_out_channels, + ) + + # See comments in forward pass of RAFT class about why we split the output of the context encoder + out_channels_context = context_encoder_layers[-1] - recurrent_block_hidden_state_size + recurrent_block = RecurrentBlock( + input_size=motion_encoder.out_channels + out_channels_context, + hidden_size=recurrent_block_hidden_state_size, + kernel_size=recurrent_block_kernel_size, + padding=recurrent_block_padding, + ) + + flow_head = FlowHead(in_channels=recurrent_block_hidden_state_size, hidden_size=flow_head_hidden_size) + + update_block = UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head) + + mask_predictor = kwargs.pop("mask_predictor", None) + if mask_predictor is None and use_mask_predictor: + mask_predictor = MaskPredictor( + in_channels=recurrent_block_hidden_state_size, + hidden_size=256, + multiplier=0.25, # See comment in MaskPredictor about this + ) + + return RAFT( + feature_encoder=feature_encoder, + context_encoder=context_encoder, + corr_block=corr_block, + update_block=update_block, + mask_predictor=mask_predictor, + **kwargs, # not really needed, all params should be consumed by now + ) + + +def raft_large(*, pretrained=False, progress=True, **kwargs): + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Args: + pretrained (bool): TODO not implemented yet + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. + + Returns: + nn.Module: The model. + """ + + if pretrained: + raise ValueError("Pretrained weights aren't available yet") + + return _raft( + # Feature encoder + feature_encoder_layers=(64, 64, 96, 128, 256), + feature_encoder_block=ResidualBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(64, 64, 96, 128, 256), + context_encoder_block=ResidualBlock, + context_encoder_norm_layer=BatchNorm2d, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=4, + # Motion encoder + motion_encoder_corr_layers=(256, 192), + motion_encoder_flow_layers=(128, 64), + motion_encoder_out_channels=128, + # Recurrent block + recurrent_block_hidden_state_size=128, + recurrent_block_kernel_size=((1, 5), (5, 1)), + recurrent_block_padding=((0, 2), (2, 0)), + # Flow head + flow_head_hidden_size=256, + # Mask predictor + use_mask_predictor=True, + **kwargs, + ) + + +def raft_small(*, pretrained=False, progress=True, **kwargs): + """RAFT "small" model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Args: + pretrained (bool): TODO not implemented yet + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. + + Returns: + nn.Module: The model. + + """ + + if pretrained: + raise ValueError("Pretrained weights aren't available yet") + + return _raft( + # Feature encoder + feature_encoder_layers=(32, 32, 64, 96, 128), + feature_encoder_block=BottleneckBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(32, 32, 64, 96, 160), + context_encoder_block=BottleneckBlock, + context_encoder_norm_layer=None, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=3, + # Motion encoder + motion_encoder_corr_layers=(96,), + motion_encoder_flow_layers=(64, 32), + motion_encoder_out_channels=82, + # Recurrent block + recurrent_block_hidden_state_size=96, + recurrent_block_kernel_size=(3,), + recurrent_block_padding=(1,), + # Flow head + flow_head_hidden_size=128, + # Mask predictor + use_mask_predictor=False, + **kwargs, + ) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 10a03a907e8..5ec46669be2 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -66,8 +66,7 @@ def batched_nms( _log_api_usage_once("torchvision.ops.batched_nms") # Benchmarks that drove the following thresholds are at # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 - # Ideally for GPU we'd use a higher threshold - if boxes.numel() > 4_000 and not torchvision._is_tracing(): + if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing(): return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) else: return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index fac9a3570d6..392517cb772 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -116,6 +116,7 @@ class ConvNormActivation(torch.nn.Sequential): activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` dilation (int): Spacing between kernel elements. Default: 1 inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """ @@ -131,9 +132,12 @@ def __init__( activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, inplace: bool = True, + bias: Optional[bool] = None, ) -> None: if padding is None: padding = (kernel_size - 1) // 2 * dilation + if bias is None: + bias = norm_layer is None layers = [ torch.nn.Conv2d( in_channels, @@ -143,7 +147,7 @@ def __init__( padding, dilation=dilation, groups=groups, - bias=norm_layer is None, + bias=bias, ) ] if norm_layer is not None: diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index a0cd238dc75..05cf5e4032e 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -1,6 +1,8 @@ +import warnings from typing import Optional, List, Dict, Tuple, Union import torch +import torch.fx import torchvision from torch import nn, Tensor from torchvision.ops.boxes import box_area @@ -106,6 +108,126 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float: return possible_scales[0] +@torch.fx.wrap +def _setup_scales( + features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int +) -> Tuple[List[float], LevelMapper]: + assert len(image_shapes) != 0 + max_x = 0 + max_y = 0 + for shape in image_shapes: + max_x = max(shape[0], max_x) + max_y = max(shape[1], max_y) + original_input_shape = (max_x, max_y) + + scales = [_infer_scale(feat, original_input_shape) for feat in features] + # get the levels in the feature map by leveraging the fact that the network always + # downsamples by a factor of 2 at each level. + lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() + lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() + + map_levels = initLevelMapper( + int(lvl_min), + int(lvl_max), + canonical_scale=canonical_scale, + canonical_level=canonical_level, + ) + return scales, map_levels + + +@torch.fx.wrap +def _filter_input(x: Dict[str, Tensor], featmap_names: List[str]) -> List[Tensor]: + x_filtered = [] + for k, v in x.items(): + if k in featmap_names: + x_filtered.append(v) + return x_filtered + + +@torch.fx.wrap +def _multiscale_roi_align( + x_filtered: List[Tensor], + boxes: List[Tensor], + output_size: List[int], + sampling_ratio: int, + scales: Optional[List[float]], + mapper: Optional[LevelMapper], +) -> Tensor: + """ + Args: + x_filtered (List[Tensor]): List of input tensors. + boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in + (x1, y1, x2, y2) format and in the image reference size, not the feature map + reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + output_size (Union[List[Tuple[int, int]], List[int]]): size of the output + sampling_ratio (int): sampling ratio for ROIAlign + scales (Optional[List[float]]): If None, scales will be automatically infered. Default value is None. + mapper (Optional[LevelMapper]): If none, mapper will be automatically infered. Default value is None. + Returns: + result (Tensor) + """ + assert scales is not None + assert mapper is not None + + num_levels = len(x_filtered) + rois = _convert_to_roi_format(boxes) + + if num_levels == 1: + return roi_align( + x_filtered[0], + rois, + output_size=output_size, + spatial_scale=scales[0], + sampling_ratio=sampling_ratio, + ) + + levels = mapper(boxes) + + num_rois = len(rois) + num_channels = x_filtered[0].shape[1] + + dtype, device = x_filtered[0].dtype, x_filtered[0].device + result = torch.zeros( + ( + num_rois, + num_channels, + ) + + output_size, + dtype=dtype, + device=device, + ) + + tracing_results = [] + for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)): + idx_in_level = torch.where(levels == level)[0] + rois_per_level = rois[idx_in_level] + + result_idx_in_level = roi_align( + per_level_feature, + rois_per_level, + output_size=output_size, + spatial_scale=scale, + sampling_ratio=sampling_ratio, + ) + + if torchvision._is_tracing(): + tracing_results.append(result_idx_in_level.to(dtype)) + else: + # result and result_idx_in_level's dtypes are based on dtypes of different + # elements in x_filtered. x_filtered contains tensors output by different + # layers. When autocast is active, it may choose different dtypes for + # different layers' outputs. Therefore, we defensively match result's dtype + # before copying elements from result_idx_in_level in the following op. + # We need to cast manually (can't rely on autocast to cast for us) because + # the op acts on result in-place, and autocast only affects out-of-place ops. + result[idx_in_level] = result_idx_in_level.to(result.dtype) + + if torchvision._is_tracing(): + result = _onnx_merge_levels(levels, tracing_results) + + return result + + class MultiScaleRoIAlign(nn.Module): """ Multi-scale RoIAlign pooling, which is useful for detection with or without FPN. @@ -165,31 +287,24 @@ def __init__( self.canonical_scale = canonical_scale self.canonical_level = canonical_level - def setup_scales( + def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor: + # TODO: deprecate eventually + warnings.warn("`convert_to_roi_format` will no loger be public in future releases.", FutureWarning) + return _convert_to_roi_format(boxes) + + def infer_scale(self, feature: Tensor, original_size: List[int]) -> float: + # TODO: deprecate eventually + warnings.warn("`infer_scale` will no loger be public in future releases.", FutureWarning) + return _infer_scale(feature, original_size) + + def setup_setup_scales( self, features: List[Tensor], image_shapes: List[Tuple[int, int]], ) -> None: - assert len(image_shapes) != 0 - max_x = 0 - max_y = 0 - for shape in image_shapes: - max_x = max(shape[0], max_x) - max_y = max(shape[1], max_y) - original_input_shape = (max_x, max_y) - - scales = [_infer_scale(feat, original_input_shape) for feat in features] - # get the levels in the feature map by leveraging the fact that the network always - # downsamples by a factor of 2 at each level. - lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() - lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() - self.scales = scales - self.map_levels = initLevelMapper( - int(lvl_min), - int(lvl_max), - canonical_scale=self.canonical_scale, - canonical_level=self.canonical_level, - ) + # TODO: deprecate eventually + warnings.warn("`setup_setup_scales` will no loger be public in future releases.", FutureWarning) + self.scales, self.map_levels = _setup_scales(features, image_shapes, self.canonical_scale, self.canonical_level) def forward( self, @@ -210,76 +325,21 @@ def forward( Returns: result (Tensor) """ - x_filtered = [] - for k, v in x.items(): - if k in self.featmap_names: - x_filtered.append(v) - num_levels = len(x_filtered) - rois = _convert_to_roi_format(boxes) - if self.scales is None: - self.setup_scales(x_filtered, image_shapes) - - scales = self.scales - assert scales is not None - - if num_levels == 1: - return roi_align( - x_filtered[0], - rois, - output_size=self.output_size, - spatial_scale=scales[0], - sampling_ratio=self.sampling_ratio, + x_filtered = _filter_input(x, self.featmap_names) + if self.scales is None or self.map_levels is None: + self.scales, self.map_levels = _setup_scales( + x_filtered, image_shapes, self.canonical_scale, self.canonical_level ) - mapper = self.map_levels - assert mapper is not None - - levels = mapper(boxes) - - num_rois = len(rois) - num_channels = x_filtered[0].shape[1] - - dtype, device = x_filtered[0].dtype, x_filtered[0].device - result = torch.zeros( - ( - num_rois, - num_channels, - ) - + self.output_size, - dtype=dtype, - device=device, + return _multiscale_roi_align( + x_filtered, + boxes, + self.output_size, + self.sampling_ratio, + self.scales, + self.map_levels, ) - tracing_results = [] - for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)): - idx_in_level = torch.where(levels == level)[0] - rois_per_level = rois[idx_in_level] - - result_idx_in_level = roi_align( - per_level_feature, - rois_per_level, - output_size=self.output_size, - spatial_scale=scale, - sampling_ratio=self.sampling_ratio, - ) - - if torchvision._is_tracing(): - tracing_results.append(result_idx_in_level.to(dtype)) - else: - # result and result_idx_in_level's dtypes are based on dtypes of different - # elements in x_filtered. x_filtered contains tensors output by different - # layers. When autocast is active, it may choose different dtypes for - # different layers' outputs. Therefore, we defensively match result's dtype - # before copying elements from result_idx_in_level in the following op. - # We need to cast manually (can't rely on autocast to cast for us) because - # the op acts on result in-place, and autocast only affects out-of-place ops. - result[idx_in_level] = result_idx_in_level.to(result.dtype) - - if torchvision._is_tracing(): - result = _onnx_merge_levels(levels, tracing_results) - - return result - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(featmap_names={self.featmap_names}, " diff --git a/torchvision/prototype/datasets/_builtin/coco.categories b/torchvision/prototype/datasets/_builtin/coco.categories new file mode 100644 index 00000000000..27e612f6d7d --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/coco.categories @@ -0,0 +1,91 @@ +__background__,N/A +person,person +bicycle,vehicle +car,vehicle +motorcycle,vehicle +airplane,vehicle +bus,vehicle +train,vehicle +truck,vehicle +boat,vehicle +traffic light,outdoor +fire hydrant,outdoor +N/A,N/A +stop sign,outdoor +parking meter,outdoor +bench,outdoor +bird,animal +cat,animal +dog,animal +horse,animal +sheep,animal +cow,animal +elephant,animal +bear,animal +zebra,animal +giraffe,animal +N/A,N/A +backpack,accessory +umbrella,accessory +N/A,N/A +N/A,N/A +handbag,accessory +tie,accessory +suitcase,accessory +frisbee,sports +skis,sports +snowboard,sports +sports ball,sports +kite,sports +baseball bat,sports +baseball glove,sports +skateboard,sports +surfboard,sports +tennis racket,sports +bottle,kitchen +N/A,N/A +wine glass,kitchen +cup,kitchen +fork,kitchen +knife,kitchen +spoon,kitchen +bowl,kitchen +banana,food +apple,food +sandwich,food +orange,food +broccoli,food +carrot,food +hot dog,food +pizza,food +donut,food +cake,food +chair,furniture +couch,furniture +potted plant,furniture +bed,furniture +N/A,N/A +dining table,furniture +N/A,N/A +N/A,N/A +toilet,furniture +N/A,N/A +tv,electronic +laptop,electronic +mouse,electronic +remote,electronic +keyboard,electronic +cell phone,electronic +microwave,appliance +oven,appliance +toaster,appliance +sink,appliance +refrigerator,appliance +N/A,N/A +book,indoor +clock,indoor +vase,indoor +scissors,indoor +teddy bear,indoor +hair drier,indoor +toothbrush,indoor diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 641d584dc43..0ba34167b51 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,6 +1,8 @@ import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +import re +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import torch from torchdata.datapipes.iter import ( @@ -26,24 +28,31 @@ from torchvision.prototype.datasets.utils._internal import ( MappingIterator, INFINITE_BUFFER_SIZE, + BUILTIN_DIR, getitem, path_accessor, - path_comparator, ) - -HERE = pathlib.Path(__file__).parent +from torchvision.prototype.features import BoundingBox, Label, Feature +from torchvision.prototype.utils._internal import FrozenMapping class Coco(Dataset): def _make_info(self) -> DatasetInfo: + name = "coco" + categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) + return DatasetInfo( - "coco", + name, type=DatasetType.IMAGE, + dependencies=("pycocotools",), + categories=categories, homepage="https://cocodataset.org/", valid_options=dict( split=("train", "val"), year=("2017", "2014"), + annotations=(*self._ANN_DECODERS.keys(), None), ), + extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))), ) _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" @@ -73,6 +82,64 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [images, meta] + def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size: Tuple[int, int]) -> torch.Tensor: + from pycocotools import mask + + if is_crowd: + segmentation = mask.frPyObjects(segmentation, *image_size) + else: + segmentation = mask.merge(mask.frPyObjects(segmentation, *image_size)) + + return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) + + def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: + image_size = (image_meta["height"], image_meta["width"]) + labels = [ann["category_id"] for ann in anns] + categories = [self.info.categories[label] for label in labels] + return dict( + # TODO: create a segmentation feature + segmentations=Feature( + torch.stack( + [ + self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size) + for ann in anns + ] + ) + ), + areas=Feature([ann["area"] for ann in anns]), + crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool), + bounding_boxes=BoundingBox( + [ann["bbox"] for ann in anns], + format="xywh", + image_size=image_size, + ), + labels=Label(labels), + categories=categories, + super_categories=[self.info.extra.category_to_super_category[category] for category in categories], + ann_ids=[ann["id"] for ann in anns], + ) + + def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: + return dict( + captions=[ann["caption"] for ann in anns], + ann_ids=[ann["id"] for ann in anns], + ) + + _ANN_DECODERS = OrderedDict( + [ + ("instances", _decode_instances_anns), + ("captions", _decode_captions_ann), + ] + ) + + _META_FILE_PATTERN = re.compile( + fr"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" + ) + + def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool: + match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) + return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations) + def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data if key == "images": @@ -82,28 +149,27 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _decode_ann(self, ann: Dict[str, Any]) -> Dict[str, Any]: - area = torch.tensor(ann["area"]) - iscrowd = bool(ann["iscrowd"]) - bbox = torch.tensor(ann["bbox"]) - id = ann["id"] - return dict(area=area, iscrowd=iscrowd, bbox=bbox, id=id) + def _collate_and_decode_image( + self, data: Tuple[str, io.IOBase], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] + ) -> Dict[str, Any]: + path, buffer = data + return dict(path=path, image=decoder(buffer) if decoder else buffer) def _collate_and_decode_sample( self, data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]], *, + annotations: Optional[str], decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data - path, buffer = image_data - - anns = [self._decode_ann(ann) for ann in anns] - image = decoder(buffer) if decoder else buffer + sample = self._collate_and_decode_image(image_data, decoder=decoder) + if annotations: + sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) - return dict(anns=anns, id=image_meta["id"], path=path, image=image) + return sample def _make_datapipe( self, @@ -114,8 +180,18 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps + images_dp = ZipArchiveReader(images_dp) + + if config.annotations is None: + dp = Shuffler(images_dp) + return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder)) + meta_dp = ZipArchiveReader(meta_dp) - meta_dp = Filter(meta_dp, path_comparator("name", f"instances_{config.split}{config.year}.json")) + meta_dp = Filter( + meta_dp, + self._filter_meta_files, + fn_kwargs=dict(split=config.split, year=config.year, annotations=config.annotations), + ) meta_dp = JsonParser(meta_dp) meta_dp = Mapper(meta_dp, getitem(1)) meta_dp = MappingIterator(meta_dp) @@ -129,24 +205,20 @@ def _make_datapipe( images_meta_dp = Mapper(images_meta_dp, getitem(1)) images_meta_dp = UnBatcher(images_meta_dp) + images_meta_dp = Shuffler(images_meta_dp) anns_meta_dp = Mapper(anns_meta_dp, getitem(1)) anns_meta_dp = UnBatcher(anns_meta_dp) + anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE) - anns_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE) - # drop images without annotations - anns_dp = Filter(anns_dp, bool) - anns_dp = Shuffler(anns_dp, buffer_size=INFINITE_BUFFER_SIZE) anns_dp = IterKeyZipper( - anns_dp, + anns_meta_dp, images_meta_dp, key_fn=getitem(0, "image_id"), ref_key_fn=getitem("id"), buffer_size=INFINITE_BUFFER_SIZE, ) - images_dp = ZipArchiveReader(images_dp) - dp = IterKeyZipper( anns_dp, images_dp, @@ -154,4 +226,35 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper( + dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder) + ) + + def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: + config = self.default_config + resources = self.resources(config) + + dp = resources[1].to_datapipe(pathlib.Path(root) / self.name) + dp = ZipArchiveReader(dp) + dp = Filter( + dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances") + ) + dp = JsonParser(dp) + + _, meta = next(iter(dp)) + # List[Tuple[super_category, id, category]] + label_data = [cast(Tuple[str, int, str], tuple(info.values())) for info in meta["categories"]] + + # COCO actually defines 91 categories, but only 80 of them have instances. Still, the category_id refers to the + # full set. To keep the labels dense, we fill the gaps with N/A. Note that there are only 10 gaps, so the total + # number of categories is 90 rather than 91. + _, ids, _ = zip(*label_data) + missing_ids = set(range(1, max(ids) + 1)) - set(ids) + label_data.extend([("N/A", id, "N/A") for id in missing_ids]) + + # We also add a background category to be used during segmentation. + label_data.append(("N/A", 0, "__background__")) + + super_categories, _, categories = zip(*sorted(label_data, key=lambda info: info[1])) + + return cast(Tuple[Tuple[str, str]], tuple(zip(categories, super_categories))) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.categories b/torchvision/prototype/datasets/_builtin/imagenet.categories index 18e24b85311..7b6006ff57f 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.categories +++ b/torchvision/prototype/datasets/_builtin/imagenet.categories @@ -1,1000 +1,1000 @@ -tench,n01440764 -goldfish,n01443537 -great white shark,n01484850 -tiger shark,n01491361 -hammerhead,n01494475 -electric ray,n01496331 -stingray,n01498041 -cock,n01514668 -hen,n01514859 -ostrich,n01518878 -brambling,n01530575 -goldfinch,n01531178 -house finch,n01532829 -junco,n01534433 -indigo bunting,n01537544 -robin,n01558993 -bulbul,n01560419 -jay,n01580077 -magpie,n01582220 -chickadee,n01592084 -water ouzel,n01601694 -kite,n01608432 -bald eagle,n01614925 -vulture,n01616318 -great grey owl,n01622779 -European fire salamander,n01629819 -common newt,n01630670 -eft,n01631663 -spotted salamander,n01632458 -axolotl,n01632777 -bullfrog,n01641577 -tree frog,n01644373 -tailed frog,n01644900 -loggerhead,n01664065 -leatherback turtle,n01665541 -mud turtle,n01667114 -terrapin,n01667778 -box turtle,n01669191 -banded gecko,n01675722 -common iguana,n01677366 -American chameleon,n01682714 -whiptail,n01685808 -agama,n01687978 -frilled lizard,n01688243 -alligator lizard,n01689811 -Gila monster,n01692333 -green lizard,n01693334 -African chameleon,n01694178 -Komodo dragon,n01695060 -African crocodile,n01697457 -American alligator,n01698640 -triceratops,n01704323 -thunder snake,n01728572 -ringneck snake,n01728920 -hognose snake,n01729322 -green snake,n01729977 -king snake,n01734418 -garter snake,n01735189 -water snake,n01737021 -vine snake,n01739381 -night snake,n01740131 -boa constrictor,n01742172 -rock python,n01744401 -Indian cobra,n01748264 -green mamba,n01749939 -sea snake,n01751748 -horned viper,n01753488 -diamondback,n01755581 -sidewinder,n01756291 -trilobite,n01768244 -harvestman,n01770081 -scorpion,n01770393 -black and gold garden spider,n01773157 -barn spider,n01773549 -garden spider,n01773797 -black widow,n01774384 -tarantula,n01774750 -wolf spider,n01775062 -tick,n01776313 -centipede,n01784675 -black grouse,n01795545 -ptarmigan,n01796340 -ruffed grouse,n01797886 -prairie chicken,n01798484 -peacock,n01806143 -quail,n01806567 -partridge,n01807496 -African grey,n01817953 -macaw,n01818515 -sulphur-crested cockatoo,n01819313 -lorikeet,n01820546 -coucal,n01824575 -bee eater,n01828970 -hornbill,n01829413 -hummingbird,n01833805 -jacamar,n01843065 -toucan,n01843383 -drake,n01847000 -red-breasted merganser,n01855032 -goose,n01855672 -black swan,n01860187 -tusker,n01871265 -echidna,n01872401 -platypus,n01873310 -wallaby,n01877812 -koala,n01882714 -wombat,n01883070 -jellyfish,n01910747 -sea anemone,n01914609 -brain coral,n01917289 -flatworm,n01924916 -nematode,n01930112 -conch,n01943899 -snail,n01944390 -slug,n01945685 -sea slug,n01950731 -chiton,n01955084 -chambered nautilus,n01968897 -Dungeness crab,n01978287 -rock crab,n01978455 -fiddler crab,n01980166 -king crab,n01981276 -American lobster,n01983481 -spiny lobster,n01984695 -crayfish,n01985128 -hermit crab,n01986214 -isopod,n01990800 -white stork,n02002556 -black stork,n02002724 -spoonbill,n02006656 -flamingo,n02007558 -little blue heron,n02009229 -American egret,n02009912 -bittern,n02011460 -crane,n02012849 -limpkin,n02013706 -European gallinule,n02017213 -American coot,n02018207 -bustard,n02018795 -ruddy turnstone,n02025239 -red-backed sandpiper,n02027492 -redshank,n02028035 -dowitcher,n02033041 -oystercatcher,n02037110 -pelican,n02051845 -king penguin,n02056570 -albatross,n02058221 -grey whale,n02066245 -killer whale,n02071294 -dugong,n02074367 -sea lion,n02077923 -Chihuahua,n02085620 -Japanese spaniel,n02085782 -Maltese dog,n02085936 -Pekinese,n02086079 -Shih-Tzu,n02086240 -Blenheim spaniel,n02086646 -papillon,n02086910 -toy terrier,n02087046 -Rhodesian ridgeback,n02087394 -Afghan hound,n02088094 -basset,n02088238 -beagle,n02088364 -bloodhound,n02088466 -bluetick,n02088632 -black-and-tan coonhound,n02089078 -Walker hound,n02089867 -English foxhound,n02089973 -redbone,n02090379 -borzoi,n02090622 -Irish wolfhound,n02090721 -Italian greyhound,n02091032 -whippet,n02091134 -Ibizan hound,n02091244 -Norwegian elkhound,n02091467 -otterhound,n02091635 -Saluki,n02091831 -Scottish deerhound,n02092002 -Weimaraner,n02092339 -Staffordshire bullterrier,n02093256 -American Staffordshire terrier,n02093428 -Bedlington terrier,n02093647 -Border terrier,n02093754 -Kerry blue terrier,n02093859 -Irish terrier,n02093991 -Norfolk terrier,n02094114 -Norwich terrier,n02094258 -Yorkshire terrier,n02094433 -wire-haired fox terrier,n02095314 -Lakeland terrier,n02095570 -Sealyham terrier,n02095889 -Airedale,n02096051 -cairn,n02096177 -Australian terrier,n02096294 -Dandie Dinmont,n02096437 -Boston bull,n02096585 -miniature schnauzer,n02097047 -giant schnauzer,n02097130 -standard schnauzer,n02097209 -Scotch terrier,n02097298 -Tibetan terrier,n02097474 -silky terrier,n02097658 -soft-coated wheaten terrier,n02098105 -West Highland white terrier,n02098286 -Lhasa,n02098413 -flat-coated retriever,n02099267 -curly-coated retriever,n02099429 -golden retriever,n02099601 -Labrador retriever,n02099712 -Chesapeake Bay retriever,n02099849 -German short-haired pointer,n02100236 -vizsla,n02100583 -English setter,n02100735 -Irish setter,n02100877 -Gordon setter,n02101006 -Brittany spaniel,n02101388 -clumber,n02101556 -English springer,n02102040 -Welsh springer spaniel,n02102177 -cocker spaniel,n02102318 -Sussex spaniel,n02102480 -Irish water spaniel,n02102973 -kuvasz,n02104029 -schipperke,n02104365 -groenendael,n02105056 -malinois,n02105162 -briard,n02105251 -kelpie,n02105412 -komondor,n02105505 -Old English sheepdog,n02105641 -Shetland sheepdog,n02105855 -collie,n02106030 -Border collie,n02106166 -Bouvier des Flandres,n02106382 -Rottweiler,n02106550 -German shepherd,n02106662 -Doberman,n02107142 -miniature pinscher,n02107312 -Greater Swiss Mountain dog,n02107574 -Bernese mountain dog,n02107683 -Appenzeller,n02107908 -EntleBucher,n02108000 -boxer,n02108089 -bull mastiff,n02108422 -Tibetan mastiff,n02108551 -French bulldog,n02108915 -Great Dane,n02109047 -Saint Bernard,n02109525 -Eskimo dog,n02109961 -malamute,n02110063 -Siberian husky,n02110185 -dalmatian,n02110341 -affenpinscher,n02110627 -basenji,n02110806 -pug,n02110958 -Leonberg,n02111129 -Newfoundland,n02111277 -Great Pyrenees,n02111500 -Samoyed,n02111889 -Pomeranian,n02112018 -chow,n02112137 -keeshond,n02112350 -Brabancon griffon,n02112706 -Pembroke,n02113023 -Cardigan,n02113186 -toy poodle,n02113624 -miniature poodle,n02113712 -standard poodle,n02113799 -Mexican hairless,n02113978 -timber wolf,n02114367 -white wolf,n02114548 -red wolf,n02114712 -coyote,n02114855 -dingo,n02115641 -dhole,n02115913 -African hunting dog,n02116738 -hyena,n02117135 -red fox,n02119022 -kit fox,n02119789 -Arctic fox,n02120079 -grey fox,n02120505 -tabby,n02123045 -tiger cat,n02123159 -Persian cat,n02123394 -Siamese cat,n02123597 -Egyptian cat,n02124075 -cougar,n02125311 -lynx,n02127052 -leopard,n02128385 -snow leopard,n02128757 -jaguar,n02128925 -lion,n02129165 -tiger,n02129604 -cheetah,n02130308 -brown bear,n02132136 -American black bear,n02133161 -ice bear,n02134084 -sloth bear,n02134418 -mongoose,n02137549 -meerkat,n02138441 -tiger beetle,n02165105 -ladybug,n02165456 -ground beetle,n02167151 -long-horned beetle,n02168699 -leaf beetle,n02169497 -dung beetle,n02172182 -rhinoceros beetle,n02174001 -weevil,n02177972 -fly,n02190166 -bee,n02206856 -ant,n02219486 -grasshopper,n02226429 -cricket,n02229544 -walking stick,n02231487 -cockroach,n02233338 -mantis,n02236044 -cicada,n02256656 -leafhopper,n02259212 -lacewing,n02264363 -dragonfly,n02268443 -damselfly,n02268853 -admiral,n02276258 -ringlet,n02277742 -monarch,n02279972 -cabbage butterfly,n02280649 -sulphur butterfly,n02281406 -lycaenid,n02281787 -starfish,n02317335 -sea urchin,n02319095 -sea cucumber,n02321529 -wood rabbit,n02325366 -hare,n02326432 -Angora,n02328150 -hamster,n02342885 -porcupine,n02346627 -fox squirrel,n02356798 -marmot,n02361337 -beaver,n02363005 -guinea pig,n02364673 -sorrel,n02389026 -zebra,n02391049 -hog,n02395406 -wild boar,n02396427 -warthog,n02397096 -hippopotamus,n02398521 -ox,n02403003 -water buffalo,n02408429 -bison,n02410509 -ram,n02412080 -bighorn,n02415577 -ibex,n02417914 -hartebeest,n02422106 -impala,n02422699 -gazelle,n02423022 -Arabian camel,n02437312 -llama,n02437616 -weasel,n02441942 -mink,n02442845 -polecat,n02443114 -black-footed ferret,n02443484 -otter,n02444819 -skunk,n02445715 -badger,n02447366 -armadillo,n02454379 -three-toed sloth,n02457408 -orangutan,n02480495 -gorilla,n02480855 -chimpanzee,n02481823 -gibbon,n02483362 -siamang,n02483708 -guenon,n02484975 -patas,n02486261 -baboon,n02486410 -macaque,n02487347 -langur,n02488291 -colobus,n02488702 -proboscis monkey,n02489166 -marmoset,n02490219 -capuchin,n02492035 -howler monkey,n02492660 -titi,n02493509 -spider monkey,n02493793 -squirrel monkey,n02494079 -Madagascar cat,n02497673 -indri,n02500267 -Indian elephant,n02504013 -African elephant,n02504458 -lesser panda,n02509815 -giant panda,n02510455 -barracouta,n02514041 -eel,n02526121 -coho,n02536864 -rock beauty,n02606052 -anemone fish,n02607072 -sturgeon,n02640242 -gar,n02641379 -lionfish,n02643566 -puffer,n02655020 -abacus,n02666196 -abaya,n02667093 -academic gown,n02669723 -accordion,n02672831 -acoustic guitar,n02676566 -aircraft carrier,n02687172 -airliner,n02690373 -airship,n02692877 -altar,n02699494 -ambulance,n02701002 -amphibian,n02704792 -analog clock,n02708093 -apiary,n02727426 -apron,n02730930 -ashcan,n02747177 -assault rifle,n02749479 -backpack,n02769748 -bakery,n02776631 -balance beam,n02777292 -balloon,n02782093 -ballpoint,n02783161 -Band Aid,n02786058 -banjo,n02787622 -bannister,n02788148 -barbell,n02790996 -barber chair,n02791124 -barbershop,n02791270 -barn,n02793495 -barometer,n02794156 -barrel,n02795169 -barrow,n02797295 -baseball,n02799071 -basketball,n02802426 -bassinet,n02804414 -bassoon,n02804610 -bathing cap,n02807133 -bath towel,n02808304 -bathtub,n02808440 -beach wagon,n02814533 -beacon,n02814860 -beaker,n02815834 -bearskin,n02817516 -beer bottle,n02823428 -beer glass,n02823750 -bell cote,n02825657 -bib,n02834397 -bicycle-built-for-two,n02835271 -bikini,n02837789 -binder,n02840245 -binoculars,n02841315 -birdhouse,n02843684 -boathouse,n02859443 -bobsled,n02860847 -bolo tie,n02865351 -bonnet,n02869837 -bookcase,n02870880 -bookshop,n02871525 -bottlecap,n02877765 -bow,n02879718 -bow tie,n02883205 -brass,n02892201 -brassiere,n02892767 -breakwater,n02894605 -breastplate,n02895154 -broom,n02906734 -bucket,n02909870 -buckle,n02910353 -bulletproof vest,n02916936 -bullet train,n02917067 -butcher shop,n02927161 -cab,n02930766 -caldron,n02939185 -candle,n02948072 -cannon,n02950826 -canoe,n02951358 -can opener,n02951585 -cardigan,n02963159 -car mirror,n02965783 -carousel,n02966193 -carpenter's kit,n02966687 -carton,n02971356 -car wheel,n02974003 -cash machine,n02977058 -cassette,n02978881 -cassette player,n02979186 -castle,n02980441 -catamaran,n02981792 -CD player,n02988304 -cello,n02992211 -cellular telephone,n02992529 -chain,n02999410 -chainlink fence,n03000134 -chain mail,n03000247 -chain saw,n03000684 -chest,n03014705 -chiffonier,n03016953 -chime,n03017168 -china cabinet,n03018349 -Christmas stocking,n03026506 -church,n03028079 -cinema,n03032252 -cleaver,n03041632 -cliff dwelling,n03042490 -cloak,n03045698 -clog,n03047690 -cocktail shaker,n03062245 -coffee mug,n03063599 -coffeepot,n03063689 -coil,n03065424 -combination lock,n03075370 -computer keyboard,n03085013 -confectionery,n03089624 -container ship,n03095699 -convertible,n03100240 -corkscrew,n03109150 -cornet,n03110669 -cowboy boot,n03124043 -cowboy hat,n03124170 -cradle,n03125729 -construction crane,n03126707 -crash helmet,n03127747 -crate,n03127925 -crib,n03131574 -Crock Pot,n03133878 -croquet ball,n03134739 -crutch,n03141823 -cuirass,n03146219 -dam,n03160309 -desk,n03179701 -desktop computer,n03180011 -dial telephone,n03187595 -diaper,n03188531 -digital clock,n03196217 -digital watch,n03197337 -dining table,n03201208 -dishrag,n03207743 -dishwasher,n03207941 -disk brake,n03208938 -dock,n03216828 -dogsled,n03218198 -dome,n03220513 -doormat,n03223299 -drilling platform,n03240683 -drum,n03249569 -drumstick,n03250847 -dumbbell,n03255030 -Dutch oven,n03259280 -electric fan,n03271574 -electric guitar,n03272010 -electric locomotive,n03272562 -entertainment center,n03290653 -envelope,n03291819 -espresso maker,n03297495 -face powder,n03314780 -feather boa,n03325584 -file,n03337140 -fireboat,n03344393 -fire engine,n03345487 -fire screen,n03347037 -flagpole,n03355925 -flute,n03372029 -folding chair,n03376595 -football helmet,n03379051 -forklift,n03384352 -fountain,n03388043 -fountain pen,n03388183 -four-poster,n03388549 -freight car,n03393912 -French horn,n03394916 -frying pan,n03400231 -fur coat,n03404251 -garbage truck,n03417042 -gasmask,n03424325 -gas pump,n03425413 -goblet,n03443371 -go-kart,n03444034 -golf ball,n03445777 -golfcart,n03445924 -gondola,n03447447 -gong,n03447721 -gown,n03450230 -grand piano,n03452741 -greenhouse,n03457902 -grille,n03459775 -grocery store,n03461385 -guillotine,n03467068 -hair slide,n03476684 -hair spray,n03476991 -half track,n03478589 -hammer,n03481172 -hamper,n03482405 -hand blower,n03483316 -hand-held computer,n03485407 -handkerchief,n03485794 -hard disc,n03492542 -harmonica,n03494278 -harp,n03495258 -harvester,n03496892 -hatchet,n03498962 -holster,n03527444 -home theater,n03529860 -honeycomb,n03530642 -hook,n03532672 -hoopskirt,n03534580 -horizontal bar,n03535780 -horse cart,n03538406 -hourglass,n03544143 -iPod,n03584254 -iron,n03584829 -jack-o'-lantern,n03590841 -jean,n03594734 -jeep,n03594945 -jersey,n03595614 -jigsaw puzzle,n03598930 -jinrikisha,n03599486 -joystick,n03602883 -kimono,n03617480 -knee pad,n03623198 -knot,n03627232 -lab coat,n03630383 -ladle,n03633091 -lampshade,n03637318 -laptop,n03642806 -lawn mower,n03649909 -lens cap,n03657121 -letter opener,n03658185 -library,n03661043 -lifeboat,n03662601 -lighter,n03666591 -limousine,n03670208 -liner,n03673027 -lipstick,n03676483 -Loafer,n03680355 -lotion,n03690938 -loudspeaker,n03691459 -loupe,n03692522 -lumbermill,n03697007 -magnetic compass,n03706229 -mailbag,n03709823 -mailbox,n03710193 -maillot,n03710637 -tank suit,n03710721 -manhole cover,n03717622 -maraca,n03720891 -marimba,n03721384 -mask,n03724870 -matchstick,n03729826 -maypole,n03733131 -maze,n03733281 -measuring cup,n03733805 -medicine chest,n03742115 -megalith,n03743016 -microphone,n03759954 -microwave,n03761084 -military uniform,n03763968 -milk can,n03764736 -minibus,n03769881 -miniskirt,n03770439 -minivan,n03770679 -missile,n03773504 -mitten,n03775071 -mixing bowl,n03775546 -mobile home,n03776460 -Model T,n03777568 -modem,n03777754 -monastery,n03781244 -monitor,n03782006 -moped,n03785016 -mortar,n03786901 -mortarboard,n03787032 -mosque,n03788195 -mosquito net,n03788365 -motor scooter,n03791053 -mountain bike,n03792782 -mountain tent,n03792972 -mouse,n03793489 -mousetrap,n03794056 -moving van,n03796401 -muzzle,n03803284 -nail,n03804744 -neck brace,n03814639 -necklace,n03814906 -nipple,n03825788 -notebook,n03832673 -obelisk,n03837869 -oboe,n03838899 -ocarina,n03840681 -odometer,n03841143 -oil filter,n03843555 -organ,n03854065 -oscilloscope,n03857828 -overskirt,n03866082 -oxcart,n03868242 -oxygen mask,n03868863 -packet,n03871628 -paddle,n03873416 -paddlewheel,n03874293 -padlock,n03874599 -paintbrush,n03876231 -pajama,n03877472 -palace,n03877845 -panpipe,n03884397 -paper towel,n03887697 -parachute,n03888257 -parallel bars,n03888605 -park bench,n03891251 -parking meter,n03891332 -passenger car,n03895866 -patio,n03899768 -pay-phone,n03902125 -pedestal,n03903868 -pencil box,n03908618 -pencil sharpener,n03908714 -perfume,n03916031 -Petri dish,n03920288 -photocopier,n03924679 -pick,n03929660 -pickelhaube,n03929855 -picket fence,n03930313 -pickup,n03930630 -pier,n03933933 -piggy bank,n03935335 -pill bottle,n03937543 -pillow,n03938244 -ping-pong ball,n03942813 -pinwheel,n03944341 -pirate,n03947888 -pitcher,n03950228 -plane,n03954731 -planetarium,n03956157 -plastic bag,n03958227 -plate rack,n03961711 -plow,n03967562 -plunger,n03970156 -Polaroid camera,n03976467 -pole,n03976657 -police van,n03977966 -poncho,n03980874 -pool table,n03982430 -pop bottle,n03983396 -pot,n03991062 -potter's wheel,n03992509 -power drill,n03995372 -prayer rug,n03998194 -printer,n04004767 -prison,n04005630 -projectile,n04008634 -projector,n04009552 -puck,n04019541 -punching bag,n04023962 -purse,n04026417 -quill,n04033901 -quilt,n04033995 -racer,n04037443 -racket,n04039381 -radiator,n04040759 -radio,n04041544 -radio telescope,n04044716 -rain barrel,n04049303 -recreational vehicle,n04065272 -reel,n04067472 -reflex camera,n04069434 -refrigerator,n04070727 -remote control,n04074963 -restaurant,n04081281 -revolver,n04086273 -rifle,n04090263 -rocking chair,n04099969 -rotisserie,n04111531 -rubber eraser,n04116512 -rugby ball,n04118538 -rule,n04118776 -running shoe,n04120489 -safe,n04125021 -safety pin,n04127249 -saltshaker,n04131690 -sandal,n04133789 -sarong,n04136333 -sax,n04141076 -scabbard,n04141327 -scale,n04141975 -school bus,n04146614 -schooner,n04147183 -scoreboard,n04149813 -screen,n04152593 -screw,n04153751 -screwdriver,n04154565 -seat belt,n04162706 -sewing machine,n04179913 -shield,n04192698 -shoe shop,n04200800 -shoji,n04201297 -shopping basket,n04204238 -shopping cart,n04204347 -shovel,n04208210 -shower cap,n04209133 -shower curtain,n04209239 -ski,n04228054 -ski mask,n04229816 -sleeping bag,n04235860 -slide rule,n04238763 -sliding door,n04239074 -slot,n04243546 -snorkel,n04251144 -snowmobile,n04252077 -snowplow,n04252225 -soap dispenser,n04254120 -soccer ball,n04254680 -sock,n04254777 -solar dish,n04258138 -sombrero,n04259630 -soup bowl,n04263257 -space bar,n04264628 -space heater,n04265275 -space shuttle,n04266014 -spatula,n04270147 -speedboat,n04273569 -spider web,n04275548 -spindle,n04277352 -sports car,n04285008 -spotlight,n04286575 -stage,n04296562 -steam locomotive,n04310018 -steel arch bridge,n04311004 -steel drum,n04311174 -stethoscope,n04317175 -stole,n04325704 -stone wall,n04326547 -stopwatch,n04328186 -stove,n04330267 -strainer,n04332243 -streetcar,n04335435 -stretcher,n04336792 -studio couch,n04344873 -stupa,n04346328 -submarine,n04347754 -suit,n04350905 -sundial,n04355338 -sunglass,n04355933 -sunglasses,n04356056 -sunscreen,n04357314 -suspension bridge,n04366367 -swab,n04367480 -sweatshirt,n04370456 -swimming trunks,n04371430 -swing,n04371774 -switch,n04372370 -syringe,n04376876 -table lamp,n04380533 -tank,n04389033 -tape player,n04392985 -teapot,n04398044 -teddy,n04399382 -television,n04404412 -tennis ball,n04409515 -thatch,n04417672 -theater curtain,n04418357 -thimble,n04423845 -thresher,n04428191 -throne,n04429376 -tile roof,n04435653 -toaster,n04442312 -tobacco shop,n04443257 -toilet seat,n04447861 -torch,n04456115 -totem pole,n04458633 -tow truck,n04461696 -toyshop,n04462240 -tractor,n04465501 -trailer truck,n04467665 -tray,n04476259 -trench coat,n04479046 -tricycle,n04482393 -trimaran,n04483307 -tripod,n04485082 -triumphal arch,n04486054 -trolleybus,n04487081 -trombone,n04487394 -tub,n04493381 -turnstile,n04501370 -typewriter keyboard,n04505470 -umbrella,n04507155 -unicycle,n04509417 -upright,n04515003 -vacuum,n04517823 -vase,n04522168 -vault,n04523525 -velvet,n04525038 -vending machine,n04525305 -vestment,n04532106 -viaduct,n04532670 -violin,n04536866 -volleyball,n04540053 -waffle iron,n04542943 -wall clock,n04548280 -wallet,n04548362 -wardrobe,n04550184 -warplane,n04552348 -washbasin,n04553703 -washer,n04554684 -water bottle,n04557648 -water jug,n04560804 -water tower,n04562935 -whiskey jug,n04579145 -whistle,n04579432 -wig,n04584207 -window screen,n04589890 -window shade,n04590129 -Windsor tie,n04591157 -wine bottle,n04591713 -wing,n04592741 -wok,n04596742 -wooden spoon,n04597913 -wool,n04599235 -worm fence,n04604644 -wreck,n04606251 -yawl,n04612504 -yurt,n04613696 -web site,n06359193 -comic book,n06596364 -crossword puzzle,n06785654 -street sign,n06794110 -traffic light,n06874185 -book jacket,n07248320 -menu,n07565083 -plate,n07579787 -guacamole,n07583066 -consomme,n07584110 -hot pot,n07590611 -trifle,n07613480 -ice cream,n07614500 -ice lolly,n07615774 -French loaf,n07684084 -bagel,n07693725 -pretzel,n07695742 -cheeseburger,n07697313 -hotdog,n07697537 -mashed potato,n07711569 -head cabbage,n07714571 -broccoli,n07714990 -cauliflower,n07715103 -zucchini,n07716358 -spaghetti squash,n07716906 -acorn squash,n07717410 -butternut squash,n07717556 -cucumber,n07718472 -artichoke,n07718747 -bell pepper,n07720875 -cardoon,n07730033 -mushroom,n07734744 -Granny Smith,n07742313 -strawberry,n07745940 -orange,n07747607 -lemon,n07749582 -fig,n07753113 -pineapple,n07753275 -banana,n07753592 -jackfruit,n07754684 -custard apple,n07760859 -pomegranate,n07768694 -hay,n07802026 -carbonara,n07831146 -chocolate sauce,n07836838 -dough,n07860988 -meat loaf,n07871810 -pizza,n07873807 -potpie,n07875152 -burrito,n07880968 -red wine,n07892512 -espresso,n07920052 -cup,n07930864 -eggnog,n07932039 -alp,n09193705 -bubble,n09229709 -cliff,n09246464 -coral reef,n09256479 -geyser,n09288635 -lakeside,n09332890 -promontory,n09399592 -sandbar,n09421951 -seashore,n09428293 -valley,n09468604 -volcano,n09472597 -ballplayer,n09835506 -groom,n10148035 -scuba diver,n10565667 -rapeseed,n11879895 -daisy,n11939491 -yellow lady's slipper,n12057211 -corn,n12144580 -acorn,n12267677 -hip,n12620546 -buckeye,n12768682 -coral fungus,n12985857 -agaric,n12998815 -gyromitra,n13037406 -stinkhorn,n13040303 -earthstar,n13044778 -hen-of-the-woods,n13052670 -bolete,n13054560 -ear,n13133613 -toilet tissue,n15075141 +tench,n01440764 +goldfish,n01443537 +great white shark,n01484850 +tiger shark,n01491361 +hammerhead,n01494475 +electric ray,n01496331 +stingray,n01498041 +cock,n01514668 +hen,n01514859 +ostrich,n01518878 +brambling,n01530575 +goldfinch,n01531178 +house finch,n01532829 +junco,n01534433 +indigo bunting,n01537544 +robin,n01558993 +bulbul,n01560419 +jay,n01580077 +magpie,n01582220 +chickadee,n01592084 +water ouzel,n01601694 +kite,n01608432 +bald eagle,n01614925 +vulture,n01616318 +great grey owl,n01622779 +European fire salamander,n01629819 +common newt,n01630670 +eft,n01631663 +spotted salamander,n01632458 +axolotl,n01632777 +bullfrog,n01641577 +tree frog,n01644373 +tailed frog,n01644900 +loggerhead,n01664065 +leatherback turtle,n01665541 +mud turtle,n01667114 +terrapin,n01667778 +box turtle,n01669191 +banded gecko,n01675722 +common iguana,n01677366 +American chameleon,n01682714 +whiptail,n01685808 +agama,n01687978 +frilled lizard,n01688243 +alligator lizard,n01689811 +Gila monster,n01692333 +green lizard,n01693334 +African chameleon,n01694178 +Komodo dragon,n01695060 +African crocodile,n01697457 +American alligator,n01698640 +triceratops,n01704323 +thunder snake,n01728572 +ringneck snake,n01728920 +hognose snake,n01729322 +green snake,n01729977 +king snake,n01734418 +garter snake,n01735189 +water snake,n01737021 +vine snake,n01739381 +night snake,n01740131 +boa constrictor,n01742172 +rock python,n01744401 +Indian cobra,n01748264 +green mamba,n01749939 +sea snake,n01751748 +horned viper,n01753488 +diamondback,n01755581 +sidewinder,n01756291 +trilobite,n01768244 +harvestman,n01770081 +scorpion,n01770393 +black and gold garden spider,n01773157 +barn spider,n01773549 +garden spider,n01773797 +black widow,n01774384 +tarantula,n01774750 +wolf spider,n01775062 +tick,n01776313 +centipede,n01784675 +black grouse,n01795545 +ptarmigan,n01796340 +ruffed grouse,n01797886 +prairie chicken,n01798484 +peacock,n01806143 +quail,n01806567 +partridge,n01807496 +African grey,n01817953 +macaw,n01818515 +sulphur-crested cockatoo,n01819313 +lorikeet,n01820546 +coucal,n01824575 +bee eater,n01828970 +hornbill,n01829413 +hummingbird,n01833805 +jacamar,n01843065 +toucan,n01843383 +drake,n01847000 +red-breasted merganser,n01855032 +goose,n01855672 +black swan,n01860187 +tusker,n01871265 +echidna,n01872401 +platypus,n01873310 +wallaby,n01877812 +koala,n01882714 +wombat,n01883070 +jellyfish,n01910747 +sea anemone,n01914609 +brain coral,n01917289 +flatworm,n01924916 +nematode,n01930112 +conch,n01943899 +snail,n01944390 +slug,n01945685 +sea slug,n01950731 +chiton,n01955084 +chambered nautilus,n01968897 +Dungeness crab,n01978287 +rock crab,n01978455 +fiddler crab,n01980166 +king crab,n01981276 +American lobster,n01983481 +spiny lobster,n01984695 +crayfish,n01985128 +hermit crab,n01986214 +isopod,n01990800 +white stork,n02002556 +black stork,n02002724 +spoonbill,n02006656 +flamingo,n02007558 +little blue heron,n02009229 +American egret,n02009912 +bittern,n02011460 +crane,n02012849 +limpkin,n02013706 +European gallinule,n02017213 +American coot,n02018207 +bustard,n02018795 +ruddy turnstone,n02025239 +red-backed sandpiper,n02027492 +redshank,n02028035 +dowitcher,n02033041 +oystercatcher,n02037110 +pelican,n02051845 +king penguin,n02056570 +albatross,n02058221 +grey whale,n02066245 +killer whale,n02071294 +dugong,n02074367 +sea lion,n02077923 +Chihuahua,n02085620 +Japanese spaniel,n02085782 +Maltese dog,n02085936 +Pekinese,n02086079 +Shih-Tzu,n02086240 +Blenheim spaniel,n02086646 +papillon,n02086910 +toy terrier,n02087046 +Rhodesian ridgeback,n02087394 +Afghan hound,n02088094 +basset,n02088238 +beagle,n02088364 +bloodhound,n02088466 +bluetick,n02088632 +black-and-tan coonhound,n02089078 +Walker hound,n02089867 +English foxhound,n02089973 +redbone,n02090379 +borzoi,n02090622 +Irish wolfhound,n02090721 +Italian greyhound,n02091032 +whippet,n02091134 +Ibizan hound,n02091244 +Norwegian elkhound,n02091467 +otterhound,n02091635 +Saluki,n02091831 +Scottish deerhound,n02092002 +Weimaraner,n02092339 +Staffordshire bullterrier,n02093256 +American Staffordshire terrier,n02093428 +Bedlington terrier,n02093647 +Border terrier,n02093754 +Kerry blue terrier,n02093859 +Irish terrier,n02093991 +Norfolk terrier,n02094114 +Norwich terrier,n02094258 +Yorkshire terrier,n02094433 +wire-haired fox terrier,n02095314 +Lakeland terrier,n02095570 +Sealyham terrier,n02095889 +Airedale,n02096051 +cairn,n02096177 +Australian terrier,n02096294 +Dandie Dinmont,n02096437 +Boston bull,n02096585 +miniature schnauzer,n02097047 +giant schnauzer,n02097130 +standard schnauzer,n02097209 +Scotch terrier,n02097298 +Tibetan terrier,n02097474 +silky terrier,n02097658 +soft-coated wheaten terrier,n02098105 +West Highland white terrier,n02098286 +Lhasa,n02098413 +flat-coated retriever,n02099267 +curly-coated retriever,n02099429 +golden retriever,n02099601 +Labrador retriever,n02099712 +Chesapeake Bay retriever,n02099849 +German short-haired pointer,n02100236 +vizsla,n02100583 +English setter,n02100735 +Irish setter,n02100877 +Gordon setter,n02101006 +Brittany spaniel,n02101388 +clumber,n02101556 +English springer,n02102040 +Welsh springer spaniel,n02102177 +cocker spaniel,n02102318 +Sussex spaniel,n02102480 +Irish water spaniel,n02102973 +kuvasz,n02104029 +schipperke,n02104365 +groenendael,n02105056 +malinois,n02105162 +briard,n02105251 +kelpie,n02105412 +komondor,n02105505 +Old English sheepdog,n02105641 +Shetland sheepdog,n02105855 +collie,n02106030 +Border collie,n02106166 +Bouvier des Flandres,n02106382 +Rottweiler,n02106550 +German shepherd,n02106662 +Doberman,n02107142 +miniature pinscher,n02107312 +Greater Swiss Mountain dog,n02107574 +Bernese mountain dog,n02107683 +Appenzeller,n02107908 +EntleBucher,n02108000 +boxer,n02108089 +bull mastiff,n02108422 +Tibetan mastiff,n02108551 +French bulldog,n02108915 +Great Dane,n02109047 +Saint Bernard,n02109525 +Eskimo dog,n02109961 +malamute,n02110063 +Siberian husky,n02110185 +dalmatian,n02110341 +affenpinscher,n02110627 +basenji,n02110806 +pug,n02110958 +Leonberg,n02111129 +Newfoundland,n02111277 +Great Pyrenees,n02111500 +Samoyed,n02111889 +Pomeranian,n02112018 +chow,n02112137 +keeshond,n02112350 +Brabancon griffon,n02112706 +Pembroke,n02113023 +Cardigan,n02113186 +toy poodle,n02113624 +miniature poodle,n02113712 +standard poodle,n02113799 +Mexican hairless,n02113978 +timber wolf,n02114367 +white wolf,n02114548 +red wolf,n02114712 +coyote,n02114855 +dingo,n02115641 +dhole,n02115913 +African hunting dog,n02116738 +hyena,n02117135 +red fox,n02119022 +kit fox,n02119789 +Arctic fox,n02120079 +grey fox,n02120505 +tabby,n02123045 +tiger cat,n02123159 +Persian cat,n02123394 +Siamese cat,n02123597 +Egyptian cat,n02124075 +cougar,n02125311 +lynx,n02127052 +leopard,n02128385 +snow leopard,n02128757 +jaguar,n02128925 +lion,n02129165 +tiger,n02129604 +cheetah,n02130308 +brown bear,n02132136 +American black bear,n02133161 +ice bear,n02134084 +sloth bear,n02134418 +mongoose,n02137549 +meerkat,n02138441 +tiger beetle,n02165105 +ladybug,n02165456 +ground beetle,n02167151 +long-horned beetle,n02168699 +leaf beetle,n02169497 +dung beetle,n02172182 +rhinoceros beetle,n02174001 +weevil,n02177972 +fly,n02190166 +bee,n02206856 +ant,n02219486 +grasshopper,n02226429 +cricket,n02229544 +walking stick,n02231487 +cockroach,n02233338 +mantis,n02236044 +cicada,n02256656 +leafhopper,n02259212 +lacewing,n02264363 +dragonfly,n02268443 +damselfly,n02268853 +admiral,n02276258 +ringlet,n02277742 +monarch,n02279972 +cabbage butterfly,n02280649 +sulphur butterfly,n02281406 +lycaenid,n02281787 +starfish,n02317335 +sea urchin,n02319095 +sea cucumber,n02321529 +wood rabbit,n02325366 +hare,n02326432 +Angora,n02328150 +hamster,n02342885 +porcupine,n02346627 +fox squirrel,n02356798 +marmot,n02361337 +beaver,n02363005 +guinea pig,n02364673 +sorrel,n02389026 +zebra,n02391049 +hog,n02395406 +wild boar,n02396427 +warthog,n02397096 +hippopotamus,n02398521 +ox,n02403003 +water buffalo,n02408429 +bison,n02410509 +ram,n02412080 +bighorn,n02415577 +ibex,n02417914 +hartebeest,n02422106 +impala,n02422699 +gazelle,n02423022 +Arabian camel,n02437312 +llama,n02437616 +weasel,n02441942 +mink,n02442845 +polecat,n02443114 +black-footed ferret,n02443484 +otter,n02444819 +skunk,n02445715 +badger,n02447366 +armadillo,n02454379 +three-toed sloth,n02457408 +orangutan,n02480495 +gorilla,n02480855 +chimpanzee,n02481823 +gibbon,n02483362 +siamang,n02483708 +guenon,n02484975 +patas,n02486261 +baboon,n02486410 +macaque,n02487347 +langur,n02488291 +colobus,n02488702 +proboscis monkey,n02489166 +marmoset,n02490219 +capuchin,n02492035 +howler monkey,n02492660 +titi,n02493509 +spider monkey,n02493793 +squirrel monkey,n02494079 +Madagascar cat,n02497673 +indri,n02500267 +Indian elephant,n02504013 +African elephant,n02504458 +lesser panda,n02509815 +giant panda,n02510455 +barracouta,n02514041 +eel,n02526121 +coho,n02536864 +rock beauty,n02606052 +anemone fish,n02607072 +sturgeon,n02640242 +gar,n02641379 +lionfish,n02643566 +puffer,n02655020 +abacus,n02666196 +abaya,n02667093 +academic gown,n02669723 +accordion,n02672831 +acoustic guitar,n02676566 +aircraft carrier,n02687172 +airliner,n02690373 +airship,n02692877 +altar,n02699494 +ambulance,n02701002 +amphibian,n02704792 +analog clock,n02708093 +apiary,n02727426 +apron,n02730930 +ashcan,n02747177 +assault rifle,n02749479 +backpack,n02769748 +bakery,n02776631 +balance beam,n02777292 +balloon,n02782093 +ballpoint,n02783161 +Band Aid,n02786058 +banjo,n02787622 +bannister,n02788148 +barbell,n02790996 +barber chair,n02791124 +barbershop,n02791270 +barn,n02793495 +barometer,n02794156 +barrel,n02795169 +barrow,n02797295 +baseball,n02799071 +basketball,n02802426 +bassinet,n02804414 +bassoon,n02804610 +bathing cap,n02807133 +bath towel,n02808304 +bathtub,n02808440 +beach wagon,n02814533 +beacon,n02814860 +beaker,n02815834 +bearskin,n02817516 +beer bottle,n02823428 +beer glass,n02823750 +bell cote,n02825657 +bib,n02834397 +bicycle-built-for-two,n02835271 +bikini,n02837789 +binder,n02840245 +binoculars,n02841315 +birdhouse,n02843684 +boathouse,n02859443 +bobsled,n02860847 +bolo tie,n02865351 +bonnet,n02869837 +bookcase,n02870880 +bookshop,n02871525 +bottlecap,n02877765 +bow,n02879718 +bow tie,n02883205 +brass,n02892201 +brassiere,n02892767 +breakwater,n02894605 +breastplate,n02895154 +broom,n02906734 +bucket,n02909870 +buckle,n02910353 +bulletproof vest,n02916936 +bullet train,n02917067 +butcher shop,n02927161 +cab,n02930766 +caldron,n02939185 +candle,n02948072 +cannon,n02950826 +canoe,n02951358 +can opener,n02951585 +cardigan,n02963159 +car mirror,n02965783 +carousel,n02966193 +carpenter's kit,n02966687 +carton,n02971356 +car wheel,n02974003 +cash machine,n02977058 +cassette,n02978881 +cassette player,n02979186 +castle,n02980441 +catamaran,n02981792 +CD player,n02988304 +cello,n02992211 +cellular telephone,n02992529 +chain,n02999410 +chainlink fence,n03000134 +chain mail,n03000247 +chain saw,n03000684 +chest,n03014705 +chiffonier,n03016953 +chime,n03017168 +china cabinet,n03018349 +Christmas stocking,n03026506 +church,n03028079 +cinema,n03032252 +cleaver,n03041632 +cliff dwelling,n03042490 +cloak,n03045698 +clog,n03047690 +cocktail shaker,n03062245 +coffee mug,n03063599 +coffeepot,n03063689 +coil,n03065424 +combination lock,n03075370 +computer keyboard,n03085013 +confectionery,n03089624 +container ship,n03095699 +convertible,n03100240 +corkscrew,n03109150 +cornet,n03110669 +cowboy boot,n03124043 +cowboy hat,n03124170 +cradle,n03125729 +construction crane,n03126707 +crash helmet,n03127747 +crate,n03127925 +crib,n03131574 +Crock Pot,n03133878 +croquet ball,n03134739 +crutch,n03141823 +cuirass,n03146219 +dam,n03160309 +desk,n03179701 +desktop computer,n03180011 +dial telephone,n03187595 +diaper,n03188531 +digital clock,n03196217 +digital watch,n03197337 +dining table,n03201208 +dishrag,n03207743 +dishwasher,n03207941 +disk brake,n03208938 +dock,n03216828 +dogsled,n03218198 +dome,n03220513 +doormat,n03223299 +drilling platform,n03240683 +drum,n03249569 +drumstick,n03250847 +dumbbell,n03255030 +Dutch oven,n03259280 +electric fan,n03271574 +electric guitar,n03272010 +electric locomotive,n03272562 +entertainment center,n03290653 +envelope,n03291819 +espresso maker,n03297495 +face powder,n03314780 +feather boa,n03325584 +file,n03337140 +fireboat,n03344393 +fire engine,n03345487 +fire screen,n03347037 +flagpole,n03355925 +flute,n03372029 +folding chair,n03376595 +football helmet,n03379051 +forklift,n03384352 +fountain,n03388043 +fountain pen,n03388183 +four-poster,n03388549 +freight car,n03393912 +French horn,n03394916 +frying pan,n03400231 +fur coat,n03404251 +garbage truck,n03417042 +gasmask,n03424325 +gas pump,n03425413 +goblet,n03443371 +go-kart,n03444034 +golf ball,n03445777 +golfcart,n03445924 +gondola,n03447447 +gong,n03447721 +gown,n03450230 +grand piano,n03452741 +greenhouse,n03457902 +grille,n03459775 +grocery store,n03461385 +guillotine,n03467068 +hair slide,n03476684 +hair spray,n03476991 +half track,n03478589 +hammer,n03481172 +hamper,n03482405 +hand blower,n03483316 +hand-held computer,n03485407 +handkerchief,n03485794 +hard disc,n03492542 +harmonica,n03494278 +harp,n03495258 +harvester,n03496892 +hatchet,n03498962 +holster,n03527444 +home theater,n03529860 +honeycomb,n03530642 +hook,n03532672 +hoopskirt,n03534580 +horizontal bar,n03535780 +horse cart,n03538406 +hourglass,n03544143 +iPod,n03584254 +iron,n03584829 +jack-o'-lantern,n03590841 +jean,n03594734 +jeep,n03594945 +jersey,n03595614 +jigsaw puzzle,n03598930 +jinrikisha,n03599486 +joystick,n03602883 +kimono,n03617480 +knee pad,n03623198 +knot,n03627232 +lab coat,n03630383 +ladle,n03633091 +lampshade,n03637318 +laptop,n03642806 +lawn mower,n03649909 +lens cap,n03657121 +letter opener,n03658185 +library,n03661043 +lifeboat,n03662601 +lighter,n03666591 +limousine,n03670208 +liner,n03673027 +lipstick,n03676483 +Loafer,n03680355 +lotion,n03690938 +loudspeaker,n03691459 +loupe,n03692522 +lumbermill,n03697007 +magnetic compass,n03706229 +mailbag,n03709823 +mailbox,n03710193 +maillot,n03710637 +tank suit,n03710721 +manhole cover,n03717622 +maraca,n03720891 +marimba,n03721384 +mask,n03724870 +matchstick,n03729826 +maypole,n03733131 +maze,n03733281 +measuring cup,n03733805 +medicine chest,n03742115 +megalith,n03743016 +microphone,n03759954 +microwave,n03761084 +military uniform,n03763968 +milk can,n03764736 +minibus,n03769881 +miniskirt,n03770439 +minivan,n03770679 +missile,n03773504 +mitten,n03775071 +mixing bowl,n03775546 +mobile home,n03776460 +Model T,n03777568 +modem,n03777754 +monastery,n03781244 +monitor,n03782006 +moped,n03785016 +mortar,n03786901 +mortarboard,n03787032 +mosque,n03788195 +mosquito net,n03788365 +motor scooter,n03791053 +mountain bike,n03792782 +mountain tent,n03792972 +mouse,n03793489 +mousetrap,n03794056 +moving van,n03796401 +muzzle,n03803284 +nail,n03804744 +neck brace,n03814639 +necklace,n03814906 +nipple,n03825788 +notebook,n03832673 +obelisk,n03837869 +oboe,n03838899 +ocarina,n03840681 +odometer,n03841143 +oil filter,n03843555 +organ,n03854065 +oscilloscope,n03857828 +overskirt,n03866082 +oxcart,n03868242 +oxygen mask,n03868863 +packet,n03871628 +paddle,n03873416 +paddlewheel,n03874293 +padlock,n03874599 +paintbrush,n03876231 +pajama,n03877472 +palace,n03877845 +panpipe,n03884397 +paper towel,n03887697 +parachute,n03888257 +parallel bars,n03888605 +park bench,n03891251 +parking meter,n03891332 +passenger car,n03895866 +patio,n03899768 +pay-phone,n03902125 +pedestal,n03903868 +pencil box,n03908618 +pencil sharpener,n03908714 +perfume,n03916031 +Petri dish,n03920288 +photocopier,n03924679 +pick,n03929660 +pickelhaube,n03929855 +picket fence,n03930313 +pickup,n03930630 +pier,n03933933 +piggy bank,n03935335 +pill bottle,n03937543 +pillow,n03938244 +ping-pong ball,n03942813 +pinwheel,n03944341 +pirate,n03947888 +pitcher,n03950228 +plane,n03954731 +planetarium,n03956157 +plastic bag,n03958227 +plate rack,n03961711 +plow,n03967562 +plunger,n03970156 +Polaroid camera,n03976467 +pole,n03976657 +police van,n03977966 +poncho,n03980874 +pool table,n03982430 +pop bottle,n03983396 +pot,n03991062 +potter's wheel,n03992509 +power drill,n03995372 +prayer rug,n03998194 +printer,n04004767 +prison,n04005630 +projectile,n04008634 +projector,n04009552 +puck,n04019541 +punching bag,n04023962 +purse,n04026417 +quill,n04033901 +quilt,n04033995 +racer,n04037443 +racket,n04039381 +radiator,n04040759 +radio,n04041544 +radio telescope,n04044716 +rain barrel,n04049303 +recreational vehicle,n04065272 +reel,n04067472 +reflex camera,n04069434 +refrigerator,n04070727 +remote control,n04074963 +restaurant,n04081281 +revolver,n04086273 +rifle,n04090263 +rocking chair,n04099969 +rotisserie,n04111531 +rubber eraser,n04116512 +rugby ball,n04118538 +rule,n04118776 +running shoe,n04120489 +safe,n04125021 +safety pin,n04127249 +saltshaker,n04131690 +sandal,n04133789 +sarong,n04136333 +sax,n04141076 +scabbard,n04141327 +scale,n04141975 +school bus,n04146614 +schooner,n04147183 +scoreboard,n04149813 +screen,n04152593 +screw,n04153751 +screwdriver,n04154565 +seat belt,n04162706 +sewing machine,n04179913 +shield,n04192698 +shoe shop,n04200800 +shoji,n04201297 +shopping basket,n04204238 +shopping cart,n04204347 +shovel,n04208210 +shower cap,n04209133 +shower curtain,n04209239 +ski,n04228054 +ski mask,n04229816 +sleeping bag,n04235860 +slide rule,n04238763 +sliding door,n04239074 +slot,n04243546 +snorkel,n04251144 +snowmobile,n04252077 +snowplow,n04252225 +soap dispenser,n04254120 +soccer ball,n04254680 +sock,n04254777 +solar dish,n04258138 +sombrero,n04259630 +soup bowl,n04263257 +space bar,n04264628 +space heater,n04265275 +space shuttle,n04266014 +spatula,n04270147 +speedboat,n04273569 +spider web,n04275548 +spindle,n04277352 +sports car,n04285008 +spotlight,n04286575 +stage,n04296562 +steam locomotive,n04310018 +steel arch bridge,n04311004 +steel drum,n04311174 +stethoscope,n04317175 +stole,n04325704 +stone wall,n04326547 +stopwatch,n04328186 +stove,n04330267 +strainer,n04332243 +streetcar,n04335435 +stretcher,n04336792 +studio couch,n04344873 +stupa,n04346328 +submarine,n04347754 +suit,n04350905 +sundial,n04355338 +sunglass,n04355933 +sunglasses,n04356056 +sunscreen,n04357314 +suspension bridge,n04366367 +swab,n04367480 +sweatshirt,n04370456 +swimming trunks,n04371430 +swing,n04371774 +switch,n04372370 +syringe,n04376876 +table lamp,n04380533 +tank,n04389033 +tape player,n04392985 +teapot,n04398044 +teddy,n04399382 +television,n04404412 +tennis ball,n04409515 +thatch,n04417672 +theater curtain,n04418357 +thimble,n04423845 +thresher,n04428191 +throne,n04429376 +tile roof,n04435653 +toaster,n04442312 +tobacco shop,n04443257 +toilet seat,n04447861 +torch,n04456115 +totem pole,n04458633 +tow truck,n04461696 +toyshop,n04462240 +tractor,n04465501 +trailer truck,n04467665 +tray,n04476259 +trench coat,n04479046 +tricycle,n04482393 +trimaran,n04483307 +tripod,n04485082 +triumphal arch,n04486054 +trolleybus,n04487081 +trombone,n04487394 +tub,n04493381 +turnstile,n04501370 +typewriter keyboard,n04505470 +umbrella,n04507155 +unicycle,n04509417 +upright,n04515003 +vacuum,n04517823 +vase,n04522168 +vault,n04523525 +velvet,n04525038 +vending machine,n04525305 +vestment,n04532106 +viaduct,n04532670 +violin,n04536866 +volleyball,n04540053 +waffle iron,n04542943 +wall clock,n04548280 +wallet,n04548362 +wardrobe,n04550184 +warplane,n04552348 +washbasin,n04553703 +washer,n04554684 +water bottle,n04557648 +water jug,n04560804 +water tower,n04562935 +whiskey jug,n04579145 +whistle,n04579432 +wig,n04584207 +window screen,n04589890 +window shade,n04590129 +Windsor tie,n04591157 +wine bottle,n04591713 +wing,n04592741 +wok,n04596742 +wooden spoon,n04597913 +wool,n04599235 +worm fence,n04604644 +wreck,n04606251 +yawl,n04612504 +yurt,n04613696 +web site,n06359193 +comic book,n06596364 +crossword puzzle,n06785654 +street sign,n06794110 +traffic light,n06874185 +book jacket,n07248320 +menu,n07565083 +plate,n07579787 +guacamole,n07583066 +consomme,n07584110 +hot pot,n07590611 +trifle,n07613480 +ice cream,n07614500 +ice lolly,n07615774 +French loaf,n07684084 +bagel,n07693725 +pretzel,n07695742 +cheeseburger,n07697313 +hotdog,n07697537 +mashed potato,n07711569 +head cabbage,n07714571 +broccoli,n07714990 +cauliflower,n07715103 +zucchini,n07716358 +spaghetti squash,n07716906 +acorn squash,n07717410 +butternut squash,n07717556 +cucumber,n07718472 +artichoke,n07718747 +bell pepper,n07720875 +cardoon,n07730033 +mushroom,n07734744 +Granny Smith,n07742313 +strawberry,n07745940 +orange,n07747607 +lemon,n07749582 +fig,n07753113 +pineapple,n07753275 +banana,n07753592 +jackfruit,n07754684 +custard apple,n07760859 +pomegranate,n07768694 +hay,n07802026 +carbonara,n07831146 +chocolate sauce,n07836838 +dough,n07860988 +meat loaf,n07871810 +pizza,n07873807 +potpie,n07875152 +burrito,n07880968 +red wine,n07892512 +espresso,n07920052 +cup,n07930864 +eggnog,n07932039 +alp,n09193705 +bubble,n09229709 +cliff,n09246464 +coral reef,n09256479 +geyser,n09288635 +lakeside,n09332890 +promontory,n09399592 +sandbar,n09421951 +seashore,n09428293 +valley,n09468604 +volcano,n09472597 +ballplayer,n09835506 +groom,n10148035 +scuba diver,n10565667 +rapeseed,n11879895 +daisy,n11939491 +yellow lady's slipper,n12057211 +corn,n12144580 +acorn,n12267677 +hip,n12620546 +buckeye,n12768682 +coral fungus,n12985857 +agaric,n12998815 +gyromitra,n13037406 +stinkhorn,n13040303 +earthstar,n13044778 +hen-of-the-woods,n13052670 +bolete,n13054560 +ear,n13133613 +toilet tissue,n15075141 diff --git a/torchvision/prototype/datasets/benchmark.py b/torchvision/prototype/datasets/benchmark.py index a555c021368..104ef95c9ae 100644 --- a/torchvision/prototype/datasets/benchmark.py +++ b/torchvision/prototype/datasets/benchmark.py @@ -3,7 +3,6 @@ import argparse import collections.abc import contextlib -import copy import inspect import itertools import os @@ -20,6 +19,7 @@ from torch.utils.data import DataLoader from torch.utils.data.dataloader_experimental import DataLoader2 from torchvision import datasets as legacy_datasets +from torchvision.datasets.utils import extract_archive from torchvision.prototype import datasets as new_datasets from torchvision.transforms import PILToTensor @@ -27,6 +27,7 @@ def main( name, *, + variant=None, legacy=True, new=True, start=True, @@ -36,46 +37,57 @@ def main( temp_root=None, num_workers=0, ): - for benchmark in DATASET_BENCHMARKS: - if benchmark.name == name: - break - else: - raise ValueError(f"No DatasetBenchmark available for dataset '{name}'") - - if legacy and start: - print( - "legacy", - "cold_start", - Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts), - ) - print( - "legacy", - "warm_start", - Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts), - ) + benchmarks = [ + benchmark + for benchmark in DATASET_BENCHMARKS + if benchmark.name == name and (variant is None or benchmark.variant == variant) + ] + if not benchmarks: + msg = f"No DatasetBenchmark available for dataset '{name}'" + if variant is not None: + msg += f" and variant '{variant}'" + raise ValueError(msg) + + for benchmark in benchmarks: + print("#" * 80) + print(f"{benchmark.name}" + (f" ({benchmark.variant})" if benchmark.variant is not None else "")) + + if legacy and start: + print( + "legacy", + "cold_start", + Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts), + ) + print( + "legacy", + "warm_start", + Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts), + ) - if legacy and iteration: - print( - "legacy", - "iteration", - Measurement.iterations_per_time( - benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples) - ), - ) + if legacy and iteration: + print( + "legacy", + "iteration", + Measurement.iterations_per_time( + benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples) + ), + ) - if new and start: - print( - "new", - "cold_start", - Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts), - ) + if new and start: + print( + "new", + "cold_start", + Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts), + ) - if new and iteration: - print( - "new", - "iteration", - Measurement.iterations_per_time(benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)), - ) + if new and iteration: + print( + "new", + "iteration", + Measurement.iterations_per_time( + benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples) + ), + ) class DatasetBenchmark: @@ -83,6 +95,7 @@ def __init__( self, name: str, *, + variant=None, legacy_cls=None, new_config=None, legacy_config_map=None, @@ -90,6 +103,7 @@ def __init__( prepare_legacy_root=None, ): self.name = name + self.variant = variant self.new_raw_dataset = new_datasets._api.find(name) self.legacy_cls = legacy_cls or self._find_legacy_cls() @@ -97,14 +111,11 @@ def __init__( if new_config is None: new_config = self.new_raw_dataset.default_config elif isinstance(new_config, dict): - new_config = new_datasets.utils.DatasetConfig(new_config) + new_config = self.new_raw_dataset.info.make_config(**new_config) self.new_config = new_config - self.legacy_config = (legacy_config_map or dict)(copy.copy(new_config)) - - self.legacy_special_options = (legacy_special_options_map or self._legacy_special_options_map)( - copy.copy(new_config) - ) + self.legacy_config_map = legacy_config_map + self.legacy_special_options_map = legacy_special_options_map or self._legacy_special_options_map self.prepare_legacy_root = prepare_legacy_root def new_dataset(self, *, num_workers=0): @@ -142,12 +153,15 @@ def context_manager(): return context_manager() def legacy_dataset(self, root, *, num_workers=0, download=None): - special_options = self.legacy_special_options.copy() + legacy_config = self.legacy_config_map(self, root) if self.legacy_config_map else dict() + + special_options = self.legacy_special_options_map(self) if "download" in special_options and download is not None: special_options["download"] = download + with self.suppress_output(): return DataLoader( - self.legacy_cls(str(root), **self.legacy_config, **special_options), + self.legacy_cls(legacy_config.pop("root", str(root)), **legacy_config, **special_options), shuffle=True, num_workers=num_workers, ) @@ -260,16 +274,17 @@ def _find_legacy_cls(self): "download", } - def _legacy_special_options_map(self, config): + @staticmethod + def _legacy_special_options_map(benchmark): available_parameters = set() - for cls in self.legacy_cls.__mro__: + for cls in benchmark.legacy_cls.__mro__: if cls is legacy_datasets.VisionDataset: break available_parameters.update(inspect.signature(cls.__init__).parameters) - available_special_kwargs = self._SPECIAL_KWARGS.intersection(available_parameters) + available_special_kwargs = benchmark._SPECIAL_KWARGS.intersection(available_parameters) special_options = dict() @@ -345,15 +360,15 @@ def _compute_mean_and_std(cls, t): return mean, std -def no_split(config): - legacy_config = dict(config) +def no_split(benchmark, root): + legacy_config = dict(benchmark.new_config) del legacy_config["split"] return legacy_config def bool_split(name="train"): - def legacy_config_map(config): - legacy_config = dict(config) + def legacy_config_map(benchmark, root): + legacy_config = dict(benchmark.new_config) legacy_config[name] = legacy_config.pop("split") == "train" return legacy_config @@ -400,8 +415,8 @@ def __call__(self, *inputs): return tuple(transform(input) for transform, input in zip(self.transforms, inputs)) -def caltech101_legacy_config_map(config): - legacy_config = no_split(config) +def caltech101_legacy_config_map(benchmark, root): + legacy_config = no_split(benchmark, root) # The new dataset always returns the category and annotation legacy_config["target_type"] = ("category", "annotation") return legacy_config @@ -410,8 +425,8 @@ def caltech101_legacy_config_map(config): mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw") -def mnist_legacy_config_map(config): - return dict(train=config.split == "train") +def mnist_legacy_config_map(benchmark, root): + return dict(train=benchmark.new_config.split == "train") def emnist_prepare_legacy_root(benchmark, root): @@ -420,20 +435,36 @@ def emnist_prepare_legacy_root(benchmark, root): return folder -def emnist_legacy_config_map(config): - legacy_config = mnist_legacy_config_map(config) - legacy_config["split"] = config.image_set.replace("_", "").lower() +def emnist_legacy_config_map(benchmark, root): + legacy_config = mnist_legacy_config_map(benchmark, root) + legacy_config["split"] = benchmark.new_config.image_set.replace("_", "").lower() return legacy_config -def qmnist_legacy_config_map(config): - legacy_config = mnist_legacy_config_map(config) - legacy_config["what"] = config.split +def qmnist_legacy_config_map(benchmark, root): + legacy_config = mnist_legacy_config_map(benchmark, root) + legacy_config["what"] = benchmark.new_config.split # The new dataset always returns the full label legacy_config["compat"] = False return legacy_config +def coco_legacy_config_map(benchmark, root): + images, _ = benchmark.new_raw_dataset.resources(benchmark.new_config) + return dict( + root=str(root / pathlib.Path(images.file_name).stem), + annFile=str( + root / "annotations" / f"{benchmark.variant}_{benchmark.new_config.split}{benchmark.new_config.year}.json" + ), + ) + + +def coco_prepare_legacy_root(benchmark, root): + images, annotations = benchmark.new_raw_dataset.resources(benchmark.new_config) + extract_archive(str(root / images.file_name)) + extract_archive(str(root / annotations.file_name)) + + DATASET_BENCHMARKS = [ DatasetBenchmark( "caltech101", @@ -453,8 +484,8 @@ def qmnist_legacy_config_map(config): DatasetBenchmark( "celeba", prepare_legacy_root=base_folder(), - legacy_config_map=lambda config: dict( - split="valid" if config.split == "val" else config.split, + legacy_config_map=lambda benchmark: dict( + split="valid" if benchmark.new_config.split == "val" else benchmark.new_config.split, # The new dataset always returns all annotations target_type=("attr", "identity", "bbox", "landmarks"), ), @@ -495,17 +526,37 @@ def qmnist_legacy_config_map(config): DatasetBenchmark( "sbd", legacy_cls=legacy_datasets.SBDataset, - legacy_config_map=lambda config: dict( - image_set=config.split, - mode="boundaries" if config.boundaries else "segmentation", + legacy_config_map=lambda benchmark: dict( + image_set=benchmark.new_config.split, + mode="boundaries" if benchmark.new_config.boundaries else "segmentation", ), - legacy_special_options_map=lambda config: dict( + legacy_special_options_map=lambda benchmark: dict( download=True, - transforms=JointTransform(PILToTensor(), torch.tensor if config.boundaries else PILToTensor()), + transforms=JointTransform( + PILToTensor(), torch.tensor if benchmark.new_config.boundaries else PILToTensor() + ), ), ), DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection), DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet), + DatasetBenchmark( + "coco", + variant="instances", + legacy_cls=legacy_datasets.CocoDetection, + new_config=dict(split="train", annotations="instances"), + legacy_config_map=coco_legacy_config_map, + prepare_legacy_root=coco_prepare_legacy_root, + legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None), + ), + DatasetBenchmark( + "coco", + variant="captions", + legacy_cls=legacy_datasets.CocoCaptions, + new_config=dict(split="train", annotations="captions"), + legacy_config_map=coco_legacy_config_map, + prepare_legacy_root=coco_prepare_legacy_root, + legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None), + ), ] @@ -517,6 +568,9 @@ def parse_args(argv=None): ) parser.add_argument("name", help="Name of the dataset to benchmark.") + parser.add_argument( + "--variant", help="Variant of the dataset. If omitted all available variants will be benchmarked." + ) parser.add_argument( "-n", @@ -591,6 +645,7 @@ def parse_args(argv=None): try: main( args.name, + variant=args.variant, legacy=args.legacy, new=args.new, start=args.start, diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 3db10183f68..c4b91b4a14b 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -8,6 +8,7 @@ import os.path import pathlib import pickle +import platform from typing import BinaryIO from typing import ( Sequence, @@ -260,6 +261,11 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe: return dp +def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: + # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable + return bytearray(file.read(-1 if count == -1 else count * item_size)) + + def fromfile( file: BinaryIO, *, @@ -293,20 +299,24 @@ def fromfile( item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 np_dtype = byte_order + char + str(item_size) - # PyTorch does not support tensors with underlying read-only memory. In case - # - the file has a .fileno(), - # - the file was opened for updating, i.e. 'r+b' or 'w+b', - # - the file is seekable - # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to - # a mutable location afterwards. buffer: Union[memoryview, bytearray] - try: - buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] - # Reading from the memoryview does not advance the file cursor, so we have to do it manually. - file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) - except (PermissionError, io.UnsupportedOperation): - # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable - buffer = bytearray(file.read(-1 if count == -1 else count * item_size)) + if platform.system() != "Windows": + # PyTorch does not support tensors with underlying read-only memory. In case + # - the file has a .fileno(), + # - the file was opened for updating, i.e. 'r+b' or 'w+b', + # - the file is seekable + # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it + # to a mutable location afterwards. + try: + buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] + # Reading from the memoryview does not advance the file cursor, so we have to do it manually. + file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) + except (PermissionError, io.UnsupportedOperation): + buffer = _read_mutable_buffer_fallback(file, count, item_size) + else: + # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state + # so no data can be read afterwards. Thus, we simply ignore the possible speed-up. + buffer = _read_mutable_buffer_fallback(file, count, item_size) # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index e225bf3df83..64ba449ae76 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -118,7 +118,7 @@ def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> T if data.dtype.is_floating_point: w = w.ceil() h = h.ceil() - return int(h), int(w) + return int(h.max()), int(w.max()) @classmethod def from_parts( diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 81adea2ed82..cd52f1f80ad 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -12,7 +12,7 @@ class Feature(torch.Tensor): - _META_ATTRS: Set[str] + _META_ATTRS: Set[str] = set() _meta_data: Dict[str, Any] def __init_subclass__(cls): diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index f675dc37f25..12a4738e53c 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -15,3 +15,4 @@ from . import quantization from . import segmentation from . import video +from ._api import get_weight diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 2935039e087..1f66fd2be45 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -1,7 +1,9 @@ +import importlib +import inspect +import sys from collections import OrderedDict from dataclasses import dataclass, fields from enum import Enum -from inspect import signature from typing import Any, Callable, Dict from ..._internally_replaced_utils import load_state_dict_from_url @@ -30,7 +32,6 @@ class Weights: url: str transforms: Callable meta: Dict[str, Any] - default: bool class WeightsEnum(Enum): @@ -50,7 +51,7 @@ def __init__(self, value: Weights): def verify(cls, obj: Any) -> Any: if obj is not None: if type(obj) is str: - obj = cls.from_str(obj) + obj = cls.from_str(obj.replace(cls.__name__ + ".", "")) elif not isinstance(obj, cls): raise TypeError( f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." @@ -59,8 +60,8 @@ def verify(cls, obj: Any) -> Any: @classmethod def from_str(cls, value: str) -> "WeightsEnum": - for v in cls: - if v._name_ == value or (value == "default" and v.default): + for k, v in cls.__members__.items(): + if k == value: return v raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") @@ -78,41 +79,35 @@ def __getattr__(self, name): return super().__getattr__(name) -def get_weight(fn: Callable, weight_name: str) -> WeightsEnum: +def get_weight(name: str) -> WeightsEnum: """ - Gets the weight enum of a specific model builder method and weight name combination. + Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1" Args: - fn (Callable): The builder method used to create the model. - weight_name (str): The name of the weight enum entry of the specific model. + name (str): The name of the weight enum entry. Returns: WeightsEnum: The requested weight enum. """ - sig = signature(fn) - if "weights" not in sig.parameters: - raise ValueError("The method is missing the 'weights' parameter.") + try: + enum_name, value_name = name.split(".") + except ValueError: + raise ValueError(f"Invalid weight name provided: '{name}'.") + + base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1]) + base_module = importlib.import_module(base_module_name) + model_modules = [base_module] + [ + x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py") + ] - ann = signature(fn).parameters["weights"].annotation weights_enum = None - if isinstance(ann, type) and issubclass(ann, WeightsEnum): - weights_enum = ann - else: - # handle cases like Union[Optional, T] - # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 - for t in ann.__args__: # type: ignore[union-attr] - if isinstance(t, type) and issubclass(t, WeightsEnum): - # ensure the name exists. handles builders with multiple types of weights like in quantization - try: - t.from_str(weight_name) - except ValueError: - continue - weights_enum = t - break + for m in model_modules: + potential_class = m.__dict__.get(enum_name, None) + if potential_class is not None and issubclass(potential_class, WeightsEnum): + weights_enum = potential_class + break if weights_enum is None: - raise ValueError( - "The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct." - ) + raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") - return weights_enum.from_str(weight_name) + return weights_enum.from_str(value_name) diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index b45ca1e7085..28b0fa60504 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -25,8 +25,8 @@ class AlexNet_Weights(WeightsEnum): "acc@1": 56.522, "acc@5": 79.066, }, - default=True, ) + default = ImageNet1K_V1 def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index e779a2cd239..b8abbdde947 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -80,8 +80,8 @@ class DenseNet121_Weights(WeightsEnum): "acc@1": 74.434, "acc@5": 91.972, }, - default=True, ) + default = ImageNet1K_V1 class DenseNet161_Weights(WeightsEnum): @@ -93,8 +93,8 @@ class DenseNet161_Weights(WeightsEnum): "acc@1": 77.138, "acc@5": 93.560, }, - default=True, ) + default = ImageNet1K_V1 class DenseNet169_Weights(WeightsEnum): @@ -106,8 +106,8 @@ class DenseNet169_Weights(WeightsEnum): "acc@1": 75.600, "acc@5": 92.806, }, - default=True, ) + default = ImageNet1K_V1 class DenseNet201_Weights(WeightsEnum): @@ -119,8 +119,8 @@ class DenseNet201_Weights(WeightsEnum): "acc@1": 76.896, "acc@5": 93.370, }, - default=True, ) + default = ImageNet1K_V1 def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index c83aaf222fb..1f5c6461698 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -45,8 +45,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "map": 37.0, }, - default=True, ) + default = Coco_V1 class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): @@ -58,8 +58,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", "map": 32.8, }, - default=True, ) + default = Coco_V1 class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): @@ -71,8 +71,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", "map": 22.8, }, - default=True, ) + default = Coco_V1 def fasterrcnn_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 85250ac2a33..a811999681d 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -35,7 +35,6 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 50.6, "kp_map": 61.1, }, - default=False, ) Coco_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", @@ -46,8 +45,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 54.6, "kp_map": 65.0, }, - default=True, ) + default = Coco_V1 def keypointrcnn_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index ea7ab4f5fc7..4eb285fac0d 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -34,8 +34,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 37.9, "mask_map": 34.6, }, - default=True, ) + default = Coco_V1 def maskrcnn_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index d442c79d5b6..799bc21c379 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -34,8 +34,8 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "map": 36.4, }, - default=True, ) + default = Coco_V1 def retinanet_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 37f5c2a6944..f57b47c00d6 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -33,8 +33,8 @@ class SSD300_VGG16_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", "map": 25.1, }, - default=True, ) + default = Coco_V1 def ssd300_vgg16( diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 309362f2f11..4a61c50101a 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -38,8 +38,8 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", "map": 21.3, }, - default=True, ) + default = Coco_V1 def ssdlite320_mobilenet_v3_large( diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 74ca6ccc71d..f4a69aac70c 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -79,8 +79,8 @@ class EfficientNet_B0_Weights(WeightsEnum): "acc@1": 77.692, "acc@5": 93.532, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B1_Weights(WeightsEnum): @@ -93,8 +93,8 @@ class EfficientNet_B1_Weights(WeightsEnum): "acc@1": 78.642, "acc@5": 94.186, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B2_Weights(WeightsEnum): @@ -107,8 +107,8 @@ class EfficientNet_B2_Weights(WeightsEnum): "acc@1": 80.608, "acc@5": 95.310, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B3_Weights(WeightsEnum): @@ -121,8 +121,8 @@ class EfficientNet_B3_Weights(WeightsEnum): "acc@1": 82.008, "acc@5": 96.054, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B4_Weights(WeightsEnum): @@ -135,8 +135,8 @@ class EfficientNet_B4_Weights(WeightsEnum): "acc@1": 83.384, "acc@5": 96.594, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B5_Weights(WeightsEnum): @@ -149,8 +149,8 @@ class EfficientNet_B5_Weights(WeightsEnum): "acc@1": 83.444, "acc@5": 96.628, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B6_Weights(WeightsEnum): @@ -163,8 +163,8 @@ class EfficientNet_B6_Weights(WeightsEnum): "acc@1": 84.008, "acc@5": 96.916, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B7_Weights(WeightsEnum): @@ -177,8 +177,8 @@ class EfficientNet_B7_Weights(WeightsEnum): "acc@1": 84.122, "acc@5": 96.908, }, - default=True, ) + default = ImageNet1K_V1 def efficientnet_b0( diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 352c49d1a2e..f62c5a96e15 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -26,8 +26,8 @@ class GoogLeNet_Weights(WeightsEnum): "acc@1": 69.778, "acc@5": 89.530, }, - default=True, ) + default = ImageNet1K_V1 def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index 9837b1fc4a6..4814fa76c5c 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -25,8 +25,8 @@ class Inception_V3_Weights(WeightsEnum): "acc@1": 77.294, "acc@5": 93.450, }, - default=True, ) + default = ImageNet1K_V1 def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index 73aaea0beca..554057a9ba1 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -40,8 +40,8 @@ class MNASNet0_5_Weights(WeightsEnum): "acc@1": 67.734, "acc@5": 87.490, }, - default=True, ) + default = ImageNet1K_V1 class MNASNet0_75_Weights(WeightsEnum): @@ -58,8 +58,8 @@ class MNASNet1_0_Weights(WeightsEnum): "acc@1": 73.456, "acc@5": 91.510, }, - default=True, ) + default = ImageNet1K_V1 class MNASNet1_3_Weights(WeightsEnum): diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 0c0f80d081a..64c7221da6d 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -25,8 +25,8 @@ class MobileNet_V2_Weights(WeightsEnum): "acc@1": 71.878, "acc@5": 90.286, }, - default=True, ) + default = ImageNet1K_V1 def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index e014fb5acb2..a92c7667aab 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -54,7 +54,6 @@ class MobileNet_V3_Large_Weights(WeightsEnum): "acc@1": 74.042, "acc@5": 91.340, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", @@ -65,8 +64,8 @@ class MobileNet_V3_Large_Weights(WeightsEnum): "acc@1": 75.274, "acc@5": 92.566, }, - default=True, ) + default = ImageNet1K_V2 class MobileNet_V3_Small_Weights(WeightsEnum): @@ -79,8 +78,8 @@ class MobileNet_V3_Small_Weights(WeightsEnum): "acc@1": 67.668, "acc@5": 87.402, }, - default=True, ) + default = ImageNet1K_V1 def mobilenet_v3_large( diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 3d26fd7d607..dc3c875b79a 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -38,8 +38,8 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): "acc@1": 69.826, "acc@5": 89.404, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 def googlenet( diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index ff779076df6..d1d5d4ca8fe 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -37,8 +37,8 @@ class Inception_V3_QuantizedWeights(WeightsEnum): "acc@1": 77.176, "acc@5": 93.354, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 def inception_v3( diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index c5afd731fad..81540f2f840 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -38,8 +38,8 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): "acc@1": 71.658, "acc@5": 90.150, }, - default=True, ) + default = ImageNet1K_QNNPACK_V1 def mobilenet_v2( diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index a29e3f44697..9d29484c18f 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -71,8 +71,8 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): "acc@1": 73.004, "acc@5": 90.858, }, - default=True, ) + default = ImageNet1K_QNNPACK_V1 def mobilenet_v3_large( diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 0de4eb5557b..c6bd530f393 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -73,8 +73,8 @@ class ResNet18_QuantizedWeights(WeightsEnum): "acc@1": 69.494, "acc@5": 88.882, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 class ResNet50_QuantizedWeights(WeightsEnum): @@ -87,7 +87,6 @@ class ResNet50_QuantizedWeights(WeightsEnum): "acc@1": 75.920, "acc@5": 92.814, }, - default=False, ) ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", @@ -98,8 +97,8 @@ class ResNet50_QuantizedWeights(WeightsEnum): "acc@1": 80.282, "acc@5": 94.976, }, - default=True, ) + default = ImageNet1K_FBGEMM_V2 class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): @@ -112,7 +111,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): "acc@1": 78.986, "acc@5": 94.480, }, - default=False, ) ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", @@ -123,8 +121,8 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): "acc@1": 82.574, "acc@5": 96.132, }, - default=True, ) + default = ImageNet1K_FBGEMM_V2 def resnet18( diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index 6677983a1d9..111763f2614 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -69,8 +69,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): "acc@1": 57.972, "acc@5": 79.780, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): @@ -83,8 +83,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): "acc@1": 68.360, "acc@5": 87.582, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 def shufflenet_v2_x0_5( diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index 1e12ae7bbd2..d810a0d1300 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -74,8 +74,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum): "acc@1": 74.046, "acc@5": 91.716, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_800MF_Weights(WeightsEnum): @@ -88,8 +88,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum): "acc@1": 76.420, "acc@5": 93.136, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_1_6GF_Weights(WeightsEnum): @@ -102,8 +102,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): "acc@1": 77.950, "acc@5": 93.966, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_3_2GF_Weights(WeightsEnum): @@ -116,8 +116,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): "acc@1": 78.948, "acc@5": 94.576, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_8GF_Weights(WeightsEnum): @@ -130,8 +130,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum): "acc@1": 80.032, "acc@5": 95.048, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_16GF_Weights(WeightsEnum): @@ -144,8 +144,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): "acc@1": 80.424, "acc@5": 95.240, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_32GF_Weights(WeightsEnum): @@ -158,8 +158,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): "acc@1": 80.878, "acc@5": 95.340, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_400MF_Weights(WeightsEnum): @@ -172,8 +172,8 @@ class RegNet_X_400MF_Weights(WeightsEnum): "acc@1": 72.834, "acc@5": 90.950, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_800MF_Weights(WeightsEnum): @@ -186,8 +186,8 @@ class RegNet_X_800MF_Weights(WeightsEnum): "acc@1": 75.212, "acc@5": 92.348, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_1_6GF_Weights(WeightsEnum): @@ -200,8 +200,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): "acc@1": 77.040, "acc@5": 93.440, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_3_2GF_Weights(WeightsEnum): @@ -214,8 +214,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): "acc@1": 78.364, "acc@5": 93.992, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_8GF_Weights(WeightsEnum): @@ -228,8 +228,8 @@ class RegNet_X_8GF_Weights(WeightsEnum): "acc@1": 79.344, "acc@5": 94.686, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_16GF_Weights(WeightsEnum): @@ -242,8 +242,8 @@ class RegNet_X_16GF_Weights(WeightsEnum): "acc@1": 80.058, "acc@5": 94.944, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_32GF_Weights(WeightsEnum): @@ -256,8 +256,8 @@ class RegNet_X_32GF_Weights(WeightsEnum): "acc@1": 80.622, "acc@5": 95.248, }, - default=True, ) + default = ImageNet1K_V1 def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index e213864acbe..3c68f0a430c 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -64,8 +64,8 @@ class ResNet18_Weights(WeightsEnum): "acc@1": 69.758, "acc@5": 89.078, }, - default=True, ) + default = ImageNet1K_V1 class ResNet34_Weights(WeightsEnum): @@ -78,8 +78,8 @@ class ResNet34_Weights(WeightsEnum): "acc@1": 73.314, "acc@5": 91.420, }, - default=True, ) + default = ImageNet1K_V1 class ResNet50_Weights(WeightsEnum): @@ -92,7 +92,6 @@ class ResNet50_Weights(WeightsEnum): "acc@1": 76.130, "acc@5": 92.862, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", @@ -103,8 +102,8 @@ class ResNet50_Weights(WeightsEnum): "acc@1": 80.674, "acc@5": 95.166, }, - default=True, ) + default = ImageNet1K_V2 class ResNet101_Weights(WeightsEnum): @@ -117,7 +116,6 @@ class ResNet101_Weights(WeightsEnum): "acc@1": 77.374, "acc@5": 93.546, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", @@ -128,8 +126,8 @@ class ResNet101_Weights(WeightsEnum): "acc@1": 81.886, "acc@5": 95.780, }, - default=True, ) + default = ImageNet1K_V2 class ResNet152_Weights(WeightsEnum): @@ -142,7 +140,6 @@ class ResNet152_Weights(WeightsEnum): "acc@1": 78.312, "acc@5": 94.046, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", @@ -153,8 +150,8 @@ class ResNet152_Weights(WeightsEnum): "acc@1": 82.284, "acc@5": 96.002, }, - default=True, ) + default = ImageNet1K_V2 class ResNeXt50_32X4D_Weights(WeightsEnum): @@ -167,7 +164,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): "acc@1": 77.618, "acc@5": 93.698, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", @@ -178,8 +174,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): "acc@1": 81.198, "acc@5": 95.340, }, - default=True, ) + default = ImageNet1K_V2 class ResNeXt101_32X8D_Weights(WeightsEnum): @@ -192,7 +188,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): "acc@1": 79.312, "acc@5": 94.526, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", @@ -203,8 +198,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): "acc@1": 82.834, "acc@5": 96.228, }, - default=True, ) + default = ImageNet1K_V2 class Wide_ResNet50_2_Weights(WeightsEnum): @@ -217,7 +212,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum): "acc@1": 78.468, "acc@5": 94.086, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", @@ -228,8 +222,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum): "acc@1": 81.602, "acc@5": 95.758, }, - default=True, ) + default = ImageNet1K_V2 class Wide_ResNet101_2_Weights(WeightsEnum): @@ -242,7 +236,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum): "acc@1": 78.848, "acc@5": 94.284, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", @@ -253,8 +246,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum): "acc@1": 82.510, "acc@5": 96.020, }, - default=True, ) + default = ImageNet1K_V2 def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 638b771c333..30c90013c9b 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -40,8 +40,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): "mIoU": 66.4, "acc": 92.4, }, - default=True, ) + default = CocoWithVocLabels_V1 class DeepLabV3_ResNet101_Weights(WeightsEnum): @@ -54,8 +54,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): "mIoU": 67.4, "acc": 92.4, }, - default=True, ) + default = CocoWithVocLabels_V1 class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): @@ -68,8 +68,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): "mIoU": 60.3, "acc": 91.2, }, - default=True, ) + default = CocoWithVocLabels_V1 def deeplabv3_resnet50( diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 841e2ea95c5..42d15a0c3cf 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -30,8 +30,8 @@ class FCN_ResNet50_Weights(WeightsEnum): "mIoU": 60.5, "acc": 91.4, }, - default=True, ) + default = CocoWithVocLabels_V1 class FCN_ResNet101_Weights(WeightsEnum): @@ -44,8 +44,8 @@ class FCN_ResNet101_Weights(WeightsEnum): "mIoU": 63.7, "acc": 91.9, }, - default=True, ) + default = CocoWithVocLabels_V1 def fcn_resnet50( diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 9743e02fa16..f80e1079c87 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -25,8 +25,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): "mIoU": 57.9, "acc": 91.2, }, - default=True, ) + default = CocoWithVocLabels_V1 def lraspp_mobilenet_v3_large( diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 9fa98c44223..a8857c2996e 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -57,8 +57,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): "acc@1": 69.362, "acc@5": 88.316, }, - default=True, ) + default = ImageNet1K_V1 class ShuffleNet_V2_X1_0_Weights(WeightsEnum): @@ -70,8 +70,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum): "acc@1": 60.552, "acc@5": 81.746, }, - default=True, ) + default = ImageNet1K_V1 class ShuffleNet_V2_X1_5_Weights(WeightsEnum): diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index fdfaa01e8be..77c9a1629d4 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -30,8 +30,8 @@ class SqueezeNet1_0_Weights(WeightsEnum): "acc@1": 58.092, "acc@5": 80.420, }, - default=True, ) + default = ImageNet1K_V1 class SqueezeNet1_1_Weights(WeightsEnum): @@ -43,8 +43,8 @@ class SqueezeNet1_1_Weights(WeightsEnum): "acc@1": 58.178, "acc@5": 80.624, }, - default=True, ) + default = ImageNet1K_V1 def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index a357426693d..708608826e0 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -57,8 +57,8 @@ class VGG11_Weights(WeightsEnum): "acc@1": 69.020, "acc@5": 88.628, }, - default=True, ) + default = ImageNet1K_V1 class VGG11_BN_Weights(WeightsEnum): @@ -70,8 +70,8 @@ class VGG11_BN_Weights(WeightsEnum): "acc@1": 70.370, "acc@5": 89.810, }, - default=True, ) + default = ImageNet1K_V1 class VGG13_Weights(WeightsEnum): @@ -83,8 +83,8 @@ class VGG13_Weights(WeightsEnum): "acc@1": 69.928, "acc@5": 89.246, }, - default=True, ) + default = ImageNet1K_V1 class VGG13_BN_Weights(WeightsEnum): @@ -96,8 +96,8 @@ class VGG13_BN_Weights(WeightsEnum): "acc@1": 71.586, "acc@5": 90.374, }, - default=True, ) + default = ImageNet1K_V1 class VGG16_Weights(WeightsEnum): @@ -109,7 +109,6 @@ class VGG16_Weights(WeightsEnum): "acc@1": 71.592, "acc@5": 90.382, }, - default=True, ) # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Only the `features` weights have proper values, those on the @@ -127,8 +126,8 @@ class VGG16_Weights(WeightsEnum): "acc@1": float("nan"), "acc@5": float("nan"), }, - default=False, ) + default = ImageNet1K_V1 class VGG16_BN_Weights(WeightsEnum): @@ -140,8 +139,8 @@ class VGG16_BN_Weights(WeightsEnum): "acc@1": 73.360, "acc@5": 91.516, }, - default=True, ) + default = ImageNet1K_V1 class VGG19_Weights(WeightsEnum): @@ -153,8 +152,8 @@ class VGG19_Weights(WeightsEnum): "acc@1": 72.376, "acc@5": 90.876, }, - default=True, ) + default = ImageNet1K_V1 class VGG19_BN_Weights(WeightsEnum): @@ -166,8 +165,8 @@ class VGG19_BN_Weights(WeightsEnum): "acc@1": 74.218, "acc@5": 91.842, }, - default=True, ) + default = ImageNet1K_V1 def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index c75f618a8b1..48c4293f0e1 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -68,8 +68,8 @@ class R3D_18_Weights(WeightsEnum): "acc@1": 52.75, "acc@5": 75.45, }, - default=True, ) + default = Kinetics400_V1 class MC3_18_Weights(WeightsEnum): @@ -81,8 +81,8 @@ class MC3_18_Weights(WeightsEnum): "acc@1": 53.90, "acc@5": 76.29, }, - default=True, ) + default = Kinetics400_V1 class R2Plus1D_18_Weights(WeightsEnum): @@ -94,8 +94,8 @@ class R2Plus1D_18_Weights(WeightsEnum): "acc@1": 57.50, "acc@5": 78.81, }, - default=True, ) + default = Kinetics400_V1 def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 9fd07af1e77..8062ff0fad0 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -360,7 +360,13 @@ def _transform_recursively(self, sample: Any, *, params: Dict[str, Any]) -> Any: else: feature_type = type(sample) if not self.supports(feature_type): - if not issubclass(feature_type, features.Feature) or feature_type in self.NO_OP_FEATURE_TYPES: + if ( + not issubclass(feature_type, features.Feature) + # issubclass is not a strict check, but also allows the type checked against. Thus, we need to + # check it separately + or feature_type is features.Feature + or feature_type in self.NO_OP_FEATURE_TYPES + ): return sample raise TypeError(