diff --git a/.coveragerc b/.coveragerc index 4bbac7b27d..5272237caf 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,13 +1,15 @@ [run] branch=True source=trio -# For some reason coverage recording doesn't work for ipython_custom_exc.py, -# so leave it out of reports omit= setup.py - */ipython_custom_exc.py -# Omit the generated files in trio/_core starting with _public_ +# These are run in subprocesses, but still don't work. We follow +# coverage's documentation to no avail. + */trio/_core/_tests/test_multierror_scripts/* +# Omit the generated files in trio/_core starting with _generated_ */trio/_core/_generated_* +# Script used to check type completeness that isn't run in tests + */trio/_tests/check_type_completeness.py # The test suite spawns subprocesses to test some stuff, so make sure # this doesn't corrupt the coverage files parallel=True @@ -19,10 +21,15 @@ exclude_lines = abc.abstractmethod if TYPE_CHECKING: if _t.TYPE_CHECKING: + if t.TYPE_CHECKING: + @overload + class .*\bProtocol\b.*\): partial_branches = pragma: no branch if not TYPE_CHECKING: if not _t.TYPE_CHECKING: + if not t.TYPE_CHECKING: if .* or not TYPE_CHECKING: if .* or not _t.TYPE_CHECKING: + if .* or not t.TYPE_CHECKING: diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000..1d3079ad5a --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# sorting all imports with isort +933f77b96f0092e1baab4474a9208fc2e379aa32 diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index 0c2930b120..0000000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,36 +0,0 @@ -version: 2 -updates: -- package-ecosystem: pip - directory: "/" - schedule: - interval: daily - open-pull-requests-limit: 10 - allow: - - dependency-type: direct - - dependency-type: indirect - ignore: - - dependency-name: pytest - versions: - - ">= 4.6.1.a, < 4.6.2" - - dependency-name: astroid - versions: - - 2.5.2 - - dependency-name: sphinx - versions: - - 3.4.3 - - 3.5.0 - - 3.5.1 - - 3.5.2 - - 3.5.3 - - dependency-name: regex - versions: - - 2021.3.17 - - dependency-name: pygments - versions: - - 2.8.0 - - dependency-name: cryptography - versions: - - 3.4.5 - - dependency-name: pytest - versions: - - 6.2.2 diff --git a/.github/workflows/autodeps.yml b/.github/workflows/autodeps.yml new file mode 100644 index 0000000000..40cf05726c --- /dev/null +++ b/.github/workflows/autodeps.yml @@ -0,0 +1,82 @@ +name: Autodeps + +on: + workflow_dispatch: + schedule: + - cron: '0 0 1 * *' + +jobs: + Autodeps: + name: Autodeps + timeout-minutes: 10 + runs-on: 'ubuntu-latest' + # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/automating-dependabot-with-github-actions#changing-github_token-permissions + permissions: + pull-requests: write + issues: write + repository-projects: write + contents: write + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + - name: Bump dependencies + run: | + python -m pip install -U pip + python -m pip install -r test-requirements.txt + pip-compile -U test-requirements.in + pip-compile -U docs-requirements.in + - name: Black + run: | + # The new dependencies may contain a new black version. + # Commit any changes immediately. + python -m pip install -r test-requirements.txt + black setup.py trio + - name: Commit changes and create automerge PR + env: + GH_TOKEN: ${{ github.token }} + run: | + # setup git repo + git switch --force-create autodeps/bump_from_${GITHUB_SHA:0:6} + git config user.name 'github-actions[bot]' + git config user.email '41898282+github-actions[bot]@users.noreply.github.com' + + if ! git commit -am "Dependency updates"; then + echo "No changes to commit!" + exit 0 + fi + + git push --force --set-upstream origin autodeps/bump_from_${GITHUB_SHA:0:6} + + # git push returns before github is ready for a pr, so we poll until success + for BACKOFF in 1 2 4 8 0; do + sleep $BACKOFF + if gh pr create \ + --label dependencies --body "" \ + --title "Bump dependencies from commit ${GITHUB_SHA:0:6}" \ + ; then + break + fi + done + + if [ $BACKOFF -eq 0 ]; then + echo "Could not create the PR" + exit 1 + fi + + # gh pr create returns before the pr is ready, so we again poll until success + # https://github.com/cli/cli/issues/2619#issuecomment-1240543096 + for BACKOFF in 1 2 4 8 0; do + sleep $BACKOFF + if gh pr merge --auto --squash; then + break + fi + done + + if [ $BACKOFF -eq 0 ]; then + echo "Could not set automerge" + exit 1 + fi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ebc22d6850..d5aeb3ec04 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,10 @@ on: - "dependabot/**" pull_request: +concurrency: + group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) && format('-{0}', github.sha) || '' }} + cancel-in-progress: true + jobs: Windows: name: 'Windows (${{ matrix.python }}, ${{ matrix.arch }}${{ matrix.extra_name }})' @@ -14,11 +18,18 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.7', '3.8', '3.9', '3.10'] + # pypy-3.10 is failing, see https://github.com/python-trio/trio/issues/2678 + python: ['3.8', '3.9', '3.10', 'pypy-3.9-nightly'] #, 'pypy-3.10-nightly'] arch: ['x86', 'x64'] lsp: [''] lsp_extract_file: [''] extra_name: [''] + exclude: + # pypy does not release 32-bit binaries + - python: 'pypy-3.9-nightly' + arch: 'x86' + #- python: 'pypy-3.10-nightly' + # arch: 'x86' include: - python: '3.8' arch: 'x64' @@ -35,16 +46,20 @@ jobs: # lsp: 'http://download.pctools.com/mirror/updates/9.0.0.2308-SDavfree-lite_en.exe' # lsp_extract_file: '' # extra_name: ', with non-IFS LSP' - - python: '3.8' # <- not actually used - arch: 'x64' - pypy_nightly_branch: 'py3.8' - extra_name: ', pypy 3.8 nightly' - + continue-on-error: >- + ${{ + ( + endsWith(matrix.python, '-dev') + || endsWith(matrix.python, '-nightly') + ) + && true + || false + }} steps: - name: Checkout uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: # This allows the matrix to specify just the major.minor version while still # expanding it to get the latest patch version including alpha releases. @@ -52,8 +67,8 @@ jobs: # and then finally an actual release version. actions/setup-python doesn't # support this for PyPy presently so we get no help there. # - # CPython -> 3.9.0-alpha - 3.9.X - # PyPy -> pypy-3.7 + # 'CPython' -> '3.9.0-alpha - 3.9.X' + # 'PyPy' -> 'pypy-3.9' python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} architecture: '${{ matrix.arch }}' cache: pip @@ -64,8 +79,13 @@ jobs: env: LSP: '${{ matrix.lsp }}' LSP_EXTRACT_FILE: '${{ matrix.lsp_extract_file }}' - # Should match 'name:' up above - JOB_NAME: 'Windows (${{ matrix.python }}, ${{ matrix.arch }}${{ matrix.extra_name }})' + - if: always() + uses: codecov/codecov-action@v3 + with: + directory: empty + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + name: Windows (${{ matrix.python }}, ${{ matrix.arch }}${{ matrix.extra_name }}) + flags: Windows,${{ matrix.python }} Ubuntu: name: 'Ubuntu (${{ matrix.python }}${{ matrix.extra_name }})' @@ -74,29 +94,18 @@ jobs: strategy: fail-fast: false matrix: - python: ['pypy-3.7', 'pypy-3.8', 'pypy-3.9', '3.7', '3.8', '3.9', '3.10', '3.11', '3.12-dev'] + python: ['pypy-3.9', 'pypy-3.10', '3.8', '3.9', '3.10', '3.11', '3.12-dev', 'pypy-3.9-nightly', 'pypy-3.10-nightly'] check_formatting: ['0'] - pypy_nightly_branch: [''] extra_name: [''] include: - python: '3.8' check_formatting: '1' extra_name: ', check formatting' - - python: '3.7' # <- not actually used - pypy_nightly_branch: 'py3.7' - extra_name: ', pypy 3.7 nightly' - - python: '3.8' # <- not actually used - pypy_nightly_branch: 'py3.8' - extra_name: ', pypy 3.8 nightly' - - python: '3.9' # <- not actually used - pypy_nightly_branch: 'py3.9' - extra_name: ', pypy 3.9 nightly' continue-on-error: >- ${{ ( - matrix.check_formatting == '1' - || matrix.pypy_nightly_branch == 'py3.7' - || endsWith(matrix.python, '-dev') + endsWith(matrix.python, '-dev') + || endsWith(matrix.python, '-nightly') ) && true || false @@ -105,7 +114,7 @@ jobs: - name: Checkout uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 if: "!endsWith(matrix.python, '-dev')" with: python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} @@ -119,71 +128,50 @@ jobs: - name: Run tests run: ./ci.sh env: - PYPY_NIGHTLY_BRANCH: '${{ matrix.pypy_nightly_branch }}' CHECK_FORMATTING: '${{ matrix.check_formatting }}' - # Should match 'name:' up above - JOB_NAME: 'Ubuntu (${{ matrix.python }}${{ matrix.extra_name }})' - - autofmt: - name: Autoformat dependabot PR - timeout-minutes: 10 - if: github.actor == 'dependabot[bot]' - runs-on: 'ubuntu-latest' - # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/automating-dependabot-with-github-actions#changing-github_token-permissions - permissions: - pull-requests: write - issues: write - repository-projects: write - contents: write - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - ref: ${{ github.event.pull_request.head.ref }} - - name: Setup python - uses: actions/setup-python@v2 + - if: always() + uses: codecov/codecov-action@v3 with: - python-version: "3.8" - - name: Check formatting - run: | - python -m pip install -r test-requirements.txt - ./check.sh - - name: Commit autoformatter changes - if: failure() - run: | - black setup.py trio - git config user.name 'github-actions[bot]' - git config user.email '41898282+github-actions[bot]@users.noreply.github.com' - git commit -am "Autoformatter changes" - git push + directory: empty + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + name: Ubuntu (${{ matrix.python }}${{ matrix.extra_name }}) + flags: Ubuntu,${{ matrix.python }} macOS: name: 'macOS (${{ matrix.python }})' - timeout-minutes: 10 + timeout-minutes: 15 runs-on: 'macos-latest' strategy: fail-fast: false matrix: - python: ['3.7', '3.8', '3.9', '3.10'] - include: - - python: '3.8' # <- not actually used - arch: 'x64' - pypy_nightly_branch: 'py3.8' - extra_name: ', pypy 3.8 nightly' + python: ['3.8', '3.9', '3.10', 'pypy-3.9-nightly', 'pypy-3.10-nightly'] + continue-on-error: >- + ${{ + ( + endsWith(matrix.python, '-dev') + || endsWith(matrix.python, '-nightly') + ) + && true + || false + }} steps: - name: Checkout uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} cache: pip cache-dependency-path: test-requirements.txt - name: Run tests run: ./ci.sh - env: - # Should match 'name:' up above - JOB_NAME: 'macOS (${{ matrix.python }})' + - if: always() + uses: codecov/codecov-action@v3 + with: + directory: empty + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + name: macOS (${{ matrix.python }}) + flags: macOS,${{ matrix.python }} # https://github.com/marketplace/actions/alls-green#why check: # This job does nothing and is only used for the branch protection diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..f57321189e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-merge-conflict + - id: mixed-line-ending + - id: check-case-conflict + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + - repo: https://github.com/pycqa/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + additional_dependencies: + - "flake8-pyproject==1.2.3" + types: [file] + types_or: [python, pyi] + - repo: https://github.com/codespell-project/codespell + rev: v2.2.5 + hooks: + - id: codespell + +ci: + autofix_commit_msg: "[pre-commit.ci] auto fixes from pre-commit.com hooks" + autofix_prs: true + autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" + autoupdate_schedule: weekly + skip: [black,isort] + submodules: false diff --git a/MANIFEST.in b/MANIFEST.in index e2fd4c157f..eb9c0173da 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,6 +2,7 @@ include LICENSE LICENSE.MIT LICENSE.APACHE2 include README.rst include CODE_OF_CONDUCT.md CONTRIBUTING.md include test-requirements.txt -recursive-include trio/tests/test_ssl_certs *.pem +include trio/py.typed +recursive-include trio/_tests/test_ssl_certs *.pem recursive-include docs * prune docs/build diff --git a/README.rst b/README.rst index 4e096eddf3..016823e1f5 100644 --- a/README.rst +++ b/README.rst @@ -92,8 +92,9 @@ demonstration of implementing the "Happy Eyeballs" algorithm in an older library versus Trio. **Cool, but will it work on my system?** Probably! As long as you have -some kind of Python 3.7-or-better (CPython or the latest PyPy3 are -both fine), and are using Linux, macOS, Windows, or FreeBSD, then Trio +some kind of Python 3.8-or-better (CPython or [currently maintained versions of +PyPy3](https://doc.pypy.org/en/latest/faq.html#which-python-versions-does-pypy-implement) +are both fine), and are using Linux, macOS, Windows, or FreeBSD, then Trio will work. Other environments might work too, but those are the ones we test on. And all of our dependencies are pure Python, except for CFFI on Windows, which has wheels available, so diff --git a/check.sh b/check.sh index 8416a9c5d1..f9458d95c0 100755 --- a/check.sh +++ b/check.sh @@ -18,15 +18,18 @@ if ! black --check setup.py trio; then black --diff setup.py trio fi -# Run flake8 without pycodestyle and import-related errors -flake8 trio/ \ - --ignore=D,E,W,F401,F403,F405,F821,F822\ - || EXIT_STATUS=$? +if ! isort --check setup.py trio; then + EXIT_STATUS=1 + isort --diff setup.py trio +fi + +# Run flake8, configured in pyproject.toml +flake8 trio/ || EXIT_STATUS=$? # Run mypy on all supported platforms -mypy -m trio -m trio.testing --platform linux || EXIT_STATUS=$? -mypy -m trio -m trio.testing --platform darwin || EXIT_STATUS=$? # tests FreeBSD too -mypy -m trio -m trio.testing --platform win32 || EXIT_STATUS=$? +mypy trio --platform linux || EXIT_STATUS=$? +mypy trio --platform darwin || EXIT_STATUS=$? # tests FreeBSD too +mypy trio --platform win32 || EXIT_STATUS=$? # Check pip compile is consistent pip-compile test-requirements.in @@ -34,6 +37,16 @@ pip-compile docs-requirements.in if git status --porcelain | grep -q "requirements.txt"; then git status --porcelain + git --no-pager diff --color *requirements.txt + EXIT_STATUS=1 +fi + +codespell || EXIT_STATUS=$? + +python trio/_tests/check_type_completeness.py --overwrite-file || EXIT_STATUS=$? +if git status --porcelain trio/_tests/verify_types*.json | grep -q "M"; then + echo "Type completeness changed, please update!" + git --no-pager diff --color trio/_tests/verify_types*.json EXIT_STATUS=1 fi @@ -48,6 +61,7 @@ To fix formatting and see remaining errors, run pip install -r test-requirements.txt black setup.py trio + isort setup.py trio ./check.sh in your local checkout. diff --git a/ci.sh b/ci.sh index d4f9df3a94..4fb68d617f 100755 --- a/ci.sh +++ b/ci.sh @@ -2,14 +2,14 @@ set -ex -o pipefail +# disable warnings about pyright being out of date +# used in test_exports and in check.sh +export PYRIGHT_PYTHON_IGNORE_WARNINGS=1 + # Log some general info about the environment uname -a env | sort -if [ "$JOB_NAME" = "" ]; then - JOB_NAME="${TRAVIS_OS_NAME}-${TRAVIS_PYTHON_VERSION:-unknown}" -fi - # Curl's built-in retry system is not very robust; it gives up on lots of # network errors that we want to retry on. Wget might work better, but it's # not installed on azure pipelines's windows boxes. So... let's try some good @@ -26,40 +26,6 @@ function curl-harder() { return 1 } -################################################################ -# Bootstrap python environment, if necessary -################################################################ - -### PyPy nightly ### - -if [ "$PYPY_NIGHTLY_BRANCH" != "" ]; then - JOB_NAME="pypy_nightly_${PYPY_NIGHTLY_BRANCH}" - curl-harder -o pypy.tar.bz2 http://buildbot.pypy.org/nightly/${PYPY_NIGHTLY_BRANCH}/pypy-c-jit-latest-linux64.tar.bz2 - if [ ! -s pypy.tar.bz2 ]; then - # We know: - # - curl succeeded (200 response code) - # - nonetheless, pypy.tar.bz2 does not exist, or contains no data - # This isn't going to work, and the failure is not informative of - # anything involving Trio. - ls -l - echo "PyPy3 nightly build failed to download – something is wrong on their end." - echo "Skipping testing against the nightly build for right now." - exit 0 - fi - tar xaf pypy.tar.bz2 - # something like "pypy-c-jit-89963-748aa3022295-linux64" - PYPY_DIR=$(echo pypy-c-jit-*) - PYTHON_EXE=$PYPY_DIR/bin/pypy3 - - if ! ($PYTHON_EXE -m ensurepip \ - && $PYTHON_EXE -m pip install virtualenv \ - && $PYTHON_EXE -m virtualenv testenv); then - echo "pypy nightly is broken; skipping tests" - exit 0 - fi - source testenv/bin/activate -fi - ################################################################ # We have a Python environment! ################################################################ @@ -115,7 +81,7 @@ else # when installing, and then running 'certmgr.msc' and exporting the # certificate. See: # http://www.migee.com/2010/09/24/solution-for-unattendedsilent-installs-and-would-you-like-to-install-this-device-software/ - certutil -addstore "TrustedPublisher" .github/workflows/astrill-codesigning-cert.cer + certutil -addstore "TrustedPublisher" trio/_tests/astrill-codesigning-cert.cer # Double-slashes are how you tell windows-bash that you want a single # slash, and don't treat this as a unix-style filename that needs to # be replaced by a windows-style filename. @@ -136,30 +102,31 @@ else INSTALLDIR=$(python -c "import os, trio; print(os.path.dirname(trio.__file__))") cp ../pyproject.toml $INSTALLDIR - # We have to copy .coveragerc into this directory, rather than passing - # --cov-config=../.coveragerc to pytest, because codecov.sh will run - # 'coverage xml' to generate the report that it uses, and that will only - # apply the ignore patterns in the current directory's .coveragerc. - cp ../.coveragerc . - if pytest -r a --junitxml=../test-results.xml --run-slow ${INSTALLDIR} --cov="$INSTALLDIR" --verbose; then + + # TODO: remove this once we have a py.typed file + touch "$INSTALLDIR/py.typed" + + # get mypy tests a nice cache + MYPYPATH=".." mypy --config-file= --cache-dir=./.mypy_cache -c "import trio" >/dev/null 2>/dev/null || true + + # support subprocess spawning with coverage.py + echo "import coverage; coverage.process_startup()" | tee -a "$INSTALLDIR/../sitecustomize.py" + + if COVERAGE_PROCESS_START=$(pwd)/../.coveragerc coverage run --rcfile=../.coveragerc -m pytest -r a -p trio._tests.pytest_plugin --junitxml=../test-results.xml --run-slow ${INSTALLDIR} --verbose --durations=10; then PASSED=true else PASSED=false fi + coverage combine --rcfile ../.coveragerc + coverage report -m --rcfile ../.coveragerc + coverage xml --rcfile ../.coveragerc + # Remove the LSP again; again we want to do this ASAP to avoid # accidentally breaking other stuff. if [ "$LSP" != "" ]; then netsh winsock reset fi - # The codecov docs recommend something like 'bash <(curl ...)' to pipe the - # script directly into bash as its being downloaded. But, the codecov - # server is flaky, so we instead save to a temp file with retries, and - # wait until we've successfully fetched the whole script before trying to - # run it. - curl-harder -o codecov.sh https://codecov.io/bash - bash codecov.sh -n "${JOB_NAME}" - $PASSED fi diff --git a/docs-requirements.in b/docs-requirements.in index fab339e1f9..9239fe3fce 100644 --- a/docs-requirements.in +++ b/docs-requirements.in @@ -1,18 +1,15 @@ # RTD is currently installing 1.5.3, which has a bug in :lineno-match: -# sphinx-3.4 causes warnings about some trio._abc classes: GH#2338 -sphinx >= 1.7.0, < 6.2 -# jinja2-3.1 causes importerror with sphinx<4.0 -jinja2 < 3.1 +sphinx >= 4.0, < 6.2 +jinja2 sphinx_rtd_theme +sphinxcontrib-jquery sphinxcontrib-trio towncrier # Trio's own dependencies cffi; os_name == "nt" -contextvars; python_version < "3.7" attrs >= 19.2.0 sortedcontainers -async_generator >= 1.9 idna outcome sniffio @@ -20,3 +17,6 @@ exceptiongroup >= 1.0.0rc9 # See note in test-requirements.in immutables >= 0.6 + +# types used in annotations +pyOpenSSL diff --git a/docs-requirements.txt b/docs-requirements.txt index 4c197c69b3..d533573c93 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -6,29 +6,31 @@ # alabaster==0.7.13 # via sphinx -async-generator==1.10 - # via -r docs-requirements.in -attrs==22.2.0 +attrs==23.1.0 # via # -r docs-requirements.in # outcome babel==2.12.1 # via sphinx -certifi==2022.12.7 +certifi==2023.7.22 # via requests -charset-normalizer==3.1.0 +cffi==1.15.1 + # via cryptography +charset-normalizer==3.2.0 # via requests -click==8.1.3 +click==8.1.7 # via # click-default-group # towncrier -click-default-group==1.2.2 +click-default-group==1.2.4 # via towncrier +cryptography==41.0.3 + # via pyopenssl docutils==0.18.1 # via # sphinx # sphinx-rtd-theme -exceptiongroup==1.1.0 +exceptiongroup==1.1.3 # via -r docs-requirements.in idna==3.4 # via @@ -36,28 +38,34 @@ idna==3.4 # requests imagesize==1.4.1 # via sphinx -immutables==0.19 +immutables==0.20 # via -r docs-requirements.in -importlib-metadata==6.0.0 +importlib-metadata==6.8.0 # via sphinx +importlib-resources==6.0.1 + # via towncrier incremental==22.10.0 # via towncrier -jinja2==3.0.3 +jinja2==3.1.2 # via # -r docs-requirements.in # sphinx # towncrier -markupsafe==2.1.2 +markupsafe==2.1.3 # via jinja2 outcome==1.2.0 # via -r docs-requirements.in -packaging==23.0 +packaging==23.1 # via sphinx -pygments==2.14.0 +pycparser==2.21 + # via cffi +pygments==2.16.1 # via sphinx -pytz==2022.7.1 +pyopenssl==23.2.0 + # via -r docs-requirements.in +pytz==2023.3 # via babel -requests==2.28.2 +requests==2.31.0 # via sphinx sniffio==1.3.0 # via -r docs-requirements.in @@ -69,8 +77,9 @@ sphinx==6.1.3 # via # -r docs-requirements.in # sphinx-rtd-theme + # sphinxcontrib-jquery # sphinxcontrib-trio -sphinx-rtd-theme==1.2.0 +sphinx-rtd-theme==1.3.0 # via -r docs-requirements.in sphinxcontrib-applehelp==1.0.4 # via sphinx @@ -78,8 +87,10 @@ sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx -sphinxcontrib-jquery==2.0.0 - # via sphinx-rtd-theme +sphinxcontrib-jquery==4.1 + # via + # -r docs-requirements.in + # sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 @@ -90,12 +101,11 @@ sphinxcontrib-trio==1.1.2 # via -r docs-requirements.in tomli==2.0.1 # via towncrier -towncrier==22.12.0 +towncrier==23.6.0 # via -r docs-requirements.in -urllib3==1.26.14 +urllib3==2.0.4 # via requests -zipp==3.15.0 - # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools +zipp==3.16.2 + # via + # importlib-metadata + # importlib-resources diff --git a/docs/Makefile b/docs/Makefile index 4fd0bb58f2..69095d6d90 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -17,4 +17,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/_static/hackrtd.css b/docs/source/_static/hackrtd.css index e75a889f69..48401f2389 100644 --- a/docs/source/_static/hackrtd.css +++ b/docs/source/_static/hackrtd.css @@ -12,6 +12,11 @@ pre { background-color: #ffe13b; } +/* Make typevar/paramspec names distinguishable from classes. */ +.typevarref { + text-decoration: dashed underline; +} + /* Add a snakey triskelion ornament to
* https://stackoverflow.com/questions/8862344/css-hr-with-ornament/18541258#18541258 * but only do it to
s in the content box, b/c the RTD popup control panel diff --git a/docs/source/awesome-trio-libraries.rst b/docs/source/awesome-trio-libraries.rst index 50b3d698a3..b3174c97a2 100644 --- a/docs/source/awesome-trio-libraries.rst +++ b/docs/source/awesome-trio-libraries.rst @@ -100,6 +100,7 @@ Tools and Utilities ------------------- * `trio-typing `__ - Type hints for Trio and related projects. * `trio-util `__ - An assortment of utilities for the Trio async/await framework. +* `flake8-trio `__ - Highly opinionated linter for various sorts of problems in Trio and/or AnyIO. Can run as a flake8 plugin, or standalone with support for autofixing some errors. * `tricycle `__ - This is a library of interesting-but-maybe-not-yet-fully-proven extensions to Trio. * `tenacity `__ - Retrying library for Python with async/await support. * `perf-timer `__ - A code timer with Trio async support (see ``TrioPerfTimer``). Collects execution time of a block of code excluding time when the coroutine isn't scheduled, such as during blocking I/O and sleep. Also offers ``trio_perf_counter()`` for low-level timing. diff --git a/docs/source/conf.py b/docs/source/conf.py index cfac66576b..66aa8dea05 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -44,7 +44,6 @@ nitpick_ignore = [ ("py:class", "CapacityLimiter-like object"), ("py:class", "bytes-like"), - ("py:class", "None"), # Was removed but still shows up in changelog ("py:class", "trio.lowlevel.RunLocal"), # trio.abc is documented at random places scattered throughout the docs @@ -53,19 +52,50 @@ ("py:exc", "Anything else"), ("py:class", "async function"), ("py:class", "sync function"), - # https://github.com/sphinx-doc/sphinx/issues/7722 - # TODO: why do these need to be spelled out? - ("py:class", "trio._abc.ReceiveType"), - ("py:class", "trio._abc.SendType"), - ("py:class", "trio._abc.T"), - ("py:obj", "trio._abc.ReceiveType"), - ("py:obj", "trio._abc.SendType"), - ("py:obj", "trio._abc.T"), - ("py:obj", "trio._abc.T_resource"), + # why aren't these found in stdlib? + ("py:class", "types.FrameType"), + # TODO: temporary type + ("py:class", "_SocketType"), + # these are not defined in https://docs.python.org/3/objects.inv + ("py:class", "socket.AddressFamily"), + ("py:class", "socket.SocketKind"), + ("py:class", "Buffer"), # collections.abc.Buffer, in 3.12 ] autodoc_inherit_docstrings = False default_role = "obj" +# These have incorrect __module__ set in stdlib and give the error +# `py:class reference target not found` +# Some of the nitpick_ignore's above can probably be fixed with this. +# See https://github.com/sphinx-doc/sphinx/issues/8315#issuecomment-751335798 +autodoc_type_aliases = { + # aliasing doesn't actually fix the warning for types.FrameType, but displaying + # "types.FrameType" is more helpful than just "frame" + "FrameType": "types.FrameType", + # unaliasing these makes intersphinx able to resolve them + "Outcome": "outcome.Outcome", + "Context": "OpenSSL.SSL.Context", +} + + +def autodoc_process_signature( + app, what, name, obj, options, signature, return_annotation +): + """Modify found signatures to fix various issues.""" + if signature is not None: + signature = signature.replace("~_contextvars.Context", "~contextvars.Context") + if name == "trio.lowlevel.start_guest_run": + signature = signature.replace("Outcome", "~outcome.Outcome") + if name == "trio.lowlevel.RunVar": # Typevar is not useful here. + signature = signature.replace(": ~trio._core._local.T", "") + if "_NoValue" in signature: + # Strip the type from the union, make it look like = ... + signature = signature.replace(" | type[trio._core._local._NoValue]", "") + signature = signature.replace("", "...") + + return signature, return_annotation + + # XX hack the RTD theme until # https://github.com/rtfd/sphinx_rtd_theme/pull/382 # is shipped (should be in the release after 0.2.4) @@ -73,6 +103,7 @@ # though. def setup(app): app.add_css_file("hackrtd.css") + app.connect("autodoc-process-signature", autodoc_process_signature) # -- General configuration ------------------------------------------------ @@ -90,7 +121,9 @@ def setup(app): "sphinx.ext.coverage", "sphinx.ext.napoleon", "sphinxcontrib_trio", + "sphinxcontrib.jquery", "local_customization", + "typevars", ] intersphinx_mapping = { diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 7d66ae711d..6189814b3f 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -286,13 +286,13 @@ Code formatting ~~~~~~~~~~~~~~~ Instead of wasting time arguing about code formatting, we use `black -`__ to automatically format all our -code to a standard style. While you're editing code you can be as -sloppy as you like about whitespace; and then before you commit, just -run:: +`__ as well as other tools to automatically +format all our code to a standard style. While you're editing code you +can be as sloppy as you like about whitespace; and then before you commit, +just run:: - pip install -U black - black setup.py trio + pip install -U pre-commit + pre-commit to fix it up. (And don't worry if you forget – when you submit a pull request then we'll automatically check and remind you.) Hopefully this @@ -300,6 +300,17 @@ will let you focus on more important style issues like choosing good names, writing useful comments, and making sure your docstrings are nicely formatted. (black doesn't reformat comments or docstrings.) +If you would like, you can even have pre-commit run before you commit by +running:: + + pre-commit install + +and now pre-commit will run before git commits. You can uninstall the +pre-commit hook at any time by running:: + + pre-commit uninstall + + Very occasionally, you'll want to override black formatting. To do so, you can can add ``# fmt: off`` and ``# fmt: on`` comments. @@ -311,6 +322,11 @@ If you want to see what changes black will make, you can use:: in-place.) +Additionally, in some cases it is necessary to disable isort changing the +order of imports. To do so you can add ``# isort: split`` comments. +For more information, please see `isort's docs `__. + + .. _pull-request-release-notes: Release notes diff --git a/docs/source/history.rst b/docs/source/history.rst index bcf90c79ab..24eeb57261 100644 --- a/docs/source/history.rst +++ b/docs/source/history.rst @@ -5,6 +5,45 @@ Release history .. towncrier release notes start +Trio 0.22.2 (2023-07-13) +------------------------ + +Bugfixes +~~~~~~~~ + +- Fix ``PermissionError`` when importing `trio` due to trying to access ``pthread``. (`#2688 `__) + + +Trio 0.22.1 (2023-07-02) +------------------------ + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Timeout functions now raise `ValueError` if passed `math.nan`. This includes `trio.sleep`, `trio.sleep_until`, `trio.move_on_at`, `trio.move_on_after`, `trio.fail_at` and `trio.fail_after`. (`#2493 `__) + + +Features +~~~~~~~~ + +- Added support for naming threads created with `trio.to_thread.run_sync`, requires pthreads so is only available on POSIX platforms with glibc installed. (`#1148 `__) +- `trio.socket.socket` now prints the address it tried to connect to upon failure. (`#1810 `__) + + +Bugfixes +~~~~~~~~ + +- Fixed a crash that can occur when running Trio within an embedded Python interpreter, by handling the `TypeError` that is raised when trying to (re-)install a C signal handler. (`#2333 `__) +- Fix :func:`sniffio.current_async_library` when Trio tasks are spawned from a non-Trio context (such as when using trio-asyncio). Previously, a regular Trio task would inherit the non-Trio library name, and spawning a system task would cause the non-Trio caller to start thinking it was Trio. (`#2462 `__) +- Issued a new release as in the git tag for 0.22.0, ``trio.__version__`` is incorrectly set to 0.21.0+dev. (`#2485 `__) + + +Improved documentation +~~~~~~~~~~~~~~~~~~~~~~ + +- Documented that :obj:`Nursery.start_soon` does not guarantee task ordering. (`#970 `__) + + Trio 0.22.0 (2022-09-28) ------------------------ diff --git a/docs/source/index.rst b/docs/source/index.rst index 84d81880af..fc13227c3a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -45,7 +45,7 @@ Vital statistics: * Supported environments: We test on - - Python: 3.7+ (CPython and PyPy) + - Python: 3.8+ (CPython and PyPy) - Windows, macOS, Linux (glibc and musl), FreeBSD Other environments might also work; give it a try and see. diff --git a/docs/source/local_customization.py b/docs/source/local_customization.py index a970ad6e22..f071b6dfbb 100644 --- a/docs/source/local_customization.py +++ b/docs/source/local_customization.py @@ -1,11 +1,11 @@ -from docutils.parsers.rst import directives +from docutils.parsers.rst import directives as directives # noqa: F401 from sphinx import addnodes from sphinx.domains.python import PyClasslike -from sphinx.ext.autodoc import ( - FunctionDocumenter, - MethodDocumenter, - ClassLevelDocumenter, - Options, +from sphinx.ext.autodoc import ( # noqa: F401 + ClassLevelDocumenter as ClassLevelDocumenter, + FunctionDocumenter as FunctionDocumenter, + MethodDocumenter as MethodDocumenter, + Options as Options, ) """ diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 055cd9dd18..6d04d3ce4a 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -693,7 +693,7 @@ Errors in multiple child tasks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Normally, in Python, only one thing happens at a time, which means -that only one thing can wrong at a time. Trio has no such +that only one thing can go wrong at a time. Trio has no such limitation. Consider code like:: async def broken1(): @@ -916,12 +916,19 @@ The nursery API .. autoclass:: Nursery() - :members: + :members: child_tasks, parent_task + + .. automethod:: start(async_fn, *args, name = None) + + .. automethod:: start_soon(async_fn, *args, name = None) .. attribute:: TASK_STATUS_IGNORED + :type: TaskStatus - See :meth:`~Nursery.start`. + See :meth:`Nursery.start`. +.. autoclass:: TaskStatus(Protocol[StatusT]) + :members: .. _task-local-storage: @@ -974,12 +981,8 @@ work. What we need is something that's *like* a global variable, but that can have different values depending on which request handler is accessing it. -To solve this problem, Python 3.7 added a new module to the standard -library: :mod:`contextvars`. And not only does Trio have built-in -support for :mod:`contextvars`, but if you're using an earlier version -of Python, then Trio makes sure that a backported version of -:mod:`contextvars` is installed. So you can assume :mod:`contextvars` -is there and works regardless of what version of Python you're using. +To solve this problem, Python has a module in the standard +library: :mod:`contextvars`. Here's a toy example demonstrating how to use :mod:`contextvars`: @@ -1009,7 +1012,7 @@ Example output (yours may differ slightly): request 0: Request received finished For more information, read the -`contextvars docs `__. +`contextvars docs `__. .. _synchronization: @@ -1096,6 +1099,8 @@ Broadcasting an event with :class:`Event` .. autoclass:: Event :members: +.. autoclass:: EventStatistics + :members: .. _channels: @@ -1169,7 +1174,7 @@ the previous version, and then exits cleanly. The only change is the addition of ``async with`` blocks inside the producer and consumer: .. literalinclude:: reference-core/channels-shutdown.py - :emphasize-lines: 10,15 + :emphasize-lines: 11,17 The really important thing here is the producer's ``async with`` . When the producer exits, this closes the ``send_channel``, and that @@ -1248,7 +1253,7 @@ Fortunately, there's a better way! Here's a fixed version of our program above: .. literalinclude:: reference-core/channels-mpmc-fixed.py - :emphasize-lines: 7, 9, 10, 12, 13 + :emphasize-lines: 8, 10, 11, 13, 14 This example demonstrates using the `MemorySendChannel.clone` and `MemoryReceiveChannel.clone` methods. What these do is create copies @@ -1456,6 +1461,16 @@ don't have any special access to Trio's internals.) .. autoclass:: Condition :members: +These primitives return statistics objects that can be inspected. + +.. autoclass:: CapacityLimiterStatistics + :members: + +.. autoclass:: LockStatistics + :members: + +.. autoclass:: ConditionStatistics + :members: .. _async-generators: diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index a3291ef2ae..e270033b46 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -304,6 +304,9 @@ unfortunately that's not yet possible. .. automethod:: statistics +.. autoclass:: DTLSChannelStatistics + :members: + .. module:: trio.socket Low-level networking with :mod:`trio.socket` @@ -501,6 +504,14 @@ Socket objects * :meth:`~socket.socket.set_inheritable` * :meth:`~socket.socket.get_inheritable` +The internal SocketType +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: _SocketType +.. + TODO: adding `:members:` here gives error due to overload+_wraps on `sendto` + TODO: rewrite ... all of the above when fixing _SocketType vs SocketType + + .. currentmodule:: trio @@ -634,9 +645,11 @@ Asynchronous path objects Asynchronous file objects ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: open_file +.. Suppress type annotations here, they refer to lots of internal types. + The normal Python docs go into better detail. +.. autofunction:: open_file(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=None, opener=None) -.. autofunction:: wrap_file +.. autofunction:: wrap_file(file) .. interface:: Asynchronous file interface @@ -718,7 +731,11 @@ task and interact with it while it's running: .. autofunction:: trio.run_process -.. autoclass:: trio.Process +.. autoclass:: trio._subprocess.HasFileno(Protocol) + + .. automethod:: fileno + +.. autoclass:: trio.Process() .. autoattribute:: returncode diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 815cff2ddf..712a36ad04 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -49,7 +49,11 @@ attributes, :meth:`trio.Lock.statistics`, etc.). Here are some more. Global statistics ----------------- -.. autofunction:: current_statistics +.. function:: current_statistics() -> RunStatistics + + Returns an object containing run-loop-level debugging information: + +.. autoclass:: RunStatistics() The current clock @@ -378,6 +382,8 @@ Wait queue abstraction :members: :undoc-members: +.. autoclass:: ParkingLotStatistics + :members: Low-level checkpoint functions ------------------------------ @@ -532,7 +538,6 @@ Task API putting a task to sleep and then waking it up again. (See :func:`wait_task_rescheduled` for details.) - .. _guest-mode: Using "guest mode" to run Trio on top of other event loops diff --git a/docs/source/releasing.rst b/docs/source/releasing.rst index 27cee864c0..0fe51370d5 100644 --- a/docs/source/releasing.rst +++ b/docs/source/releasing.rst @@ -29,7 +29,7 @@ Things to do for releasing: - review history change - - ``git rm`` changes + - ``git rm`` the now outdated newfragments + commit @@ -53,4 +53,10 @@ Things to do for releasing: * merge the release pull request +* make a GitHub release (go to the tag and press "Create release from tag") + + + paste in the new content in ``history.rst`` and convert it to markdown: turn the parts under section into ``---``, update links to just be the links, and whatever else is necessary. + + + include anything else that might be pertinent, like a link to the commits between the latest and current release. + * announce on gitter diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 19289ca991..0584446fb7 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -88,7 +88,7 @@ Okay, ready? Let's get started. Before you begin ---------------- -1. Make sure you're using Python 3.7 or newer. +1. Make sure you're using Python 3.8 or newer. 2. ``python3 -m pip install --upgrade trio`` (or on Windows, maybe ``py -3 -m pip install --upgrade trio`` – `details @@ -436,15 +436,15 @@ Now that we understand ``async with``, let's look at ``parent`` again: :end-at: all done! There are only 4 lines of code that really do anything here. On line -17, we use :func:`trio.open_nursery` to get a "nursery" object, and +20, we use :func:`trio.open_nursery` to get a "nursery" object, and then inside the ``async with`` block we call ``nursery.start_soon`` twice, -on lines 19 and 22. There are actually two ways to call an async +on lines 22 and 25. There are actually two ways to call an async function: the first one is the one we already saw, using ``await async_fn()``; the new one is ``nursery.start_soon(async_fn)``: it asks Trio to start running this async function, *but then returns immediately without waiting for the function to finish*. So after our two calls to ``nursery.start_soon``, ``child1`` and ``child2`` are now running in the -background. And then at line 25, the commented line, we hit the end of +background. And then at line 28, the commented line, we hit the end of the ``async with`` block, and the nursery's ``__aexit__`` function runs. What this does is force ``parent`` to stop here and wait for all the children in the nursery to exit. This is why you have to use diff --git a/docs/source/typevars.py b/docs/source/typevars.py new file mode 100644 index 0000000000..ab492b98b8 --- /dev/null +++ b/docs/source/typevars.py @@ -0,0 +1,103 @@ +"""Transform references to typevars to avoid missing reference errors. + +See https://github.com/sphinx-doc/sphinx/issues/7722 also. +""" +from __future__ import annotations + +import re +from pathlib import Path + +from sphinx.addnodes import Element, pending_xref +from sphinx.application import Sphinx +from sphinx.environment import BuildEnvironment +from sphinx.errors import NoUri + +import trio + + +def identify_typevars(trio_folder: Path) -> None: + """Record all typevars in trio.""" + for filename in trio_folder.rglob("*.py"): + with open(filename, encoding="utf8") as f: + for line in f: + # A simple regex should be sufficient to find them all, no need to actually parse. + match = re.search( + r"\b(TypeVar|TypeVarTuple|ParamSpec)\(['\"]([^'\"]+)['\"]", + line, + ) + if match is not None: + relative = "trio" / filename.relative_to(trio_folder) + relative = relative.with_suffix("") + if relative.name == "__init__": # Package, remove. + relative = relative.parent + kind = match.group(1) + name = match.group(2) + typevars_qualified[f'{".".join(relative.parts)}.{name}'] = kind + existing = typevars_named.setdefault(name, kind) + if existing != kind: + print("Mismatch: {} = {}, {}", name, existing, kind) + + +# All our typevars, so we can suppress reference errors for them. +typevars_qualified: dict[str, str] = {} +typevars_named: dict[str, str] = {} + + +def lookup_reference( + app: Sphinx, + env: BuildEnvironment, + node: pending_xref, + contnode: Element, +) -> Element | None: + """Handle missing references.""" + # If this is a typing_extensions object, redirect to typing. + # Most things there are backports, so the stdlib docs should have an entry. + target: str = node["reftarget"] + if target.startswith("typing_extensions."): + new_node = node.copy() + new_node["reftarget"] = f"typing.{target[18:]}" + # This fires off this same event, with our new modified node in order to fetch the right + # URL to use. + return app.emit_firstresult( + "missing-reference", + env, + new_node, + contnode, + allowed_exceptions=(NoUri,), + ) + + try: + typevar_type = typevars_qualified[target] + except KeyError: + # Imports might mean the typevar was defined in a different module or something. + # Fall back to checking just by name. + dot = target.rfind(".") + stem = target[dot + 1 :] if dot >= 0 else target + try: + typevar_type = typevars_named[stem] + except KeyError: + # Let other handlers deal with this name, it's not a typevar. + return None + + # Found a typevar. Redirect to the stdlib docs for that kind of var. + new_node = node.copy() + new_node["reftarget"] = f"typing.{typevar_type}" + new_node = app.emit_firstresult( + "missing-reference", + env, + new_node, + contnode, + allowed_exceptions=(NoUri,), + ) + reftitle = new_node["reftitle"] + # Is normally "(in Python 3.XX)", make it say typevar/paramspec/etc + paren = "(" if reftitle.startswith("(") else "" + new_node["reftitle"] = f"{paren}{typevar_type}, {reftitle.lstrip('(')}" + # Add a CSS class, for restyling. + new_node["classes"].append("typevarref") + return new_node + + +def setup(app: Sphinx) -> None: + identify_typevars(Path(trio.__file__).parent) + app.connect("missing-reference", lookup_reference, -10) diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 31eeef1cd0..0000000000 --- a/mypy.ini +++ /dev/null @@ -1,25 +0,0 @@ -[mypy] -# TODO: run mypy against several OS/version combos in CI -# https://mypy.readthedocs.io/en/latest/command_line.html#platform-configuration - -# Be flexible about dependencies that don't have stubs yet (like pytest) -ignore_missing_imports = True - -# Be strict about use of Mypy -warn_unused_ignores = True -warn_unused_configs = True -warn_redundant_casts = True -warn_return_any = True - -# Avoid subtle backsliding -#disallow_any_decorated = True -#disallow_incomplete_defs = True -#disallow_subclassing_any = True - -# Enable gradually / for new modules -check_untyped_defs = False -disallow_untyped_calls = False -disallow_untyped_defs = False - -# DO NOT use `ignore_errors`; it doesn't apply -# downstream and users have to deal with them. diff --git a/newsfragments/1148.feature.rst b/newsfragments/1148.feature.rst deleted file mode 100644 index 51f2b792c3..0000000000 --- a/newsfragments/1148.feature.rst +++ /dev/null @@ -1 +0,0 @@ -Added support for naming threads created with `trio.to_thread.run_sync`, requires pthreads so is only available on POSIX platforms with glibc installed. diff --git a/newsfragments/1810.feature.rst b/newsfragments/1810.feature.rst deleted file mode 100644 index a2599d32b0..0000000000 --- a/newsfragments/1810.feature.rst +++ /dev/null @@ -1 +0,0 @@ -`trio.socket.socket` now prints the address it tried to connect to upon failure. diff --git a/newsfragments/2333.bugfix.rst b/newsfragments/2333.bugfix.rst deleted file mode 100644 index a0f4c9fd37..0000000000 --- a/newsfragments/2333.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fixed a crash that can occur when running Trio within an embedded Python interpreter, by handling the `TypeError` that is raised when trying to (re-)install a C signal handler. diff --git a/newsfragments/2462.bugfix.rst b/newsfragments/2462.bugfix.rst deleted file mode 100644 index 9208289ed9..0000000000 --- a/newsfragments/2462.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -Fix :func:`sniffio.current_async_library` when Trio tasks are spawned from a non-Trio context (such as when using trio-asyncio). Previously, a regular Trio task would inherit the non-Trio library name, and spawning a system task would cause the non-Trio caller to start thinking it was Trio. diff --git a/newsfragments/2668.removal.rst b/newsfragments/2668.removal.rst new file mode 100644 index 0000000000..512f681077 --- /dev/null +++ b/newsfragments/2668.removal.rst @@ -0,0 +1 @@ +Drop support for Python3.7 and PyPy3.7/3.8. diff --git a/newsfragments/2696.feature.rst b/newsfragments/2696.feature.rst new file mode 100644 index 0000000000..560cf3b365 --- /dev/null +++ b/newsfragments/2696.feature.rst @@ -0,0 +1,4 @@ +:func:`trio.lowlevel.start_guest_run` now does a bit more setup of the guest run +before it returns to its caller, so that the caller can immediately make calls to +:func:`trio.current_time`, :func:`trio.lowlevel.spawn_system_task`, +:func:`trio.lowlevel.current_trio_token`, etc. diff --git a/newsfragments/2700.misc.rst b/newsfragments/2700.misc.rst new file mode 100644 index 0000000000..a70924816e --- /dev/null +++ b/newsfragments/2700.misc.rst @@ -0,0 +1,4 @@ +Trio now indicates its presence to `sniffio` using the ``sniffio.thread_local`` +interface that is preferred since sniffio v1.3.0. This should be less likely +than the previous approach to cause :func:`sniffio.current_async_library` to +return incorrect results due to unintended inheritance of contextvars. diff --git a/newsfragments/970.doc.rst b/newsfragments/970.doc.rst deleted file mode 100644 index 6e114abf5b..0000000000 --- a/newsfragments/970.doc.rst +++ /dev/null @@ -1 +0,0 @@ -Documented that :obj:`Nursery.start_soon` does not guarantee task ordering. diff --git a/newsfragments/README.rst b/newsfragments/README.rst index 349e67eec0..52dc0716bb 100644 --- a/newsfragments/README.rst +++ b/newsfragments/README.rst @@ -14,6 +14,7 @@ Each file should be named like ``..rst``, where deprecated features after an appropriate time, go in the ``deprecated`` category instead) * ``feature``: any new feature that doesn't qualify for ``headline`` +* ``removal``: removing support for old python versions, or other removals with no deprecation period. * ``bugfix`` * ``doc`` * ``deprecated`` diff --git a/notes-to-self/afd-lab.py b/notes-to-self/afd-lab.py index ed420dbdbd..600975482c 100644 --- a/notes-to-self/afd-lab.py +++ b/notes-to-self/afd-lab.py @@ -77,22 +77,27 @@ # matter, energy, and life which lie close at hand yet can never be detected # with the senses we have." -import sys import os.path +import sys + sys.path.insert(0, os.path.abspath(os.path.dirname(__file__) + r"\..")) import trio + print(trio.__file__) -import trio.testing import socket +import trio.testing +from trio._core._io_windows import _afd_helper_handle, _check, _get_base_socket from trio._core._windows_cffi import ( - ffi, kernel32, AFDPollFlags, IoControlCodes, ErrorCodes -) -from trio._core._io_windows import ( - _get_base_socket, _afd_helper_handle, _check + AFDPollFlags, + ErrorCodes, + IoControlCodes, + ffi, + kernel32, ) + class AFDLab: def __init__(self): self._afd = _afd_helper_handle() @@ -173,4 +178,5 @@ async def main(): await trio.sleep(2) nursery.cancel_scope.cancel() + trio.run(main) diff --git a/notes-to-self/aio-guest-test.py b/notes-to-self/aio-guest-test.py index b64a11bd04..17d4bfb9e0 100644 --- a/notes-to-self/aio-guest-test.py +++ b/notes-to-self/aio-guest-test.py @@ -1,10 +1,13 @@ import asyncio + import trio + async def aio_main(): loop = asyncio.get_running_loop() trio_done_fut = loop.create_future() + def trio_done_callback(main_outcome): print(f"trio_main finished: {main_outcome!r}") trio_done_fut.set_result(main_outcome) @@ -35,6 +38,7 @@ async def trio_main(): if n >= 10: return + async def aio_pingpong(from_trio, to_trio): print("aio_pingpong!") diff --git a/notes-to-self/atomic-local.py b/notes-to-self/atomic-local.py index 212c9eef00..429211eaf6 100644 --- a/notes-to-self/atomic-local.py +++ b/notes-to-self/atomic-local.py @@ -3,9 +3,11 @@ # Has to be a string :-( sentinel = "_unique_name" + def f(): print(locals()) + # code(argcount, kwonlyargcount, nlocals, stacksize, flags, codestring, # constants, names, varnames, filename, name, firstlineno, # lnotab[, freevars[, cellvars]]) diff --git a/notes-to-self/blocking-read-hack.py b/notes-to-self/blocking-read-hack.py index b301058e85..f4a73f876d 100644 --- a/notes-to-self/blocking-read-hack.py +++ b/notes-to-self/blocking-read-hack.py @@ -1,13 +1,16 @@ -import trio +import errno import os import socket -import errno + +import trio bad_socket = socket.socket() + class BlockingReadTimeoutError(Exception): pass + async def blocking_read_with_timeout(fd, count, timeout): print("reading from fd", fd) cancel_requested = False @@ -42,4 +45,5 @@ async def kill_it_after_timeout(new_fd): finally: os.close(new_fd) + trio.run(blocking_read_with_timeout, 0, 10, 2) diff --git a/notes-to-self/estimate-task-size.py b/notes-to-self/estimate-task-size.py index 1e8597ba42..0010c7a2b4 100644 --- a/notes-to-self/estimate-task-size.py +++ b/notes-to-self/estimate-task-size.py @@ -1,15 +1,18 @@ # Little script to get a rough estimate of how much memory each task takes import resource + import trio import trio.testing LOW = 1000 HIGH = 10000 + async def tinytask(): await trio.sleep_forever() + async def measure(count): async with trio.open_nursery() as nursery: for _ in range(count): @@ -23,8 +26,8 @@ async def main(): low_usage = await measure(LOW) high_usage = await measure(HIGH + LOW) - print("Memory usage per task:", - (high_usage.ru_maxrss - low_usage.ru_maxrss) / HIGH) + print("Memory usage per task:", (high_usage.ru_maxrss - low_usage.ru_maxrss) / HIGH) print("(kilobytes on Linux, bytes on macOS)") + trio.run(main) diff --git a/notes-to-self/fbsd-pipe-close-notify.py b/notes-to-self/fbsd-pipe-close-notify.py index 7b18f65d6f..ab17f94c3f 100644 --- a/notes-to-self/fbsd-pipe-close-notify.py +++ b/notes-to-self/fbsd-pipe-close-notify.py @@ -4,9 +4,8 @@ # # Upstream bug: https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350 -import select import os -import threading +import select r, w = os.pipe() diff --git a/notes-to-self/file-read-latency.py b/notes-to-self/file-read-latency.py index 9af1b7222d..132e29dc4f 100644 --- a/notes-to-self/file-read-latency.py +++ b/notes-to-self/file-read-latency.py @@ -8,7 +8,7 @@ # ns per call, instead of ~500 ns/call for the syscall and related overhead. # That's probably more fair -- the BufferedIOBase code can't service random # accesses, even if your working set fits entirely in RAM. -f = open("/etc/passwd", "rb")#, buffering=0) +f = open("/etc/passwd", "rb") # , buffering=0) while True: start = time.perf_counter() @@ -23,5 +23,8 @@ both = (between - start) / COUNT * 1e9 seek = (end - between) / COUNT * 1e9 read = both - seek - print("{:.2f} ns/(seek+read), {:.2f} ns/seek, estimate ~{:.2f} ns/read" - .format(both, seek, read)) + print( + "{:.2f} ns/(seek+read), {:.2f} ns/seek, estimate ~{:.2f} ns/read".format( + both, seek, read + ) + ) diff --git a/notes-to-self/graceful-shutdown-idea.py b/notes-to-self/graceful-shutdown-idea.py index 792344de02..b454d7610a 100644 --- a/notes-to-self/graceful-shutdown-idea.py +++ b/notes-to-self/graceful-shutdown-idea.py @@ -1,5 +1,6 @@ import trio + class GracefulShutdownManager: def __init__(self): self._shutting_down = False @@ -21,6 +22,7 @@ def cancel_on_graceful_shutdown(self): def shutting_down(self): return self._shutting_down + # Code can check gsm.shutting_down occasionally at appropriate points to see # if it should exit. # @@ -31,9 +33,11 @@ async def stream_handler(stream): while True: with gsm.cancel_on_graceful_shutdown(): data = await stream.receive_some() + print(f"{data = }") if gsm.shutting_down: break + # To trigger the shutdown: async def listen_for_shutdown_signals(): with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as signal_aiter: diff --git a/notes-to-self/how-does-windows-so-reuseaddr-work.py b/notes-to-self/how-does-windows-so-reuseaddr-work.py index 4865ea17b3..3189d4d594 100644 --- a/notes-to-self/how-does-windows-so-reuseaddr-work.py +++ b/notes-to-self/how-does-windows-so-reuseaddr-work.py @@ -4,12 +4,13 @@ # # See https://github.com/python-trio/trio/issues/928 for details and context -import socket import errno +import socket modes = ["default", "SO_REUSEADDR", "SO_EXCLUSIVEADDRUSE"] bind_types = ["wildcard", "specific"] + def sock(mode): s = socket.socket(family=socket.AF_INET) if mode == "SO_REUSEADDR": @@ -18,6 +19,7 @@ def sock(mode): s.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) return s + def bind(sock, bind_type): if bind_type == "wildcard": sock.bind(("0.0.0.0", 12345)) @@ -26,6 +28,7 @@ def bind(sock, bind_type): else: assert False + def table_entry(mode1, bind_type1, mode2, bind_type2): with sock(mode1) as sock1: bind(sock1, bind_type1) @@ -41,19 +44,22 @@ def table_entry(mode1, bind_type1, mode2, bind_type2): else: return "Success" -print(""" + +print( + """ second bind | """ -+ " | ".join(["%-19s" % mode for mode in modes]) + + " | ".join(["%-19s" % mode for mode in modes]) ) -print(""" """, end='') +print(""" """, end="") for mode in modes: - print(" | " + " | ".join(["%8s" % bind_type for bind_type in bind_types]), end='') + print(" | " + " | ".join(["%8s" % bind_type for bind_type in bind_types]), end="") -print(""" +print( + """ first bind -----------------------------------------------------------------""" -# default | wildcard | INUSE | Success | ACCESS | Success | INUSE | Success + # default | wildcard | INUSE | Success | ACCESS | Success | INUSE | Success ) for i, mode1 in enumerate(modes): @@ -63,6 +69,8 @@ def table_entry(mode1, bind_type1, mode2, bind_type2): for l, bind_type2 in enumerate(bind_types): entry = table_entry(mode1, bind_type1, mode2, bind_type2) row.append(entry) - #print(mode1, bind_type1, mode2, bind_type2, entry) - print("{:>19} | {:>8} | ".format(mode1, bind_type1) - + " | ".join(["%8s" % entry for entry in row])) + # print(mode1, bind_type1, mode2, bind_type2, entry) + print( + f"{mode1:>19} | {bind_type1:>8} | " + + " | ".join(["%8s" % entry for entry in row]) + ) diff --git a/notes-to-self/loopy.py b/notes-to-self/loopy.py index 9f893590bd..0297a32dd8 100644 --- a/notes-to-self/loopy.py +++ b/notes-to-self/loopy.py @@ -1,6 +1,8 @@ -import trio import time +import trio + + async def loopy(): try: while True: @@ -9,10 +11,12 @@ async def loopy(): except KeyboardInterrupt: print("KI!") + async def main(): async with trio.open_nursery() as nursery: nursery.start_soon(loopy) nursery.start_soon(loopy) nursery.start_soon(loopy) + trio.run(main) diff --git a/notes-to-self/lots-of-tasks.py b/notes-to-self/lots-of-tasks.py index fca2741de9..048c69a7ec 100644 --- a/notes-to-self/lots-of-tasks.py +++ b/notes-to-self/lots-of-tasks.py @@ -1,12 +1,15 @@ import sys + import trio (COUNT_STR,) = sys.argv[1:] COUNT = int(COUNT_STR) + async def main(): async with trio.open_nursery() as nursery: for _ in range(COUNT): nursery.start_soon(trio.sleep, 1) + trio.run(main) diff --git a/notes-to-self/manual-signal-handler.py b/notes-to-self/manual-signal-handler.py index 39ffeb5a4b..e1b5ee3036 100644 --- a/notes-to-self/manual-signal-handler.py +++ b/notes-to-self/manual-signal-handler.py @@ -3,16 +3,20 @@ if os.name == "nt": import cffi + ffi = cffi.FFI() - ffi.cdef(""" + ffi.cdef( + """ void* WINAPI GetProcAddress(void* hModule, char* lpProcName); typedef void (*PyOS_sighandler_t)(int); - """) + """ + ) kernel32 = ffi.dlopen("kernel32.dll") PyOS_getsig_ptr = kernel32.GetProcAddress( - ffi.cast("void*", sys.dllhandle), b"PyOS_getsig") + ffi.cast("void*", sys.dllhandle), b"PyOS_getsig" + ) PyOS_getsig = ffi.cast("PyOS_sighandler_t (*)(int)", PyOS_getsig_ptr) - import signal + PyOS_getsig(signal.SIGINT)(signal.SIGINT) diff --git a/notes-to-self/measure-listen-backlog.py b/notes-to-self/measure-listen-backlog.py index dc32732dfe..b7253b86cc 100644 --- a/notes-to-self/measure-listen-backlog.py +++ b/notes-to-self/measure-listen-backlog.py @@ -1,5 +1,6 @@ import trio + async def run_test(nominal_backlog): print("--\nnominal:", nominal_backlog) @@ -22,5 +23,6 @@ async def run_test(nominal_backlog): for client_sock in client_socks: client_sock.close() + for nominal_backlog in [10, trio.socket.SOMAXCONN, 65535]: trio.run(run_test, nominal_backlog) diff --git a/notes-to-self/ntp-example.py b/notes-to-self/ntp-example.py index 44db8cc873..2bb9f80fb3 100644 --- a/notes-to-self/ntp-example.py +++ b/notes-to-self/ntp-example.py @@ -3,9 +3,11 @@ # - use the hostname "2.pool.ntp.org" # (see: https://news.ntppool.org/2011/06/continuing-ipv6-deployment/) -import trio -import struct import datetime +import struct + +import trio + def make_query_packet(): """Construct a UDP packet suitable for querying an NTP server to ask for @@ -27,6 +29,7 @@ def make_query_packet(): return packet + def extract_transmit_timestamp(ntp_packet): """Given an NTP packet, extract the "transmit timestamp" field, as a Python datetime.""" @@ -49,15 +52,16 @@ def extract_transmit_timestamp(ntp_packet): offset = datetime.timedelta(seconds=seconds + fraction / 2**32) return base_time + offset + async def main(): print("Our clock currently reads (in UTC):", datetime.datetime.utcnow()) # Look up some random NTP servers. # (See www.pool.ntp.org for information about the NTP pool.) servers = await trio.socket.getaddrinfo( - "pool.ntp.org", # host - "ntp", # port - family=trio.socket.AF_INET, # IPv4 + "pool.ntp.org", # host + "ntp", # port + family=trio.socket.AF_INET, # IPv4 type=trio.socket.SOCK_DGRAM, # UDP ) @@ -66,7 +70,7 @@ async def main(): # Create a UDP socket udp_sock = trio.socket.socket( - family=trio.socket.AF_INET, # IPv4 + family=trio.socket.AF_INET, # IPv4 type=trio.socket.SOCK_DGRAM, # UDP ) @@ -88,4 +92,5 @@ async def main(): transmit_timestamp = extract_transmit_timestamp(data) print("Their clock read (in UTC):", transmit_timestamp) + trio.run(main) diff --git a/notes-to-self/proxy-benchmarks.py b/notes-to-self/proxy-benchmarks.py index a45d94d056..ea92e10c6f 100644 --- a/notes-to-self/proxy-benchmarks.py +++ b/notes-to-self/proxy-benchmarks.py @@ -3,6 +3,7 @@ methods = {"fileno"} + class Proxy1: strategy = "__getattr__" works_for = "any attr" @@ -15,8 +16,10 @@ def __getattr__(self, name): return getattr(self._wrapped, name) raise AttributeError(name) + ################################################################ + class Proxy2: strategy = "generated methods (getattr + closure)" works_for = "methods" @@ -24,16 +27,20 @@ class Proxy2: def __init__(self, wrapped): self._wrapped = wrapped + def add_wrapper(cls, method): def wrapper(self, *args, **kwargs): return getattr(self._wrapped, method)(*args, **kwargs) + setattr(cls, method, wrapper) + for method in methods: add_wrapper(Proxy2, method) ################################################################ + class Proxy3: strategy = "generated methods (exec)" works_for = "methods" @@ -41,20 +48,27 @@ class Proxy3: def __init__(self, wrapped): self._wrapped = wrapped + def add_wrapper(cls, method): - code = textwrap.dedent(""" + code = textwrap.dedent( + """ def wrapper(self, *args, **kwargs): return self._wrapped.{}(*args, **kwargs) - """.format(method)) + """.format( + method + ) + ) ns = {} exec(code, ns) setattr(cls, method, ns["wrapper"]) + for method in methods: add_wrapper(Proxy3, method) ################################################################ + class Proxy4: strategy = "generated properties (getattr + closure)" works_for = "any attr" @@ -62,6 +76,7 @@ class Proxy4: def __init__(self, wrapped): self._wrapped = wrapped + def add_wrapper(cls, attr): def getter(self): return getattr(self._wrapped, attr) @@ -74,11 +89,13 @@ def deleter(self): setattr(cls, attr, property(getter, setter, deleter)) + for method in methods: add_wrapper(Proxy4, method) ################################################################ + class Proxy5: strategy = "generated properties (exec)" works_for = "any attr" @@ -86,8 +103,10 @@ class Proxy5: def __init__(self, wrapped): self._wrapped = wrapped + def add_wrapper(cls, attr): - code = textwrap.dedent(""" + code = textwrap.dedent( + """ def getter(self): return self._wrapped.{attr} @@ -96,16 +115,21 @@ def setter(self, newval): def deleter(self): del self._wrapped.{attr} - """.format(attr=attr)) + """.format( + attr=attr + ) + ) ns = {} exec(code, ns) setattr(cls, attr, property(ns["getter"], ns["setter"], ns["deleter"])) + for method in methods: add_wrapper(Proxy5, method) ################################################################ + # methods only class Proxy6: strategy = "copy attrs from wrappee to wrapper" @@ -116,17 +140,19 @@ def __init__(self, wrapper): for method in methods: setattr(self, method, getattr(self._wrapper, method)) - + ################################################################ classes = [Proxy1, Proxy2, Proxy3, Proxy4, Proxy5, Proxy6] + def check(cls): with open("/etc/passwd") as f: p = cls(f) assert p.fileno() == f.fileno() + for cls in classes: check(cls) @@ -135,7 +161,7 @@ def check(cls): COUNT = 1000000 try: - import __pypy__ + import __pypy__ # noqa: F401 # __pypy__ imported but unused except ImportError: pass else: @@ -147,8 +173,7 @@ def check(cls): start = time.perf_counter() for _ in range(COUNT): obj.fileno() - #obj.fileno + # obj.fileno end = time.perf_counter() per_usec = COUNT / (end - start) / 1e6 - print("{:7.2f} / us: {} ({})" - .format(per_usec, obj.strategy, obj.works_for)) + print("{:7.2f} / us: {} ({})".format(per_usec, obj.strategy, obj.works_for)) diff --git a/notes-to-self/reopen-pipe.py b/notes-to-self/reopen-pipe.py index 910def397c..dbccd567d7 100644 --- a/notes-to-self/reopen-pipe.py +++ b/notes-to-self/reopen-pipe.py @@ -1,14 +1,15 @@ import os +import tempfile import threading import time -import tempfile + def check_reopen(r1, w): try: print("Reopening read end") - r2 = os.open("/proc/self/fd/{}".format(r1), os.O_RDONLY) + r2 = os.open(f"/proc/self/fd/{r1}", os.O_RDONLY) - print("r1 is {}, r2 is {}".format(r1, r2)) + print(f"r1 is {r1}, r2 is {r2}") print("checking they both can receive from w...") @@ -36,11 +37,12 @@ def check_reopen(r1, w): def sleep_then_write(): time.sleep(1) os.write(w, b"c") + threading.Thread(target=sleep_then_write, daemon=True).start() assert os.read(r1, 1) == b"c" print("r1 definitely seems to be in blocking mode") except Exception as exc: - print("ERROR: {!r}".format(exc)) + print(f"ERROR: {exc!r}") print("-- testing anonymous pipe --") @@ -63,6 +65,6 @@ def sleep_then_write(): print("-- testing socketpair --") import socket + rs, ws = socket.socketpair() check_reopen(rs.fileno(), ws.fileno()) - diff --git a/notes-to-self/schedule-timing.py b/notes-to-self/schedule-timing.py index c3093066e2..c84ec9a436 100644 --- a/notes-to-self/schedule-timing.py +++ b/notes-to-self/schedule-timing.py @@ -1,19 +1,22 @@ -import trio import time +import trio + LOOPS = 0 RUNNING = True + async def reschedule_loop(depth): if depth == 0: global LOOPS while RUNNING: LOOPS += 1 await trio.sleep(0) - #await trio.lowlevel.cancel_shielded_checkpoint() + # await trio.lowlevel.cancel_shielded_checkpoint() else: await reschedule_loop(depth - 1) + async def report_loop(): global RUNNING try: @@ -25,13 +28,15 @@ async def report_loop(): end_count = LOOPS loops = end_count - start_count duration = end_time - start_time - print("{} loops/sec".format(loops / duration)) + print(f"{loops / duration} loops/sec") finally: RUNNING = False + async def main(): async with trio.open_nursery() as nursery: nursery.start_soon(reschedule_loop, 10) nursery.start_soon(report_loop) + trio.run(main) diff --git a/notes-to-self/socket-scaling.py b/notes-to-self/socket-scaling.py index 1571be4d17..bd7e32ef7f 100644 --- a/notes-to-self/socket-scaling.py +++ b/notes-to-self/socket-scaling.py @@ -17,13 +17,16 @@ # # or similar. +import socket import time + import trio import trio.testing -import socket + async def main(): for total in [10, 100, 500, 1_000, 10_000, 20_000, 30_000]: + def pt(desc, *, count=total, item="socket"): nonlocal last_time now = time.perf_counter() @@ -53,4 +56,5 @@ def pt(desc, *, count=total, item="socket"): sock.close() pt("closing sockets") + trio.run(main) diff --git a/notes-to-self/socketpair-buffering.py b/notes-to-self/socketpair-buffering.py index dd3b1ad97d..5e77a709b7 100644 --- a/notes-to-self/socketpair-buffering.py +++ b/notes-to-self/socketpair-buffering.py @@ -32,6 +32,6 @@ except BlockingIOError: pass - print("setsockopt bufsize {}: {}".format(bufsize, i)) + print(f"setsockopt bufsize {bufsize}: {i}") a.close() b.close() diff --git a/notes-to-self/ssl-close-notify/ssl-close-notify.py b/notes-to-self/ssl-close-notify/ssl-close-notify.py index cd4b450de8..32ecbea2f0 100644 --- a/notes-to-self/ssl-close-notify/ssl-close-notify.py +++ b/notes-to-self/ssl-close-notify/ssl-close-notify.py @@ -22,6 +22,7 @@ client_done = threading.Event() + def server_thread_fn(): server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) server_ctx.load_cert_chain("trio-test-1.pem") @@ -42,13 +43,12 @@ def server_thread_fn(): break server.sendall(data) + server_thread = threading.Thread(target=server_thread_fn) server_thread.start() client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") -client = client_ctx.wrap_socket( - client_sock, - server_hostname="trio-test-1.example.org") +client = client_ctx.wrap_socket(client_sock, server_hostname="trio-test-1.example.org") # Now we have two SSLSockets that have established an encrypted connection diff --git a/notes-to-self/ssl-close-notify/ssl2.py b/notes-to-self/ssl-close-notify/ssl2.py index 32a68e1495..54ee1fb9b6 100644 --- a/notes-to-self/ssl-close-notify/ssl2.py +++ b/notes-to-self/ssl-close-notify/ssl2.py @@ -5,7 +5,7 @@ import ssl import threading -#client_sock, server_sock = socket.socketpair() +# client_sock, server_sock = socket.socketpair() listen_sock = socket.socket() listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(1) @@ -52,12 +52,12 @@ server.shutdown(socket.SHUT_WR) # Attempting to read/write to the fd after it's closed should raise EBADF -#os.close(server.fileno()) +# os.close(server.fileno()) # Attempting to read/write to an fd opened with O_DIRECT raises EINVAL in most # cases (unless you're very careful with alignment etc. which openssl isn't) -#os.dup2(os.open("/tmp/blah-example-file", os.O_RDWR | os.O_CREAT | os.O_DIRECT), server.fileno()) +# os.dup2(os.open("/tmp/blah-example-file", os.O_RDWR | os.O_CREAT | os.O_DIRECT), server.fileno()) # Sending or receiving server.sendall(b"hello") -#server.recv(10) +# server.recv(10) diff --git a/notes-to-self/ssl-handshake/ssl-handshake.py b/notes-to-self/ssl-handshake/ssl-handshake.py index 81d875be6a..e906bc2a87 100644 --- a/notes-to-self/ssl-handshake/ssl-handshake.py +++ b/notes-to-self/ssl-handshake/ssl-handshake.py @@ -1,5 +1,5 @@ -import ssl import socket +import ssl import threading from contextlib import contextmanager @@ -8,6 +8,7 @@ server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) server_ctx.load_cert_chain("trio-test-1.pem") + def _ssl_echo_serve_sync(sock): try: wrapped = server_ctx.wrap_socket(sock, server_side=True) @@ -20,16 +21,19 @@ def _ssl_echo_serve_sync(sock): except BrokenPipeError: pass + @contextmanager def echo_server_connection(): client_sock, server_sock = socket.socketpair() with client_sock, server_sock: t = threading.Thread( - target=_ssl_echo_serve_sync, args=(server_sock,), daemon=True) + target=_ssl_echo_serve_sync, args=(server_sock,), daemon=True + ) t.start() yield client_sock + class ManuallyWrappedSocket: def __init__(self, ctx, sock, **kwargs): self.incoming = ssl.MemoryBIO() @@ -82,21 +86,23 @@ def unwrap(self): def wrap_socket_via_wrap_socket(ctx, sock, **kwargs): return ctx.wrap_socket(sock, do_handshake_on_connect=False, **kwargs) + def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): return ManuallyWrappedSocket(ctx, sock, **kwargs) for wrap_socket in [ - wrap_socket_via_wrap_socket, - wrap_socket_via_wrap_bio, + wrap_socket_via_wrap_socket, + wrap_socket_via_wrap_bio, ]: - print("\n--- checking {} ---\n".format(wrap_socket.__name__)) + print(f"\n--- checking {wrap_socket.__name__} ---\n") print("checking with do_handshake + correct hostname...") with echo_server_connection() as client_sock: client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-1.example.org") + client_ctx, client_sock, server_hostname="trio-test-1.example.org" + ) wrapped.do_handshake() wrapped.sendall(b"x") assert wrapped.recv(1) == b"x" @@ -107,7 +113,8 @@ def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): with echo_server_connection() as client_sock: client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-2.example.org") + client_ctx, client_sock, server_hostname="trio-test-2.example.org" + ) try: wrapped.do_handshake() except Exception: @@ -119,7 +126,8 @@ def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): with echo_server_connection() as client_sock: client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-2.example.org") + client_ctx, client_sock, server_hostname="trio-test-2.example.org" + ) # We forgot to call do_handshake # But the hostname is wrong so something had better error out... sent = b"x" diff --git a/notes-to-self/sslobject.py b/notes-to-self/sslobject.py index cfac98676e..a6e7b07a08 100644 --- a/notes-to-self/sslobject.py +++ b/notes-to-self/sslobject.py @@ -1,5 +1,5 @@ -from contextlib import contextmanager import ssl +from contextlib import contextmanager client_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) client_ctx.check_hostname = False @@ -15,6 +15,7 @@ soutb = ssl.MemoryBIO() sso = server_ctx.wrap_bio(sinb, soutb, server_side=True) + @contextmanager def expect(etype): try: @@ -22,7 +23,8 @@ def expect(etype): except etype: pass else: - raise AssertionError("expected {}".format(etype)) + raise AssertionError(f"expected {etype}") + with expect(ssl.SSLWantReadError): cso.do_handshake() diff --git a/notes-to-self/thread-closure-bug-demo.py b/notes-to-self/thread-closure-bug-demo.py index 514636a1b4..b09a87fe5f 100644 --- a/notes-to-self/thread-closure-bug-demo.py +++ b/notes-to-self/thread-closure-bug-demo.py @@ -8,18 +8,21 @@ COUNT = 100 + def slow_tracefunc(frame, event, arg): # A no-op trace function that sleeps briefly to make us more likely to hit # the race condition. time.sleep(0.01) return slow_tracefunc + def run_with_slow_tracefunc(fn): # settrace() only takes effect when you enter a new frame, so we need this # little dance: sys.settrace(slow_tracefunc) return fn() + def outer(): x = 0 # We hide the done variable inside a list, because we want to use it to @@ -46,13 +49,14 @@ def traced_looper(): t.start() for i in range(COUNT): - print("after {} increments, x is {}".format(i, x)) + print(f"after {i} increments, x is {x}") x += 1 time.sleep(0.01) done[0] = True t.join() - print("Final discrepancy: {} (should be 0)".format(COUNT - x)) + print(f"Final discrepancy: {COUNT - x} (should be 0)") + outer() diff --git a/notes-to-self/thread-dispatch-bench.py b/notes-to-self/thread-dispatch-bench.py index 1625efae17..70547a6000 100644 --- a/notes-to-self/thread-dispatch-bench.py +++ b/notes-to-self/thread-dispatch-bench.py @@ -5,16 +5,18 @@ # trio.to_thread.run_sync import threading -from queue import Queue import time +from queue import Queue COUNT = 10000 + def worker(in_q, out_q): while True: job = in_q.get() out_q.put(job()) + def main(): in_q = Queue() out_q = Queue() @@ -28,6 +30,7 @@ def main(): in_q.put(lambda: None) out_q.get() end = time.monotonic() - print("{:.2f} µs/job".format((end - start) / COUNT * 1e6)) + print(f"{(end - start) / COUNT * 1e6:.2f} µs/job") + main() diff --git a/notes-to-self/time-wait-windows-exclusiveaddruse.py b/notes-to-self/time-wait-windows-exclusiveaddruse.py index db3aaad08a..dcb4a27dd0 100644 --- a/notes-to-self/time-wait-windows-exclusiveaddruse.py +++ b/notes-to-self/time-wait-windows-exclusiveaddruse.py @@ -8,15 +8,17 @@ import socket from contextlib import contextmanager + @contextmanager def report_outcome(tagline): try: yield except OSError as exc: - print("{}: failed".format(tagline)) - print(" details: {!r}".format(exc)) + print(f"{tagline}: failed") + print(f" details: {exc!r}") else: - print("{}: succeeded".format(tagline)) + print(f"{tagline}: succeeded") + # Set up initial listening socket lsock = socket.socket() diff --git a/notes-to-self/time-wait.py b/notes-to-self/time-wait.py index e865a94982..772f6c2727 100644 --- a/notes-to-self/time-wait.py +++ b/notes-to-self/time-wait.py @@ -26,11 +26,12 @@ # Also, it must be set on listen2 before calling bind(), or it will conflict # with the lingering server1 socket. -import socket import errno +import socket import attr + @attr.s(repr=False) class Options: listen1_early = attr.ib(default=None) @@ -49,9 +50,10 @@ def describe(self): for f in attr.fields(self.__class__): value = getattr(self, f.name) if value is not None: - info.append("{}={}".format(f.name, value)) + info.append(f"{f.name}={value}") return "Set/unset: {}".format(", ".join(info)) + def time_wait(options): print(options.describe()) @@ -60,7 +62,7 @@ def time_wait(options): listen0 = socket.socket() listen0.bind(("127.0.0.1", 0)) sockaddr = listen0.getsockname() - #print(" ", sockaddr) + # print(" ", sockaddr) listen0.close() listen1 = socket.socket() @@ -98,6 +100,7 @@ def time_wait(options): else: print(" -> ok") + time_wait(Options()) time_wait(Options(listen1_early=True, server=True, listen2=True)) time_wait(Options(listen1_early=True)) diff --git a/notes-to-self/trace.py b/notes-to-self/trace.py index c024a36ba5..aa68fac125 100644 --- a/notes-to-self/trace.py +++ b/notes-to-self/trace.py @@ -1,8 +1,9 @@ -import trio -import os import json +import os from itertools import count +import trio + # Experiment with generating Chrome Event Trace format, which can be browsed # through chrome://tracing or other mechanisms. # @@ -29,6 +30,7 @@ # let us also show "task is running", because neither kind of event is # strictly nested inside the other + class Trace(trio.abc.Instrument): def __init__(self, out): self.out = out @@ -108,14 +110,14 @@ def task_scheduled(self, task): def before_io_wait(self, timeout): self._write( - name=f"I/O wait", + name="I/O wait", ph="B", tid=-1, ) def after_io_wait(self, timeout): self._write( - name=f"I/O wait", + name="I/O wait", ph="E", tid=-1, ) @@ -126,11 +128,13 @@ async def child1(): await trio.sleep(1) print(" child1: exiting!") + async def child2(): print(" child2: started! sleeping now...") await trio.sleep(1) print(" child2: exiting!") + async def parent(): print("parent: started!") async with trio.open_nursery() as nursery: @@ -144,5 +148,6 @@ async def parent(): # -- we exit the nursery block here -- print("parent: all done!") + t = Trace(open("/tmp/t.json", "w")) trio.run(parent, instruments=[t]) diff --git a/notes-to-self/trivial-err.py b/notes-to-self/trivial-err.py index ed11ec33e6..6c32617c74 100644 --- a/notes-to-self/trivial-err.py +++ b/notes-to-self/trivial-err.py @@ -1,26 +1,33 @@ import sys + import trio sys.stderr = sys.stdout + async def child1(): raise ValueError + async def child2(): async with trio.open_nursery() as nursery: nursery.start_soon(grandchild1) nursery.start_soon(grandchild2) + async def grandchild1(): raise KeyError + async def grandchild2(): raise NameError("Bob") + async def main(): async with trio.open_nursery() as nursery: nursery.start_soon(child1) nursery.start_soon(child2) - #nursery.start_soon(grandchild1) + # nursery.start_soon(grandchild1) + trio.run(main) diff --git a/notes-to-self/trivial.py b/notes-to-self/trivial.py index 6852d63200..405d92daf5 100644 --- a/notes-to-self/trivial.py +++ b/notes-to-self/trivial.py @@ -1,8 +1,10 @@ import trio + async def foo(): print("in foo!") return 3 + print("running!") print(trio.run(foo)) diff --git a/notes-to-self/wakeup-fd-racer.py b/notes-to-self/wakeup-fd-racer.py index c6ef6caec1..b56cbdc91c 100644 --- a/notes-to-self/wakeup-fd-racer.py +++ b/notes-to-self/wakeup-fd-racer.py @@ -1,19 +1,21 @@ +import itertools import os +import select import signal +import socket import threading import time -import socket -import select -import itertools # Equivalent to the C function raise(), which Python doesn't wrap if os.name == "nt": import cffi + _ffi = cffi.FFI() _ffi.cdef("int raise(int);") _lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll") signal_raise = getattr(_lib, "raise") else: + def signal_raise(signum): # Use pthread_kill to make sure we're actually using the wakeup fd on # Unix @@ -26,7 +28,7 @@ def raise_SIGINT_soon(): # Sending 2 signals becomes reliable, as we'd expect (because we need # set-flags -> write-to-fd, and doing it twice does # write-to-fd -> set-flags -> write-to-fd -> set-flags) - #signal_raise(signal.SIGINT) + # signal_raise(signal.SIGINT) def drain(sock): @@ -87,8 +89,10 @@ def main(): # them. duration = time.perf_counter() - start if duration < 2: - print(f"Attempt {attempt}: OK, trying again " - f"(select_calls = {select_calls}, drained = {drained})") + print( + f"Attempt {attempt}: OK, trying again " + f"(select_calls = {select_calls}, drained = {drained})" + ) else: print(f"Attempt {attempt}: FAILED, took {duration} seconds") print(f"select_calls = {select_calls}, drained = {drained}") @@ -96,5 +100,6 @@ def main(): thread.join() + if __name__ == "__main__": main() diff --git a/notes-to-self/win-waitable-timer.py b/notes-to-self/win-waitable-timer.py index 92bfd7a39a..5309f43867 100644 --- a/notes-to-self/win-waitable-timer.py +++ b/notes-to-self/win-waitable-timer.py @@ -24,12 +24,12 @@ # make this fairly straightforward, but you obviously need to use a separate # time source -import cffi from datetime import datetime, timedelta, timezone -import time + +import cffi import trio -from trio._core._windows_cffi import (ffi, kernel32, raise_winerror) +from trio._core._windows_cffi import ffi, kernel32, raise_winerror try: ffi.cdef( @@ -91,7 +91,7 @@ LPFILETIME lpFileTime ); """, - override=True + override=True, ) ProcessLeapSecondInfo = 8 @@ -106,10 +106,10 @@ def set_leap_seconds_enabled(enabled): plsi.Flags = 0 plsi.Reserved = 0 if not kernel32.SetProcessInformation( - ffi.cast("HANDLE", -1), # current process - ProcessLeapSecondInfo, - plsi, - ffi.sizeof("PROCESS_LEAP_SECOND_INFO"), + ffi.cast("HANDLE", -1), # current process + ProcessLeapSecondInfo, + plsi, + ffi.sizeof("PROCESS_LEAP_SECOND_INFO"), ): raise_winerror() @@ -135,9 +135,7 @@ def now_as_filetime(): # https://www.epochconverter.com/ldap # FILETIME_TICKS_PER_SECOND = 10**7 -FILETIME_EPOCH = datetime.strptime( - '1601-01-01 00:00:00 Z', '%Y-%m-%d %H:%M:%S %z' -) +FILETIME_EPOCH = datetime.strptime("1601-01-01 00:00:00 Z", "%Y-%m-%d %H:%M:%S %z") # XXX THE ABOVE IS WRONG: # # https://techcommunity.microsoft.com/t5/networking-blog/leap-seconds-for-the-appdev-what-you-should-know/ba-p/339813# @@ -159,11 +157,9 @@ def now_as_filetime(): def py_datetime_to_win_filetime(dt): # We'll want to call this on every datetime as it comes in - #dt = dt.astimezone(timezone.utc) + # dt = dt.astimezone(timezone.utc) assert dt.tzinfo is timezone.utc - return round( - (dt - FILETIME_EPOCH).total_seconds() * FILETIME_TICKS_PER_SECOND - ) + return round((dt - FILETIME_EPOCH).total_seconds() * FILETIME_TICKS_PER_SECOND) async def main(): diff --git a/pyproject.toml b/pyproject.toml index 4b939510ef..17dd2aa1b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,160 @@ [tool.black] -target-version = ['py37'] +target-version = ['py38'] +force-exclude = ''' +( + ^/docs/source/reference-.* + | ^/docs/source/tutorial +) +''' +[tool.codespell] +ignore-words-list = 'astroid,crasher,asend' + +[tool.flake8] +extend-ignore = ['D', 'E', 'W', 'F403', 'F405', 'F821', 'F822'] +per-file-ignores = [ + 'trio/__init__.py: F401', + 'trio/_core/__init__.py: F401', + 'trio/_core/_tests/test_multierror_scripts/*: F401', + 'trio/abc.py: F401', + 'trio/lowlevel.py: F401', + 'trio/socket.py: F401', + 'trio/testing/__init__.py: F401' +] + +[tool.isort] +combine_as_imports = true +profile = "black" +skip_gitignore = true +skip_glob = [ + "docs/source/reference-*", + "docs/source/tutorial/*" +] + +[tool.mypy] +python_version = "3.8" + +# Be flexible about dependencies that don't have stubs yet (like pytest) +ignore_missing_imports = true + +# Be strict about use of Mypy +local_partial_types = true +warn_unused_ignores = true +warn_unused_configs = true +warn_redundant_casts = true +warn_return_any = true + +# Avoid subtle backsliding +disallow_any_decorated = true +disallow_any_generics = true +disallow_any_unimported = false # Enable once Outcome has stubs. +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_untyped_defs = true + +# Enable once other problems are dealt with +check_untyped_defs = true +disallow_untyped_calls = false + +# files not yet fully typed +[[tool.mypy.overrides]] +module = [ +# 2761 +"trio/_core/_generated_io_windows", +"trio/_core/_io_windows", + +"trio/_signals", + +# internal +"trio/_windows_pipes", + +# tests +"trio/testing/_fake_net", +"trio/_core/_tests/test_asyncgen", +"trio/_core/_tests/test_guest_mode", +"trio/_core/_tests/test_instrumentation", +"trio/_core/_tests/test_ki", +"trio/_core/_tests/test_local", +"trio/_core/_tests/test_mock_clock", +"trio/_core/_tests/test_multierror", +"trio/_core/_tests/test_multierror_scripts/ipython_custom_exc", +"trio/_core/_tests/test_multierror_scripts/simple_excepthook", +"trio/_core/_tests/test_parking_lot", +"trio/_core/_tests/test_thread_cache", +"trio/_core/_tests/test_tutil", +"trio/_core/_tests/test_unbounded_queue", +"trio/_core/_tests/test_windows", +"trio/_core/_tests/tutil", +"trio/_tests/pytest_plugin", +"trio/_tests/test_abc", +"trio/_tests/test_channel", +"trio/_tests/test_deprecate", +"trio/_tests/test_dtls", +"trio/_tests/test_exports", +"trio/_tests/test_file_io", +"trio/_tests/test_highlevel_generic", +"trio/_tests/test_highlevel_open_tcp_listeners", +"trio/_tests/test_highlevel_open_tcp_stream", +"trio/_tests/test_highlevel_open_unix_stream", +"trio/_tests/test_highlevel_serve_listeners", +"trio/_tests/test_highlevel_socket", +"trio/_tests/test_highlevel_ssl_helpers", +"trio/_tests/test_path", +"trio/_tests/test_scheduler_determinism", +"trio/_tests/test_signals", +"trio/_tests/test_socket", +"trio/_tests/test_ssl", +"trio/_tests/test_subprocess", +"trio/_tests/test_sync", +"trio/_tests/test_testing", +"trio/_tests/test_threads", +"trio/_tests/test_timeouts", +"trio/_tests/test_tracing", +"trio/_tests/test_util", +"trio/_tests/test_wait_for_object", +"trio/_tests/test_windows_pipes", +"trio/_tests/tools/test_gen_exports", +] +check_untyped_defs = false +disallow_any_decorated = false +disallow_any_generics = false +disallow_any_unimported = false +disallow_incomplete_defs = false +disallow_untyped_defs = false + +[tool.pytest.ini_options] +addopts = ["--strict-markers", "--strict-config"] +faulthandler_timeout = 60 +filterwarnings = [ + "error", + # https://gitter.im/python-trio/general?at=63bb8d0740557a3d5c688d67 + 'ignore:You are using cryptography on a 32-bit Python on a 64-bit Windows Operating System. Cryptography will be significantly faster if you switch to using a 64-bit Python.:UserWarning', + # this should remain until https://github.com/pytest-dev/pytest/pull/10894 is merged + 'ignore:ast.Str is deprecated:DeprecationWarning', + 'ignore:Attribute s is deprecated and will be removed:DeprecationWarning', + 'ignore:ast.NameConstant is deprecated:DeprecationWarning', + 'ignore:ast.Num is deprecated:DeprecationWarning', + # https://github.com/python/mypy/issues/15330 + 'ignore:ast.Ellipsis is deprecated:DeprecationWarning', + 'ignore:ast.Bytes is deprecated:DeprecationWarning' +] +junit_family = "xunit2" +markers = ["redistributors_should_skip: tests that should be skipped by downstream redistributors"] +xfail_strict = true [tool.towncrier] +directory = "newsfragments" +filename = "docs/source/history.rst" +issue_format = "`#{issue} `__" # Usage: # - PRs should drop a file like "issuenumber.feature" in newsfragments -# (or "bugfix", "doc", "removal", "misc"; misc gets no text, we can -# customize this) +# (or "bugfix", "doc", "removal", "misc"; misc gets no text, we can +# customize this) # - At release time after bumping version number, run: towncrier -# (or towncrier --draft) +# (or towncrier --draft) package = "trio" -filename = "docs/source/history.rst" -directory = "newsfragments" underlines = ["-", "~", "^"] -issue_format = "`#{issue} `__" [[tool.towncrier.type]] directory = "headline" @@ -49,15 +190,3 @@ showcontent = true directory = "misc" name = "Miscellaneous internal changes" showcontent = true - -[tool.pytest.ini_options] -addopts = ["--strict-markers", "--strict-config"] -xfail_strict = true -faulthandler_timeout = 60 -markers = ["redistributors_should_skip: tests that should be skipped by downstream redistributors"] -junit_family = "xunit2" -filterwarnings = [ - "error", - # https://gitter.im/python-trio/general?at=63bb8d0740557a3d5c688d67 - 'ignore:You are using cryptography on a 32-bit Python on a 64-bit Windows Operating System. Cryptography will be significantly faster if you switch to using a 64-bit Python.:UserWarning', -] diff --git a/setup.py b/setup.py index a8e1154dc6..dbce61c0fd 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup exec(open("trio/_version.py", encoding="utf-8").read()) @@ -44,7 +44,7 @@ Vital statistics: * Supported environments: Linux, macOS, or Windows running some kind of Python - 3.7-or-better (either CPython or PyPy3 is fine). \\*BSD and illumos likely + 3.8-or-better (either CPython or PyPy3 is fine). \\*BSD and illumos likely work too, but are not tested. * Install: ``python3 -m pip install -U trio`` (or on Windows, maybe @@ -52,6 +52,8 @@ * Tutorial and reference manual: https://trio.readthedocs.io +* Changelog: https://trio.readthedocs.io/en/latest/history.html + * Bug tracker and source code: https://github.com/python-trio/trio * Real-time chat: https://gitter.im/python-trio/general @@ -73,6 +75,7 @@ version=__version__, description="A friendly Python library for async concurrency and I/O", long_description=LONG_DESC, + long_description_content_type="text/x-rst", author="Nathaniel J. Smith", author_email="njs@pobox.com", url="https://github.com/python-trio/trio", @@ -85,7 +88,7 @@ "sortedcontainers", "idna", "outcome", - "sniffio", + "sniffio >= 1.3.0", # cffi 1.12 adds from_buffer(require_writable=True) and ffi.release() # cffi 1.14 fixes memory leak inside ffi.getwinerror() # cffi is required on Windows, except on PyPy where it is built-in @@ -95,7 +98,7 @@ # This means, just install *everything* you see under trio/, even if it # doesn't look like a source file, so long as it appears in MANIFEST.in: include_package_data=True, - python_requires=">=3.7", + python_requires=">=3.8", keywords=["async", "io", "networking", "trio"], classifiers=[ "Development Status :: 3 - Alpha", @@ -109,12 +112,16 @@ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: System :: Networking", "Framework :: Trio", ], + project_urls={ + "Documentation": "https://trio.readthedocs.io/", + "Changelog": "https://trio.readthedocs.io/en/latest/history.html", + }, ) diff --git a/test-requirements.in b/test-requirements.in index eda15ef4b3..1911b1bf11 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -1,28 +1,28 @@ # For tests pytest >= 5.0 # for faulthandler in core -pytest-cov >= 2.6.0 +coverage >= 7.2.5 async_generator >= 1.9 -# ipython 7.x is the last major version supporting Python 3.7 -ipython < 7.35 # for the IPython traceback integration tests +pyright +ipython # for the IPython traceback integration tests pyOpenSSL >= 22.0.0 # for the ssl + DTLS tests trustme # for the ssl + DTLS tests pylint # for pylint finding all symbols tests jedi # for jedi code completion tests -cryptography>=36.0.0 # 35.0.0 is transitive but fails +cryptography>=41.0.0 # cryptography<41 segfaults on pypy3.10 # Tools black; implementation_name == "cpython" mypy; implementation_name == "cpython" -types-pyOpenSSL; implementation_name == "cpython" +types-pyOpenSSL; implementation_name == "cpython" # and annotations flake8 +flake8-pyproject astor # code generation -pip-tools +pip-tools >= 6.13.0 +codespell # https://github.com/python-trio/trio/pull/654#issuecomment-420518745 -# typed_ast is deprecated as of 3.8, and straight up doesn't compile on 3.10-dev as of 2021-12-13 -typed_ast; implementation_name == "cpython" and python_version < "3.8" mypy-extensions; implementation_name == "cpython" -typing-extensions; implementation_name == "cpython" +typing-extensions # Trio's own dependencies cffi; os_name == "nt" diff --git a/test-requirements.txt b/test-requirements.txt index cc80d23675..86a8f14aee 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -6,30 +6,33 @@ # astor==0.8.1 # via -r test-requirements.in -astroid==2.15.0 +astroid==2.15.6 # via pylint +asttokens==2.2.1 + # via stack-data async-generator==1.10 # via -r test-requirements.in -attrs==22.2.0 +attrs==23.1.0 # via # -r test-requirements.in # outcome - # pytest backcall==0.2.0 # via ipython -black==23.1.0 ; implementation_name == "cpython" +black==23.7.0 ; implementation_name == "cpython" # via -r test-requirements.in build==0.10.0 # via pip-tools cffi==1.15.1 # via cryptography -click==8.1.3 +click==8.1.7 # via # black # pip-tools -coverage[toml]==6.4.1 - # via pytest-cov -cryptography==39.0.2 +codespell==2.2.5 + # via -r test-requirements.in +coverage==7.3.0 + # via -r test-requirements.in +cryptography==41.0.3 # via # -r test-requirements.in # pyopenssl @@ -37,13 +40,19 @@ cryptography==39.0.2 # types-pyopenssl decorator==5.1.1 # via ipython -dill==0.3.6 +dill==0.3.7 # via pylint -exceptiongroup==1.1.0 ; python_version < "3.11" +exceptiongroup==1.1.3 ; python_version < "3.11" # via # -r test-requirements.in # pytest -flake8==4.0.1 +executing==1.2.0 + # via stack-data +flake8==6.1.0 + # via + # -r test-requirements.in + # flake8-pyproject +flake8-pyproject==1.2.3 # via -r test-requirements.in idna==3.4 # via @@ -51,11 +60,11 @@ idna==3.4 # trustme iniconfig==2.0.0 # via pytest -ipython==7.34.0 +ipython==8.12.2 # via -r test-requirements.in -isort==5.10.1 +isort==5.12.0 # via pylint -jedi==0.18.2 +jedi==0.19.0 # via # -r test-requirements.in # ipython @@ -63,96 +72,105 @@ lazy-object-proxy==1.9.0 # via astroid matplotlib-inline==0.1.6 # via ipython -mccabe==0.6.1 +mccabe==0.7.0 # via # flake8 # pylint -mypy==1.1.1 ; implementation_name == "cpython" +mypy==1.5.1 ; implementation_name == "cpython" # via -r test-requirements.in mypy-extensions==1.0.0 ; implementation_name == "cpython" # via # -r test-requirements.in # black # mypy +nodeenv==1.8.0 + # via pyright outcome==1.2.0 # via -r test-requirements.in -packaging==23.0 +packaging==23.1 # via # black # build # pytest parso==0.8.3 # via jedi -pathspec==0.11.0 +pathspec==0.11.2 # via black pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -pip-tools==6.12.3 +pip-tools==7.3.0 # via -r test-requirements.in -platformdirs==3.1.0 +platformdirs==3.10.0 # via # black # pylint -pluggy==1.0.0 +pluggy==1.3.0 # via pytest -prompt-toolkit==3.0.38 +prompt-toolkit==3.0.39 # via ipython ptyprocess==0.7.0 # via pexpect -pycodestyle==2.8.0 +pure-eval==0.2.2 + # via stack-data +pycodestyle==2.11.0 # via flake8 pycparser==2.21 # via cffi -pyflakes==2.4.0 +pyflakes==3.1.0 # via flake8 -pygments==2.14.0 +pygments==2.16.1 # via ipython -pylint==2.17.0 +pylint==2.17.5 # via -r test-requirements.in -pyopenssl==23.0.0 +pyopenssl==23.2.0 # via -r test-requirements.in pyproject-hooks==1.0.0 # via build -pytest==7.2.2 - # via - # -r test-requirements.in - # pytest-cov -pytest-cov==4.0.0 +pyright==1.1.325 + # via -r test-requirements.in +pytest==7.4.0 # via -r test-requirements.in +six==1.16.0 + # via asttokens sniffio==1.3.0 # via -r test-requirements.in sortedcontainers==2.4.0 # via -r test-requirements.in +stack-data==0.6.2 + # via ipython tomli==2.0.1 # via # black # build - # coverage + # flake8-pyproject # mypy + # pip-tools # pylint + # pyproject-hooks # pytest -tomlkit==0.11.6 +tomlkit==0.12.1 # via pylint traitlets==5.9.0 # via # ipython # matplotlib-inline -trustme==0.9.0 +trustme==1.1.0 # via -r test-requirements.in -types-pyopenssl==23.0.0.4 ; implementation_name == "cpython" +types-pyopenssl==23.2.0.2 ; implementation_name == "cpython" # via -r test-requirements.in -typing-extensions==4.5.0 ; implementation_name == "cpython" +typing-extensions==4.7.1 # via # -r test-requirements.in # astroid # black + # ipython # mypy # pylint wcwidth==0.2.6 # via prompt-toolkit -wheel==0.38.4 +wheel==0.41.2 # via pip-tools wrapt==1.15.0 # via astroid diff --git a/trio/__init__.py b/trio/__init__.py index d6d2adb4bb..be7de42cde 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -1,5 +1,6 @@ """Trio - A friendly Python library for async concurrency and I/O """ +from __future__ import annotations # General layout: # @@ -12,105 +13,112 @@ # # This file pulls together the friendly public API, by re-exporting the more # innocuous bits of the _core API + the higher-level tools from trio/*.py. +# +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) -from ._version import __version__ +# must be imported early to avoid circular import +from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: split +# Submodules imported by default +from . import abc, from_thread, lowlevel, socket, to_thread +from ._channel import ( + MemoryReceiveChannel as MemoryReceiveChannel, + MemorySendChannel as MemorySendChannel, + open_memory_channel as open_memory_channel, +) from ._core import ( - TrioInternalError, - RunFinishedError, - WouldBlock, - Cancelled, - BusyResourceError, - ClosedResourceError, - run, - open_nursery, - CancelScope, - current_effective_deadline, - TASK_STATUS_IGNORED, - current_time, - BrokenResourceError, - EndOfChannel, - Nursery, + BrokenResourceError as BrokenResourceError, + BusyResourceError as BusyResourceError, + Cancelled as Cancelled, + CancelScope as CancelScope, + ClosedResourceError as ClosedResourceError, + EndOfChannel as EndOfChannel, + Nursery as Nursery, + RunFinishedError as RunFinishedError, + TaskStatus as TaskStatus, + TrioInternalError as TrioInternalError, + WouldBlock as WouldBlock, + current_effective_deadline as current_effective_deadline, + current_time as current_time, + open_nursery as open_nursery, + run as run, ) - -from ._timeouts import ( - move_on_at, - move_on_after, - sleep_forever, - sleep_until, - sleep, - fail_at, - fail_after, - TooSlowError, +from ._core._multierror import ( + MultiError as _MultiError, + NonBaseMultiError as _NonBaseMultiError, ) - -from ._sync import ( - Event, - CapacityLimiter, - Semaphore, - Lock, - StrictFIFOLock, - Condition, +from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning +from ._dtls import ( + DTLSChannel as DTLSChannel, + DTLSChannelStatistics as DTLSChannelStatistics, + DTLSEndpoint as DTLSEndpoint, ) - -from ._highlevel_generic import aclose_forcefully, StapledStream - -from ._channel import ( - open_memory_channel, - MemorySendChannel, - MemoryReceiveChannel, +from ._file_io import open_file as open_file, wrap_file as wrap_file +from ._highlevel_generic import ( + StapledStream as StapledStream, + aclose_forcefully as aclose_forcefully, +) +from ._highlevel_open_tcp_listeners import ( + open_tcp_listeners as open_tcp_listeners, + serve_tcp as serve_tcp, +) +from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream +from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket +from ._highlevel_serve_listeners import serve_listeners as serve_listeners +from ._highlevel_socket import ( + SocketListener as SocketListener, + SocketStream as SocketStream, ) - -from ._signals import open_signal_receiver - -from ._highlevel_socket import SocketStream, SocketListener - -from ._file_io import open_file, wrap_file - -from ._path import Path - -from ._subprocess import Process, run_process - -from ._ssl import SSLStream, SSLListener, NeedHandshakeError - -from ._dtls import DTLSEndpoint, DTLSChannel - -from ._highlevel_serve_listeners import serve_listeners - -from ._highlevel_open_tcp_stream import open_tcp_stream - -from ._highlevel_open_tcp_listeners import open_tcp_listeners, serve_tcp - -from ._highlevel_open_unix_stream import open_unix_socket - from ._highlevel_ssl_helpers import ( - open_ssl_over_tcp_stream, - open_ssl_over_tcp_listeners, - serve_ssl_over_tcp, + open_ssl_over_tcp_listeners as open_ssl_over_tcp_listeners, + open_ssl_over_tcp_stream as open_ssl_over_tcp_stream, + serve_ssl_over_tcp as serve_ssl_over_tcp, +) +from ._path import Path as Path +from ._signals import open_signal_receiver as open_signal_receiver +from ._ssl import ( + NeedHandshakeError as NeedHandshakeError, + SSLListener as SSLListener, + SSLStream as SSLStream, +) +from ._subprocess import Process as Process, run_process as run_process +from ._sync import ( + CapacityLimiter as CapacityLimiter, + CapacityLimiterStatistics as CapacityLimiterStatistics, + Condition as Condition, + ConditionStatistics as ConditionStatistics, + Event as Event, + EventStatistics as EventStatistics, + Lock as Lock, + LockStatistics as LockStatistics, + Semaphore as Semaphore, + StrictFIFOLock as StrictFIFOLock, +) +from ._timeouts import ( + TooSlowError as TooSlowError, + fail_after as fail_after, + fail_at as fail_at, + move_on_after as move_on_after, + move_on_at as move_on_at, + sleep as sleep, + sleep_forever as sleep_forever, + sleep_until as sleep_until, ) -from ._core._multierror import MultiError as _MultiError -from ._core._multierror import NonBaseMultiError as _NonBaseMultiError - -from ._deprecate import TrioDeprecationWarning - -# Submodules imported by default -from . import lowlevel -from . import socket -from . import abc -from . import from_thread -from . import to_thread +# pyright explicitly does not care about `__version__` +# see https://github.com/microsoft/pyright/blob/main/docs/typed-libraries.md#type-completeness +from ._version import __version__ # Not imported by default, but mentioned here so static analysis tools like # pylint will know that it exists. if False: from . import testing -from . import _deprecate +from . import _deprecate as _deprecate _deprecate.enable_attribute_deprecations(__name__) -__deprecated_attributes__ = { +__deprecated_attributes__: dict[str, _deprecate.DeprecatedAttribute] = { "open_process": _deprecate.DeprecatedAttribute( value=lowlevel.open_process, version="0.20.0", diff --git a/trio/_abc.py b/trio/_abc.py index c085c82b89..746360c8f8 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,7 +1,20 @@ +from __future__ import annotations + +import socket from abc import ABCMeta, abstractmethod -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar + import trio +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import Self + + # both of these introduce circular imports if outside a TYPE_CHECKING guard + from ._socket import _SocketType + from .lowlevel import Task + # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a # __dict__ onto subclasses. @@ -11,7 +24,7 @@ class Clock(metaclass=ABCMeta): __slots__ = () @abstractmethod - def start_clock(self): + def start_clock(self) -> None: """Do any setup this clock might need. Called at the beginning of the run. @@ -19,7 +32,7 @@ def start_clock(self): """ @abstractmethod - def current_time(self): + def current_time(self) -> float: """Return the current time, according to this clock. This is used to implement functions like :func:`trio.current_time` and @@ -31,7 +44,7 @@ def current_time(self): """ @abstractmethod - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: """Compute the real time until the given deadline. This is called before we enter a system-specific wait function like @@ -65,13 +78,13 @@ class Instrument(metaclass=ABCMeta): __slots__ = () - def before_run(self): + def before_run(self) -> None: """Called at the beginning of :func:`trio.run`.""" - def after_run(self): + def after_run(self) -> None: """Called just before :func:`trio.run` returns.""" - def task_spawned(self, task): + def task_spawned(self, task: Task) -> None: """Called when the given task is created. Args: @@ -79,7 +92,7 @@ def task_spawned(self, task): """ - def task_scheduled(self, task): + def task_scheduled(self, task: Task) -> None: """Called when the given task becomes runnable. It may still be some time before it actually runs, if there are other @@ -90,7 +103,7 @@ def task_scheduled(self, task): """ - def before_task_step(self, task): + def before_task_step(self, task: Task) -> None: """Called immediately before we resume running the given task. Args: @@ -98,7 +111,7 @@ def before_task_step(self, task): """ - def after_task_step(self, task): + def after_task_step(self, task: Task) -> None: """Called when we return to the main run loop after a task has yielded. Args: @@ -106,7 +119,7 @@ def after_task_step(self, task): """ - def task_exited(self, task): + def task_exited(self, task: Task) -> None: """Called when the given task exits. Args: @@ -114,7 +127,7 @@ def task_exited(self, task): """ - def before_io_wait(self, timeout): + def before_io_wait(self, timeout: float) -> None: """Called before blocking to wait for I/O readiness. Args: @@ -122,7 +135,7 @@ def before_io_wait(self, timeout): """ - def after_io_wait(self, timeout): + def after_io_wait(self, timeout: float) -> None: """Called after handling pending I/O. Args: @@ -144,7 +157,23 @@ class HostnameResolver(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: """A custom implementation of :func:`~trio.socket.getaddrinfo`. Called by :func:`trio.socket.getaddrinfo`. @@ -161,7 +190,9 @@ async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): """ @abstractmethod - async def getnameinfo(self, sockaddr, flags): + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: """A custom implementation of :func:`~trio.socket.getnameinfo`. Called by :func:`trio.socket.getnameinfo`. @@ -178,7 +209,12 @@ class SocketFactory(metaclass=ABCMeta): """ @abstractmethod - def socket(self, family=None, type=None, proto=None): + def socket( + self, + family: socket.AddressFamily | int | None = None, + type: socket.SocketKind | int | None = None, + proto: int | None = None, + ) -> _SocketType: """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, @@ -224,7 +260,7 @@ class AsyncResource(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def aclose(self): + async def aclose(self) -> None: """Close this resource, possibly blocking. IMPORTANT: This method may block in order to perform a "graceful" @@ -252,10 +288,15 @@ async def aclose(self): """ - async def __aenter__(self): + async def __aenter__(self) -> Self: return self - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: await self.aclose() @@ -278,7 +319,7 @@ class SendStream(AsyncResource): __slots__ = () @abstractmethod - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Sends the given data through the stream, blocking if necessary. Args: @@ -304,7 +345,7 @@ async def send_all(self, data): """ @abstractmethod - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Block until it's possible that :meth:`send_all` might not block. This method may return early: it's possible that after it returns, @@ -384,7 +425,7 @@ class ReceiveStream(AsyncResource): __slots__ = () @abstractmethod - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """Wait until there is data available on this stream, and then return some of it. @@ -412,10 +453,10 @@ async def receive_some(self, max_bytes=None): """ - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> bytes | bytearray: data = await self.receive_some() if not data: raise StopAsyncIteration @@ -445,7 +486,7 @@ class HalfCloseableStream(Stream): __slots__ = () @abstractmethod - async def send_eof(self): + async def send_eof(self) -> None: """Send an end-of-file indication on this stream, if possible. The difference between :meth:`send_eof` and @@ -524,7 +565,7 @@ class Listener(AsyncResource, Generic[T_resource]): __slots__ = () @abstractmethod - async def accept(self): + async def accept(self) -> T_resource: """Wait until an incoming connection arrives, and then return it. Returns: @@ -631,7 +672,7 @@ async def receive(self) -> ReceiveType: """ - def __aiter__(self): + def __aiter__(self) -> Self: return self async def __anext__(self) -> ReceiveType: diff --git a/trio/_channel.py b/trio/_channel.py index 2059b1fb4b..db122d37f5 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,35 +1,26 @@ from __future__ import annotations -from collections import deque, OrderedDict -from collections.abc import Callable +from collections import OrderedDict, deque from math import inf - from types import TracebackType -from typing import ( - Any, - Generic, - NoReturn, - TypeVar, - TYPE_CHECKING, - Tuple, # only needed for typechecking on <3.9 -) +from typing import Tuple # only needed for typechecking on <3.9 +from typing import TYPE_CHECKING, Generic import attr from outcome import Error, Value -from ._abc import SendChannel, ReceiveChannel, Channel, ReceiveType, SendType, T -from ._util import generic_function, NoPublicConstructor - import trio -from ._core import enable_ki_protection, Task, Abort, RaiseCancelT +from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T +from ._core import Abort, RaiseCancelT, Task, enable_ki_protection +from ._util import NoPublicConstructor, generic_function -# Temporary TypeVar needed until mypy release supports Self as a type -SelfT = TypeVar("SelfT") +if TYPE_CHECKING: + from typing_extensions import Self def _open_memory_channel( - max_buffer_size: int, + max_buffer_size: int | float, ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: """Open a channel for passing objects between tasks within a process. @@ -101,11 +92,11 @@ def _open_memory_channel( # Need to use Tuple instead of tuple due to CI check running on 3.8 class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]): def __new__( # type: ignore[misc] # "must return a subtype" - cls, max_buffer_size: int + cls, max_buffer_size: int | float ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: return _open_memory_channel(max_buffer_size) - def __init__(self, max_buffer_size: int): + def __init__(self, max_buffer_size: int | float): ... else: @@ -117,7 +108,7 @@ def __init__(self, max_buffer_size: int): @attr.s(frozen=True, slots=True) class MemoryChannelStats: current_buffer_used: int = attr.ib() - max_buffer_size: int = attr.ib() + max_buffer_size: int | float = attr.ib() open_send_channels: int = attr.ib() open_receive_channels: int = attr.ib() tasks_waiting_send: int = attr.ib() @@ -126,7 +117,7 @@ class MemoryChannelStats: @attr.s(slots=True) class MemoryChannelState(Generic[T]): - max_buffer_size: int = attr.ib() + max_buffer_size: int | float = attr.ib() data: deque[T] = attr.ib(factory=deque) # Counts of open endpoints using this state open_send_channels: int = attr.ib(default=0) @@ -218,7 +209,7 @@ def abort_fn(_: RaiseCancelT) -> Abort: # Return type must be stringified or use a TypeVar @enable_ki_protection - def clone(self) -> "MemorySendChannel[SendType]": + def clone(self) -> MemorySendChannel[SendType]: """Clone this send channel object. This returns a new `MemorySendChannel` object, which acts as a @@ -246,14 +237,14 @@ def clone(self) -> "MemorySendChannel[SendType]": raise trio.ClosedResourceError return MemorySendChannel._create(self._state) - def __enter__(self: SelfT) -> SelfT: + def __enter__(self) -> Self: return self def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.close() @@ -361,7 +352,7 @@ def abort_fn(_: RaiseCancelT) -> Abort: return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return] @enable_ki_protection - def clone(self) -> "MemoryReceiveChannel[ReceiveType]": + def clone(self) -> MemoryReceiveChannel[ReceiveType]: """Clone this receive channel object. This returns a new `MemoryReceiveChannel` object, which acts as a @@ -392,14 +383,14 @@ def clone(self) -> "MemoryReceiveChannel[ReceiveType]": raise trio.ClosedResourceError return MemoryReceiveChannel._create(self._state) - def __enter__(self: SelfT) -> SelfT: + def __enter__(self) -> Self: return self def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.close() diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index f9919b8323..b9bd0d8cc4 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -6,83 +6,74 @@ import sys +from ._entry_queue import TrioToken from ._exceptions import ( - TrioInternalError, - RunFinishedError, - WouldBlock, - Cancelled, + BrokenResourceError, BusyResourceError, + Cancelled, ClosedResourceError, - BrokenResourceError, EndOfChannel, + RunFinishedError, + TrioInternalError, + WouldBlock, ) - -from ._ki import ( - enable_ki_protection, - disable_ki_protection, - currently_ki_protected, -) +from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection +from ._local import RunVar +from ._mock_clock import MockClock +from ._parking_lot import ParkingLot, ParkingLotStatistics # Imports that always exist from ._run import ( - Task, + TASK_STATUS_IGNORED, CancelScope, - run, - open_nursery, + Nursery, + RunStatistics, + Task, + TaskStatus, + add_instrument, checkpoint, - current_task, - current_effective_deadline, checkpoint_if_cancelled, - TASK_STATUS_IGNORED, + current_clock, + current_effective_deadline, + current_root_task, current_statistics, + current_task, + current_time, current_trio_token, - reschedule, + notify_closing, + open_nursery, remove_instrument, - add_instrument, - current_clock, - current_root_task, + reschedule, + run, spawn_system_task, - current_time, + start_guest_run, wait_all_tasks_blocked, wait_readable, wait_writable, - notify_closing, - Nursery, - start_guest_run, ) +from ._thread_cache import start_thread_soon # Has to come after _run to resolve a circular import from ._traps import ( - cancel_shielded_checkpoint, Abort, RaiseCancelT, - wait_task_rescheduled, - temporarily_detach_coroutine_object, + cancel_shielded_checkpoint, permanently_detach_coroutine_object, reattach_detached_coroutine_object, + temporarily_detach_coroutine_object, + wait_task_rescheduled, ) - -from ._entry_queue import TrioToken - -from ._parking_lot import ParkingLot - -from ._unbounded_queue import UnboundedQueue - -from ._local import RunVar - -from ._thread_cache import start_thread_soon - -from ._mock_clock import MockClock +from ._unbounded_queue import UnboundedQueue, UnboundedQueueStatistics # Windows imports if sys.platform == "win32": from ._run import ( - monitor_completion_key, current_iocp, + monitor_completion_key, + readinto_overlapped, register_with_iocp, wait_overlapped, write_overlapped, - readinto_overlapped, ) # Kqueue imports elif sys.platform != "linux" and sys.platform != "win32": diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py index 1eab150488..4261328278 100644 --- a/trio/_core/_asyncgens.py +++ b/trio/_core/_asyncgens.py @@ -1,16 +1,30 @@ -import attr +from __future__ import annotations + import logging import sys import warnings import weakref +from types import AsyncGeneratorType +from typing import TYPE_CHECKING, NoReturn + +import attr +from .. import _core from .._util import name_asyncgen from . import _run -from .. import _core # Used to log exceptions in async generator finalizers ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") +if TYPE_CHECKING: + from typing import Set + + _WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]] + _ASYNC_GEN_SET = Set[AsyncGeneratorType[object, NoReturn]] +else: + _WEAK_ASYNC_GEN_SET = weakref.WeakSet + _ASYNC_GEN_SET = set + @attr.s(eq=False, slots=True) class AsyncGenerators: @@ -21,17 +35,17 @@ class AsyncGenerators: # asyncgens after the system nursery has been closed, it's a # regular set so we don't have to deal with GC firing at # unexpected times. - alive = attr.ib(factory=weakref.WeakSet) + alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attr.ib(factory=_WEAK_ASYNC_GEN_SET) # This collects async generators that get garbage collected during # the one-tick window between the system nursery closing and the # init task starting end-of-run asyncgen finalization. - trailing_needs_finalize = attr.ib(factory=set) + trailing_needs_finalize: _ASYNC_GEN_SET = attr.ib(factory=_ASYNC_GEN_SET) prev_hooks = attr.ib(init=False) - def install_hooks(self, runner): - def firstiter(agen): + def install_hooks(self, runner: _run.Runner) -> None: + def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None: if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"): self.alive.add(agen) else: @@ -45,7 +59,9 @@ def firstiter(agen): if self.prev_hooks.firstiter is not None: self.prev_hooks.firstiter(agen) - def finalize_in_trio_context(agen, agen_name): + def finalize_in_trio_context( + agen: AsyncGeneratorType[object, NoReturn], agen_name: str + ) -> None: try: runner.spawn_system_task( self._finalize_one, @@ -60,7 +76,7 @@ def finalize_in_trio_context(agen, agen_name): # have hit it. self.trailing_needs_finalize.add(agen) - def finalizer(agen): + def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: agen_name = name_asyncgen(agen) try: is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen") @@ -111,9 +127,9 @@ def finalizer(agen): ) self.prev_hooks = sys.get_asyncgen_hooks() - sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) + sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) # type: ignore[arg-type] # Finalizer doesn't use AsyncGeneratorType - async def finalize_remaining(self, runner): + async def finalize_remaining(self, runner: _run.Runner) -> None: # This is called from init after shutting down the system nursery. # The only tasks running at this point are init and # the run_sync_soon task, and since the system nursery is closed, @@ -169,14 +185,16 @@ async def finalize_remaining(self, runner): # all are gone. while self.alive: batch = self.alive - self.alive = set() + self.alive = _ASYNC_GEN_SET() for agen in batch: await self._finalize_one(agen, name_asyncgen(agen)) - def close(self): + def close(self) -> None: sys.set_asyncgen_hooks(*self.prev_hooks) - async def _finalize_one(self, agen, name): + async def _finalize_one( + self, agen: AsyncGeneratorType[object, NoReturn], name: object + ) -> None: try: # This shield ensures that finalize_asyncgen never exits # with an exception, not even a Cancelled. The inside diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index 9f3301b3d2..468a13462a 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -1,5 +1,8 @@ -from collections import deque +from __future__ import annotations + import threading +from collections import deque +from typing import Callable, Iterable, NoReturn, Tuple import attr @@ -7,6 +10,11 @@ from .._util import NoPublicConstructor from ._wakeup_socketpair import WakeupSocketpair +# TODO: Type with TypeVarTuple, at least to an extent where it makes +# the public interface safe. +Function = Callable[..., object] +Job = Tuple[Function, Iterable[object]] + @attr.s(slots=True) class EntryQueue: @@ -17,11 +25,11 @@ class EntryQueue: # atomic WRT signal delivery (signal handlers can run on either side, but # not *during* a deque operation). dict makes similar guarantees - and # it's even ordered! - queue = attr.ib(factory=deque) - idempotent_queue = attr.ib(factory=dict) + queue: deque[Job] = attr.ib(factory=deque) + idempotent_queue: dict[Job, None] = attr.ib(factory=dict) - wakeup = attr.ib(factory=WakeupSocketpair) - done = attr.ib(default=False) + wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + done: bool = attr.ib(default=False) # Must be a reentrant lock, because it's acquired from signal handlers. # RLock is signal-safe as of cpython 3.2. NB that this does mean that the # lock is effectively *disabled* when we enter from signal context. The @@ -30,9 +38,9 @@ class EntryQueue: # main thread -- it just might happen at some inconvenient place. But if # you look at the one place where the main thread holds the lock, it's # just to make 1 assignment, so that's atomic WRT a signal anyway. - lock = attr.ib(factory=threading.RLock) + lock: threading.RLock = attr.ib(factory=threading.RLock) - async def task(self): + async def task(self) -> None: assert _core.currently_ki_protected() # RLock has two implementations: a signal-safe version in _thread, and # and signal-UNsafe version in threading. We need the signal safe @@ -43,7 +51,7 @@ async def task(self): # https://bugs.python.org/issue13697#msg237140 assert self.lock.__class__.__module__ == "_thread" - def run_cb(job): + def run_cb(job: Job) -> None: # We run this with KI protection enabled; it's the callback's # job to disable it if it wants it disabled. Exceptions are # treated like system task exceptions (i.e., converted into @@ -53,7 +61,7 @@ def run_cb(job): sync_fn(*args) except BaseException as exc: - async def kill_everything(exc): + async def kill_everything(exc: BaseException) -> NoReturn: raise exc try: @@ -63,14 +71,17 @@ async def kill_everything(exc): # system nursery is already closed. # TODO(2020-06): this is a gross hack and should # be fixed soon when we address #1607. - _core.current_task().parent_nursery.start_soon(kill_everything, exc) - - return True + parent_nursery = _core.current_task().parent_nursery + if parent_nursery is None: + raise AssertionError( + "Internal error: `parent_nursery` should never be `None`" + ) from exc # pragma: no cover + parent_nursery.start_soon(kill_everything, exc) # This has to be carefully written to be safe in the face of new items # being queued while we iterate, and to do a bounded amount of work on # each pass: - def run_all_bounded(): + def run_all_bounded() -> None: for _ in range(len(self.queue)): run_cb(self.queue.popleft()) for job in list(self.idempotent_queue): @@ -104,13 +115,15 @@ def run_all_bounded(): assert not self.queue assert not self.idempotent_queue - def close(self): + def close(self) -> None: self.wakeup.close() - def size(self): + def size(self) -> int: return len(self.queue) + len(self.idempotent_queue) - def run_sync_soon(self, sync_fn, *args, idempotent=False): + def run_sync_soon( + self, sync_fn: Function, *args: object, idempotent: bool = False + ) -> None: with self.lock: if self.done: raise _core.RunFinishedError("run() has exited") @@ -146,9 +159,11 @@ class TrioToken(metaclass=NoPublicConstructor): """ - _reentry_queue = attr.ib() + _reentry_queue: EntryQueue = attr.ib() - def run_sync_soon(self, sync_fn, *args, idempotent=False): + def run_sync_soon( + self, sync_fn: Function, *args: object, idempotent: bool = False + ) -> None: """Schedule a call to ``sync_fn(*args)`` to occur in the context of a Trio task. diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index 6189c484b4..bdc7b31c21 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -1,5 +1,3 @@ -import attr - from trio._util import NoPublicConstructor @@ -61,7 +59,7 @@ class Cancelled(BaseException, metaclass=NoPublicConstructor): """ - def __str__(self): + def __str__(self) -> str: return "Cancelled" diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index 986ab2c7f5..652fed1288 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -1,22 +1,22 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._instrumentation import Instrument +from __future__ import annotations -# fmt: off +from ._instrumentation import Instrument +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT -def add_instrument(instrument: Instrument) ->None: +def add_instrument(instrument: Instrument) -> None: """Start instrumenting the current run loop with the given instrument. - Args: - instrument (trio.abc.Instrument): The instrument to activate. + Args: + instrument (trio.abc.Instrument): The instrument to activate. - If ``instrument`` is already active, does nothing. + If ``instrument`` is already active, does nothing. - """ + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument) @@ -24,24 +24,21 @@ def add_instrument(instrument: Instrument) ->None: raise RuntimeError("must be called from async context") -def remove_instrument(instrument: Instrument) ->None: +def remove_instrument(instrument: Instrument) -> None: """Stop instrumenting the current run loop with the given instrument. - Args: - instrument (trio.abc.Instrument): The instrument to de-activate. + Args: + instrument (trio.abc.Instrument): The instrument to de-activate. - Raises: - KeyError: if the instrument is not currently active. This could - occur either because you never added it, or because you added it - and then it raised an unhandled exception and was automatically - deactivated. + Raises: + KeyError: if the instrument is not currently active. This could + occur either because you never added it, or because you added it + and then it raised an unhandled exception and was automatically + deactivated. - """ + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument) except AttributeError: raise RuntimeError("must be called from async context") - - -# fmt: on diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 9ae54e4f68..4dc2b59c98 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -1,14 +1,19 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from __future__ import annotations + +import sys +from socket import socket +from typing import TYPE_CHECKING + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._instrumentation import Instrument +from ._run import GLOBAL_RUN_CONTEXT -# fmt: off +assert not TYPE_CHECKING or sys.platform == "linux" -async def wait_readable(fd): +async def wait_readable(fd: (int | socket)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -16,7 +21,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: (int | socket)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -24,12 +29,9 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: (int | socket)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: raise RuntimeError("must be called from async context") - - -# fmt: on diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 7549899dbe..9c8ca26ef3 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -1,14 +1,27 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, ContextManager + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._instrumentation import Instrument +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + import select + from socket import socket + + from ._traps import Abort, RaiseCancelT -# fmt: off + from .. import _core +import sys -def current_kqueue(): +assert not TYPE_CHECKING or sys.platform == "darwin" + + +def current_kqueue() -> select.kqueue: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() @@ -16,7 +29,9 @@ def current_kqueue(): raise RuntimeError("must be called from async context") -def monitor_kevent(ident, filter): +def monitor_kevent( + ident: int, filter: int +) -> ContextManager[_core.UnboundedQueue[select.kevent]]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) @@ -24,15 +39,19 @@ def monitor_kevent(ident, filter): raise RuntimeError("must be called from async context") -async def wait_kevent(ident, filter, abort_func): +async def wait_kevent( + ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] +) -> Abort: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent( + ident, filter, abort_func + ) except AttributeError: raise RuntimeError("must be called from async context") -async def wait_readable(fd): +async def wait_readable(fd: (int | socket)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -40,7 +59,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: (int | socket)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -48,12 +67,9 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: (int | socket)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: raise RuntimeError("must be called from async context") - - -# fmt: on diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index e6337e94b0..b81255d8a9 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -1,11 +1,15 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._instrumentation import Instrument +from ._run import GLOBAL_RUN_CONTEXT -# fmt: off +assert not TYPE_CHECKING or sys.platform == "win32" async def wait_readable(sock): @@ -43,7 +47,9 @@ def register_with_iocp(handle): async def wait_overlapped(handle, lpOverlapped): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped) + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( + handle, lpOverlapped + ) except AttributeError: raise RuntimeError("must be called from async context") @@ -51,7 +57,9 @@ async def wait_overlapped(handle, lpOverlapped): async def write_overlapped(handle, data, file_offset=0): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset) + return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( + handle, data, file_offset + ) except AttributeError: raise RuntimeError("must be called from async context") @@ -59,7 +67,9 @@ async def write_overlapped(handle, data, file_offset=0): async def readinto_overlapped(handle, buffer, file_offset=0): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset) + return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( + handle, buffer, file_offset + ) except AttributeError: raise RuntimeError("must be called from async context") @@ -78,6 +88,3 @@ def monitor_completion_key(): return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() except AttributeError: raise RuntimeError("must be called from async context") - - -# fmt: on diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index d20891c55e..3e1b7b78f1 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -1,36 +1,43 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from __future__ import annotations + +import contextvars +from collections.abc import Awaitable, Callable +from typing import Any + +from outcome import Outcome + +from .._abc import Clock +from ._entry_queue import TrioToken from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._instrumentation import Instrument - -# fmt: off - - -def current_statistics(): - """Returns an object containing run-loop-level debugging information. - - Currently the following fields are defined: - - * ``tasks_living`` (int): The number of tasks that have been spawned - and not yet exited. - * ``tasks_runnable`` (int): The number of tasks that are currently - queued on the run queue (as opposed to blocked waiting for something - to happen). - * ``seconds_to_next_deadline`` (float): The time until the next - pending cancel scope deadline. May be negative if the deadline has - expired but we haven't yet processed cancellations. May be - :data:`~math.inf` if there are no pending deadlines. - * ``run_sync_soon_queue_size`` (int): The number of - unprocessed callbacks queued via - :meth:`trio.lowlevel.TrioToken.run_sync_soon`. - * ``io_statistics`` (object): Some statistics from Trio's I/O - backend. This always has an attribute ``backend`` which is a string - naming which operating-system-specific I/O backend is in use; the - other attributes vary between backends. - - """ +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task + + +def current_statistics() -> RunStatistics: + """Returns ``RunStatistics``, which contains run-loop-level debugging information. + + Currently, the following fields are defined: + + * ``tasks_living`` (int): The number of tasks that have been spawned + and not yet exited. + * ``tasks_runnable`` (int): The number of tasks that are currently + queued on the run queue (as opposed to blocked waiting for something + to happen). + * ``seconds_to_next_deadline`` (float): The time until the next + pending cancel scope deadline. May be negative if the deadline has + expired but we haven't yet processed cancellations. May be + :data:`~math.inf` if there are no pending deadlines. + * ``run_sync_soon_queue_size`` (int): The number of + unprocessed callbacks queued via + :meth:`trio.lowlevel.TrioToken.run_sync_soon`. + * ``io_statistics`` (object): Some statistics from Trio's I/O + backend. This always has an attribute ``backend`` which is a string + naming which operating-system-specific I/O backend is in use; the + other attributes vary between backends. + + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_statistics() @@ -38,16 +45,16 @@ def current_statistics(): raise RuntimeError("must be called from async context") -def current_time(): +def current_time() -> float: """Returns the current time according to Trio's internal clock. - Returns: - float: The current time. + Returns: + float: The current time. - Raises: - RuntimeError: if not inside a call to :func:`trio.run`. + Raises: + RuntimeError: if not inside a call to :func:`trio.run`. - """ + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_time() @@ -55,7 +62,7 @@ def current_time(): raise RuntimeError("must be called from async context") -def current_clock(): +def current_clock() -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: @@ -64,12 +71,12 @@ def current_clock(): raise RuntimeError("must be called from async context") -def current_root_task(): +def current_root_task() -> Task | None: """Returns the current root :class:`Task`. - This is the task that is the ultimate parent of all other tasks. + This is the task that is the ultimate parent of all other tasks. - """ + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_root_task() @@ -77,24 +84,24 @@ def current_root_task(): raise RuntimeError("must be called from async context") -def reschedule(task, next_send=_NO_SEND): +def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: """Reschedule the given task with the given - :class:`outcome.Outcome`. + :class:`outcome.Outcome`. - See :func:`wait_task_rescheduled` for the gory details. + See :func:`wait_task_rescheduled` for the gory details. - There must be exactly one call to :func:`reschedule` for every call to - :func:`wait_task_rescheduled`. (And when counting, keep in mind that - returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent - to calling :func:`reschedule` once.) + There must be exactly one call to :func:`reschedule` for every call to + :func:`wait_task_rescheduled`. (And when counting, keep in mind that + returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent + to calling :func:`reschedule` once.) - Args: - task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked - in a call to :func:`wait_task_rescheduled`. - next_send (outcome.Outcome): the value (or error) to return (or - raise) from :func:`wait_task_rescheduled`. + Args: + task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked + in a call to :func:`wait_task_rescheduled`. + next_send (outcome.Outcome): the value (or error) to return (or + raise) from :func:`wait_task_rescheduled`. - """ + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) @@ -102,70 +109,77 @@ def reschedule(task, next_send=_NO_SEND): raise RuntimeError("must be called from async context") -def spawn_system_task(async_fn, *args, name=None, context=None): +def spawn_system_task( + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: object = None, + context: (contextvars.Context | None) = None, +) -> Task: """Spawn a "system" task. - System tasks have a few differences from regular tasks: - - * They don't need an explicit nursery; instead they go into the - internal "system nursery". - - * If a system task raises an exception, then it's converted into a - :exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you - write a system task, you should be careful to make sure it doesn't - crash. - - * System tasks are automatically cancelled when the main task exits. - - * By default, system tasks have :exc:`KeyboardInterrupt` protection - *enabled*. If you want your task to be interruptible by control-C, - then you need to use :func:`disable_ki_protection` explicitly (and - come up with some plan for what to do with a - :exc:`KeyboardInterrupt`, given that system tasks aren't allowed to - raise exceptions). - - * System tasks do not inherit context variables from their creator. - - Towards the end of a call to :meth:`trio.run`, after the main - task and all system tasks have exited, the system nursery - becomes closed. At this point, new calls to - :func:`spawn_system_task` will raise ``RuntimeError("Nursery - is closed to new arrivals")`` instead of creating a system - task. It's possible to encounter this state either in - a ``finally`` block in an async generator, or in a callback - passed to :meth:`TrioToken.run_sync_soon` at the right moment. - - Args: - async_fn: An async callable. - args: Positional arguments for ``async_fn``. If you want to pass - keyword arguments, use :func:`functools.partial`. - name: The name for this task. Only used for debugging/introspection - (e.g. ``repr(task_obj)``). If this isn't a string, - :func:`spawn_system_task` will try to make it one. A common use - case is if you're wrapping a function before spawning a new - task, you might pass the original function as the ``name=`` to - make debugging easier. - context: An optional ``contextvars.Context`` object with context variables - to use for this task. You would normally get a copy of the current - context with ``context = contextvars.copy_context()`` and then you would - pass that ``context`` object here. - - Returns: - Task: the newly spawned task - - """ + System tasks have a few differences from regular tasks: + + * They don't need an explicit nursery; instead they go into the + internal "system nursery". + + * If a system task raises an exception, then it's converted into a + :exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you + write a system task, you should be careful to make sure it doesn't + crash. + + * System tasks are automatically cancelled when the main task exits. + + * By default, system tasks have :exc:`KeyboardInterrupt` protection + *enabled*. If you want your task to be interruptible by control-C, + then you need to use :func:`disable_ki_protection` explicitly (and + come up with some plan for what to do with a + :exc:`KeyboardInterrupt`, given that system tasks aren't allowed to + raise exceptions). + + * System tasks do not inherit context variables from their creator. + + Towards the end of a call to :meth:`trio.run`, after the main + task and all system tasks have exited, the system nursery + becomes closed. At this point, new calls to + :func:`spawn_system_task` will raise ``RuntimeError("Nursery + is closed to new arrivals")`` instead of creating a system + task. It's possible to encounter this state either in + a ``finally`` block in an async generator, or in a callback + passed to :meth:`TrioToken.run_sync_soon` at the right moment. + + Args: + async_fn: An async callable. + args: Positional arguments for ``async_fn``. If you want to pass + keyword arguments, use :func:`functools.partial`. + name: The name for this task. Only used for debugging/introspection + (e.g. ``repr(task_obj)``). If this isn't a string, + :func:`spawn_system_task` will try to make it one. A common use + case is if you're wrapping a function before spawning a new + task, you might pass the original function as the ``name=`` to + make debugging easier. + context: An optional ``contextvars.Context`` object with context variables + to use for this task. You would normally get a copy of the current + context with ``context = contextvars.copy_context()`` and then you would + pass that ``context`` object here. + + Returns: + Task: the newly spawned task + + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name, context=context) + return GLOBAL_RUN_CONTEXT.runner.spawn_system_task( + async_fn, *args, name=name, context=context + ) except AttributeError: raise RuntimeError("must be called from async context") -def current_trio_token(): +def current_trio_token() -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to - :func:`trio.run`. + :func:`trio.run`. - """ + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_trio_token() @@ -173,69 +187,66 @@ def current_trio_token(): raise RuntimeError("must be called from async context") -async def wait_all_tasks_blocked(cushion=0.0): +async def wait_all_tasks_blocked(cushion: float = 0.0) -> None: """Block until there are no runnable tasks. - This is useful in testing code when you want to give other tasks a - chance to "settle down". The calling task is blocked, and doesn't wake - up until all other tasks are also blocked for at least ``cushion`` - seconds. (Setting a non-zero ``cushion`` is intended to handle cases - like two tasks talking to each other over a local socket, where we - want to ignore the potential brief moment between a send and receive - when all tasks are blocked.) - - Note that ``cushion`` is measured in *real* time, not the Trio clock - time. - - If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`, - then the one with the shortest ``cushion`` is the one woken (and - this task becoming unblocked resets the timers for the remaining - tasks). If there are multiple tasks that have exactly the same - ``cushion``, then all are woken. - - You should also consider :class:`trio.testing.Sequencer`, which - provides a more explicit way to control execution ordering within a - test, and will often produce more readable tests. - - Example: - Here's an example of one way to test that Trio's locks are fair: we - take the lock in the parent, start a child, wait for the child to be - blocked waiting for the lock (!), and then check that we can't - release and immediately re-acquire the lock:: - - async def lock_taker(lock): - await lock.acquire() + This is useful in testing code when you want to give other tasks a + chance to "settle down". The calling task is blocked, and doesn't wake + up until all other tasks are also blocked for at least ``cushion`` + seconds. (Setting a non-zero ``cushion`` is intended to handle cases + like two tasks talking to each other over a local socket, where we + want to ignore the potential brief moment between a send and receive + when all tasks are blocked.) + + Note that ``cushion`` is measured in *real* time, not the Trio clock + time. + + If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`, + then the one with the shortest ``cushion`` is the one woken (and + this task becoming unblocked resets the timers for the remaining + tasks). If there are multiple tasks that have exactly the same + ``cushion``, then all are woken. + + You should also consider :class:`trio.testing.Sequencer`, which + provides a more explicit way to control execution ordering within a + test, and will often produce more readable tests. + + Example: + Here's an example of one way to test that Trio's locks are fair: we + take the lock in the parent, start a child, wait for the child to be + blocked waiting for the lock (!), and then check that we can't + release and immediately re-acquire the lock:: + + async def lock_taker(lock): + await lock.acquire() + lock.release() + + async def test_lock_fairness(): + lock = trio.Lock() + await lock.acquire() + async with trio.open_nursery() as nursery: + nursery.start_soon(lock_taker, lock) + # child hasn't run yet, we have the lock + assert lock.locked() + assert lock._owner is trio.lowlevel.current_task() + await trio.testing.wait_all_tasks_blocked() + # now the child has run and is blocked on lock.acquire(), we + # still have the lock + assert lock.locked() + assert lock._owner is trio.lowlevel.current_task() lock.release() - - async def test_lock_fairness(): - lock = trio.Lock() - await lock.acquire() - async with trio.open_nursery() as nursery: - nursery.start_soon(lock_taker, lock) - # child hasn't run yet, we have the lock - assert lock.locked() - assert lock._owner is trio.lowlevel.current_task() - await trio.testing.wait_all_tasks_blocked() - # now the child has run and is blocked on lock.acquire(), we - # still have the lock - assert lock.locked() - assert lock._owner is trio.lowlevel.current_task() - lock.release() - try: - # The child has a prior claim, so we can't have it - lock.acquire_nowait() - except trio.WouldBlock: - assert lock._owner is not trio.lowlevel.current_task() - print("PASS") - else: - print("FAIL") - - """ + try: + # The child has a prior claim, so we can't have it + lock.acquire_nowait() + except trio.WouldBlock: + assert lock._owner is not trio.lowlevel.current_task() + print("PASS") + else: + print("FAIL") + + """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion) except AttributeError: raise RuntimeError("must be called from async context") - - -# fmt: on diff --git a/trio/_core/_instrumentation.py b/trio/_core/_instrumentation.py index b133d47406..a0757a5b83 100644 --- a/trio/_core/_instrumentation.py +++ b/trio/_core/_instrumentation.py @@ -1,7 +1,6 @@ import logging import types -import attr -from typing import Any, Callable, Dict, List, Sequence, Iterator, TypeVar +from typing import Any, Callable, Dict, Sequence, TypeVar from .._abc import Instrument diff --git a/trio/_core/_io_common.py b/trio/_core/_io_common.py index 9891849bc9..c1af293278 100644 --- a/trio/_core/_io_common.py +++ b/trio/_core/_io_common.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import copy +from typing import TYPE_CHECKING + import outcome + from .. import _core +if TYPE_CHECKING: + from ._io_epoll import EpollWaiters + # Utility function shared between _io_epoll and _io_windows -def wake_all(waiters, exc): +def wake_all(waiters: EpollWaiters, exc: BaseException) -> None: try: current_task = _core.current_task() except RuntimeError: diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index c1537cf53e..0d247cae64 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -1,22 +1,43 @@ +from __future__ import annotations + import select import sys -import attr from collections import defaultdict -from typing import Dict, TYPE_CHECKING +from typing import TYPE_CHECKING, DefaultDict, Literal + +import attr from .. import _core -from ._run import _public from ._io_common import wake_all +from ._run import Task, _public from ._wakeup_socketpair import WakeupSocketpair +if TYPE_CHECKING: + from socket import socket + + from typing_extensions import TypeAlias + + from .._core import Abort, RaiseCancelT + + +@attr.s(slots=True, eq=False) +class EpollWaiters: + read_task: Task | None = attr.ib(default=None) + write_task: Task | None = attr.ib(default=None) + current_flags: int = attr.ib(default=0) + + assert not TYPE_CHECKING or sys.platform == "linux" +EventResult: TypeAlias = "list[tuple[int, int]]" + + @attr.s(slots=True, eq=False, frozen=True) class _EpollStatistics: - tasks_waiting_read = attr.ib() - tasks_waiting_write = attr.ib() - backend = attr.ib(default="epoll") + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + backend: Literal["epoll"] = attr.ib(init=False, default="epoll") # Some facts about epoll @@ -177,28 +198,21 @@ class _EpollStatistics: # wanted to about how epoll works. -@attr.s(slots=True, eq=False) -class EpollWaiters: - read_task = attr.ib(default=None) - write_task = attr.ib(default=None) - current_flags = attr.ib(default=0) - - @attr.s(slots=True, eq=False, hash=False) class EpollIOManager: - _epoll = attr.ib(factory=select.epoll) + _epoll: select.epoll = attr.ib(factory=select.epoll) # {fd: EpollWaiters} - _registered = attr.ib( - factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters] + _registered: DefaultDict[int, EpollWaiters] = attr.ib( + factory=lambda: defaultdict(EpollWaiters) ) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: int | None = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() - def statistics(self): + def statistics(self) -> _EpollStatistics: tasks_waiting_read = 0 tasks_waiting_write = 0 for waiter in self._registered.values(): @@ -211,24 +225,24 @@ def statistics(self): tasks_waiting_write=tasks_waiting_write, ) - def close(self): + def close(self) -> None: self._epoll.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() # Return value must be False-y IFF the timeout expired, NOT if any I/O # happened or force_wakeup was called. Otherwise it can be anything; gets # passed straight through to process_events. - def get_events(self, timeout): + def get_events(self, timeout: float) -> EventResult: # max_events must be > 0 or epoll gets cranky # accessing self._registered from a thread looks dangerous, but it's # OK because it doesn't matter if our value is a little bit off. max_events = max(1, len(self._registered)) return self._epoll.poll(timeout, max_events) - def process_events(self, events): + def process_events(self, events: EventResult) -> None: for fd, flags in events: if fd == self._force_wakeup_fd: self._force_wakeup.drain() @@ -247,7 +261,7 @@ def process_events(self, events): waiters.read_task = None self._update_registrations(fd) - def _update_registrations(self, fd): + def _update_registrations(self, fd: int) -> None: waiters = self._registered[fd] wanted_flags = 0 if waiters.read_task is not None: @@ -276,7 +290,7 @@ def _update_registrations(self, fd): if not wanted_flags: del self._registered[fd] - async def _epoll_wait(self, fd, attr_name): + async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None: if not isinstance(fd, int): fd = fd.fileno() waiters = self._registered[fd] @@ -287,7 +301,7 @@ async def _epoll_wait(self, fd, attr_name): setattr(waiters, attr_name, _core.current_task()) self._update_registrations(fd) - def abort(_): + def abort(_: RaiseCancelT) -> Abort: setattr(waiters, attr_name, None) self._update_registrations(fd) return _core.Abort.SUCCEEDED @@ -295,15 +309,15 @@ def abort(_): await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: int | socket) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: int | socket) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd): + def notify_closing(self, fd: int | socket) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 31940d5694..56a6559091 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -1,42 +1,55 @@ +from __future__ import annotations + +import errno import select import sys -from typing import TYPE_CHECKING - -import outcome from contextlib import contextmanager +from typing import TYPE_CHECKING, Callable, Iterator, Literal + import attr -import errno +import outcome from .. import _core from ._run import _public from ._wakeup_socketpair import WakeupSocketpair +if TYPE_CHECKING: + from socket import socket + + from typing_extensions import TypeAlias + + from .._core import Abort, RaiseCancelT, Task, UnboundedQueue + assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") +EventResult: TypeAlias = "list[select.kevent]" + @attr.s(slots=True, eq=False, frozen=True) class _KqueueStatistics: - tasks_waiting = attr.ib() - monitors = attr.ib() - backend = attr.ib(default="kqueue") + tasks_waiting: int = attr.ib() + monitors: int = attr.ib() + backend: Literal["kqueue"] = attr.ib(init=False, default="kqueue") @attr.s(slots=True, eq=False) class KqueueIOManager: - _kqueue = attr.ib(factory=select.kqueue) + _kqueue: select.kqueue = attr.ib(factory=select.kqueue) # {(ident, filter): Task or UnboundedQueue} - _registered = attr.ib(factory=dict) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = attr.ib( + factory=dict + ) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: int | None = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: force_wakeup_event = select.kevent( self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD ) self._kqueue.control([force_wakeup_event], 0) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() - def statistics(self): + def statistics(self) -> _KqueueStatistics: tasks_waiting = 0 monitors = 0 for receiver in self._registered.values(): @@ -46,14 +59,14 @@ def statistics(self): monitors += 1 return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors) - def close(self): + def close(self) -> None: self._kqueue.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() - def get_events(self, timeout): + def get_events(self, timeout: float) -> EventResult: # max_events must be > 0 or kqueue gets cranky # and we generally want this to be strictly larger than the actual # number of events we get, so that we can tell that we've gotten @@ -70,7 +83,7 @@ def get_events(self, timeout): # and loop back to the start return events - def process_events(self, events): + def process_events(self, events: EventResult) -> None: for event in events: key = (event.ident, event.filter) if event.ident == self._force_wakeup_fd: @@ -79,7 +92,7 @@ def process_events(self, events): receiver = self._registered[key] if event.flags & select.KQ_EV_ONESHOT: del self._registered[key] - if type(receiver) is _core.Task: + if isinstance(receiver, _core.Task): _core.reschedule(receiver, outcome.Value(event)) else: receiver.put_nowait(event) @@ -96,18 +109,20 @@ def process_events(self, events): # be more ergonomic... @_public - def current_kqueue(self): + def current_kqueue(self) -> select.kqueue: return self._kqueue @contextmanager @_public - def monitor_kevent(self, ident, filter): + def monitor_kevent( + self, ident: int, filter: int + ) -> Iterator[_core.UnboundedQueue[select.kevent]]: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( "attempt to register multiple listeners for same ident/filter pair" ) - q = _core.UnboundedQueue() + q = _core.UnboundedQueue[select.kevent]() self._registered[key] = q try: yield q @@ -115,7 +130,9 @@ def monitor_kevent(self, ident, filter): del self._registered[key] @_public - async def wait_kevent(self, ident, filter, abort_func): + async def wait_kevent( + self, ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] + ) -> Abort: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( @@ -123,22 +140,23 @@ async def wait_kevent(self, ident, filter, abort_func): ) self._registered[key] = _core.current_task() - def abort(raise_cancel): + def abort(raise_cancel: RaiseCancelT) -> Abort: r = abort_func(raise_cancel) if r is _core.Abort.SUCCEEDED: del self._registered[key] return r - return await _core.wait_task_rescheduled(abort) + # wait_task_rescheduled does not have its return type typed + return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return] - async def _wait_common(self, fd, filter): + async def _wait_common(self, fd: int | socket, filter: int) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT event = select.kevent(fd, filter, flags) self._kqueue.control([event], 0) - def abort(_): + def abort(_: RaiseCancelT) -> Abort: event = select.kevent(fd, filter, select.KQ_EV_DELETE) try: self._kqueue.control([event], 0) @@ -163,15 +181,15 @@ def abort(_): await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: int | socket) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: int | socket) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd): + def notify_closing(self, fd: int | socket) -> None: if not isinstance(fd, int): fd = fd.fileno() diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 9b5ebfc268..9757d25b5f 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -1,35 +1,40 @@ -import itertools -from contextlib import contextmanager +from __future__ import annotations + import enum +import itertools import socket import sys -from typing import TYPE_CHECKING +from contextlib import contextmanager +from typing import TYPE_CHECKING, Literal import attr from outcome import Value from .. import _core -from ._run import _public from ._io_common import wake_all - +from ._run import _public from ._windows_cffi import ( - ffi, - kernel32, - ntdll, - ws2_32, INVALID_HANDLE_VALUE, - raise_winerror, - _handle, - ErrorCodes, - FileFlags, AFDPollFlags, - WSAIoctls, CompletionModes, + ErrorCodes, + FileFlags, IoControlCodes, + WSAIoctls, + _handle, + ffi, + kernel32, + ntdll, + raise_winerror, + ws2_32, ) assert not TYPE_CHECKING or sys.platform == "win32" +if TYPE_CHECKING: + from typing_extensions import TypeAlias +EventResult: TypeAlias = int + # There's a lot to be said about the overall design of a Windows event # loop. See # @@ -366,11 +371,11 @@ class AFDGroup: @attr.s(slots=True, eq=False, frozen=True) class _WindowsStatistics: - tasks_waiting_read = attr.ib() - tasks_waiting_write = attr.ib() - tasks_waiting_overlapped = attr.ib() - completion_key_monitors = attr.ib() - backend = attr.ib(default="windows") + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + tasks_waiting_overlapped: int = attr.ib() + completion_key_monitors: int = attr.ib() + backend: Literal["windows"] = attr.ib(init=False, default="windows") # Maximum number of events to dequeue from the completion port on each pass @@ -486,7 +491,7 @@ def force_wakeup(self): ) ) - def get_events(self, timeout): + def get_events(self, timeout: float) -> EventResult: received = ffi.new("PULONG") milliseconds = round(1000 * timeout) if timeout > 0 and milliseconds == 0: @@ -501,9 +506,11 @@ def get_events(self, timeout): if exc.winerror != ErrorCodes.WAIT_TIMEOUT: # pragma: no cover raise return 0 - return received[0] + result = received[0] + assert isinstance(result, int) + return result - def process_events(self, received): + def process_events(self, received: EventResult) -> None: for i in range(received): entry = self._events[i] if entry.lpCompletionKey == CKeys.AFD_POLL: diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index fec23863f1..10172e4989 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -3,17 +3,21 @@ import inspect import signal import sys +import types +from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Final, TypeVar import attr from .._util import is_main_thread +RetT = TypeVar("RetT") + if TYPE_CHECKING: - from typing import Any, TypeVar, Callable + from typing_extensions import ParamSpec, TypeGuard - F = TypeVar("F", bound=Callable[..., Any]) + ArgsT = ParamSpec("ArgsT") # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. @@ -80,22 +84,22 @@ # We use this special string as a unique key into the frame locals dictionary. # The @ ensures it is not a valid identifier and can't clash with any possible # real local name. See: https://github.com/python-trio/trio/issues/469 -LOCALS_KEY_KI_PROTECTION_ENABLED = "@TRIO_KI_PROTECTION_ENABLED" +LOCALS_KEY_KI_PROTECTION_ENABLED: Final = "@TRIO_KI_PROTECTION_ENABLED" # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: -def ki_protection_enabled(frame): +def ki_protection_enabled(frame: types.FrameType | None) -> bool: while frame is not None: if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] + return bool(frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]) if frame.f_code.co_name == "__del__": return True frame = frame.f_back return True -def currently_ki_protected(): +def currently_ki_protected() -> bool: r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection enabled. @@ -115,29 +119,35 @@ def currently_ki_protected(): # functions decorated @async_generator are given this magic property that's a # reference to the object itself # see python-trio/async_generator/async_generator/_impl.py -def legacy_isasyncgenfunction(obj): +def legacy_isasyncgenfunction( + obj: object, +) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: return getattr(obj, "_async_gen_function", None) == id(obj) -def _ki_protection_decorator(enabled): - def decorator(fn): +def _ki_protection_decorator( + enabled: bool, +) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: + # The "ignore[return-value]" below is because the inspect functions cast away the + # original return type of fn, making it just CoroutineType[Any, Any, Any] etc. + def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: # In some version of Python, isgeneratorfunction returns true for # coroutine functions, so we have to check for coroutine functions # first. if inspect.iscoroutinefunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # See the comment for regular generators below coro = fn(*args, **kwargs) coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return coro + return coro # type: ignore[return-value] return wrapper elif inspect.isgeneratorfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # It's important that we inject this directly into the # generator's locals, as opposed to setting it here and then # doing 'yield from'. The reason is, if a generator is @@ -148,23 +158,23 @@ def wrapper(*args, **kwargs): # https://bugs.python.org/issue29590 gen = fn(*args, **kwargs) gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return gen + return gen # type: ignore[return-value] return wrapper elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # See the comment for regular generators above agen = fn(*args, **kwargs) agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return agen + return agen # type: ignore[return-value] return wrapper else: @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return fn(*args, **kwargs) @@ -173,18 +183,28 @@ def wrapper(*args, **kwargs): return decorator -enable_ki_protection: Callable[[F], F] = _ki_protection_decorator(True) +enable_ki_protection: Callable[ + [Callable[ArgsT, RetT]], Callable[ArgsT, RetT] +] = _ki_protection_decorator(True) enable_ki_protection.__name__ = "enable_ki_protection" -disable_ki_protection: Callable[[F], F] = _ki_protection_decorator(False) +disable_ki_protection: Callable[ + [Callable[ArgsT, RetT]], Callable[ArgsT, RetT] +] = _ki_protection_decorator(False) disable_ki_protection.__name__ = "disable_ki_protection" @attr.s class KIManager: - handler = attr.ib(default=None) - - def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): + handler: Callable[[int, types.FrameType | None], None] | None = attr.ib( + default=None + ) + + def install( + self, + deliver_cb: Callable[[], object], + restrict_keyboard_interrupt_to_checkpoints: bool, + ) -> None: assert self.handler is None if ( not is_main_thread() @@ -192,7 +212,7 @@ def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): ): return - def handler(signum, frame): + def handler(signum: int, frame: types.FrameType | None) -> None: assert signum == signal.SIGINT protection_enabled = ki_protection_enabled(frame) if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: @@ -203,7 +223,7 @@ def handler(signum, frame): self.handler = handler signal.signal(signal.SIGINT, handler) - def close(self): + def close(self) -> None: if self.handler is not None: if signal.getsignal(signal.SIGINT) is self.handler: signal.signal(signal.SIGINT, signal.default_int_handler) diff --git a/trio/_core/_local.py b/trio/_core/_local.py index f898a13cff..8286a5578f 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,26 +1,34 @@ +from __future__ import annotations + +from typing import Generic, TypeVar, cast, final + # Runvar implementations import attr +from .._util import Final, NoPublicConstructor from . import _run -from .._util import Final +T = TypeVar("T") + +@final +class _NoValue(metaclass=Final): + ... -@attr.s(eq=False, hash=False, slots=True) -class _RunVarToken: - _no_value = object() - _var = attr.ib() - previous_value = attr.ib(default=_no_value) - redeemed = attr.ib(default=False, init=False) +@attr.s(eq=False, hash=False, slots=False) +class RunVarToken(Generic[T], metaclass=NoPublicConstructor): + _var: RunVar[T] = attr.ib() + previous_value: T | type[_NoValue] = attr.ib(default=_NoValue) + redeemed: bool = attr.ib(default=False, init=False) @classmethod - def empty(cls, var): - return cls(var) + def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: + return cls._create(var) -@attr.s(eq=False, hash=False, slots=True) -class RunVar(metaclass=Final): +@attr.s(eq=False, hash=False, slots=True, repr=False) +class RunVar(Generic[T], metaclass=Final): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, @@ -29,27 +37,27 @@ class RunVar(metaclass=Final): """ - _NO_DEFAULT = object() - _name = attr.ib() - _default = attr.ib(default=_NO_DEFAULT) + _name: str = attr.ib() + _default: T | type[_NoValue] = attr.ib(default=_NoValue) - def get(self, default=_NO_DEFAULT): + def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] + return cast(T, _run.GLOBAL_RUN_CONTEXT.runner._locals[self]) except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency - if default is not self._NO_DEFAULT: - return default + # `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released + if default is not _NoValue: + return default # type: ignore[return-value] - if self._default is not self._NO_DEFAULT: - return self._default + if self._default is not _NoValue: + return self._default # type: ignore[return-value] raise LookupError(self) from None - def set(self, value): + def set(self, value: T) -> RunVarToken[T]: """Sets the value of this :class:`RunVar` for this current run call. @@ -57,16 +65,16 @@ def set(self, value): try: old_value = self.get() except LookupError: - token = _RunVarToken.empty(self) + token = RunVarToken._empty(self) else: - token = _RunVarToken(self, old_value) + token = RunVarToken[T]._create(self, old_value) # This can't fail, because if we weren't in Trio context then the # get() above would have failed. _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value return token - def reset(self, token): + def reset(self, token: RunVarToken[T]) -> None: """Resets the value of this :class:`RunVar` to what it was previously specified by the token. @@ -82,7 +90,7 @@ def reset(self, token): previous = token.previous_value try: - if previous is _RunVarToken._no_value: + if previous is _NoValue: _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) else: _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous @@ -91,5 +99,5 @@ def reset(self, token): token.redeemed = True - def __repr__(self): + def __repr__(self) -> str: return f"" diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index 0e95e4e5c5..27a5829076 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -2,9 +2,9 @@ from math import inf from .. import _core -from ._run import GLOBAL_RUN_CONTEXT from .._abc import Clock from .._util import Final +from ._run import GLOBAL_RUN_CONTEXT ################################################################ # The glorious MockClock @@ -62,7 +62,7 @@ class MockClock(Clock, metaclass=Final): """ - def __init__(self, rate=0.0, autojump_threshold=inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. @@ -77,17 +77,17 @@ def __init__(self, rate=0.0, autojump_threshold=inf): self.rate = rate self.autojump_threshold = autojump_threshold - def __repr__(self): + def __repr__(self) -> str: return "".format( self.current_time(), self._rate, id(self) ) @property - def rate(self): + def rate(self) -> float: return self._rate @rate.setter - def rate(self, new_rate): + def rate(self, new_rate: float) -> None: if new_rate < 0: raise ValueError("rate must be >= 0") else: @@ -98,11 +98,11 @@ def rate(self, new_rate): self._rate = float(new_rate) @property - def autojump_threshold(self): + def autojump_threshold(self) -> float: return self._autojump_threshold @autojump_threshold.setter - def autojump_threshold(self, new_autojump_threshold): + def autojump_threshold(self, new_autojump_threshold: float) -> None: self._autojump_threshold = float(new_autojump_threshold) self._try_resync_autojump_threshold() @@ -112,7 +112,7 @@ def autojump_threshold(self, new_autojump_threshold): # API. Discussion: # # https://github.com/python-trio/trio/issues/1587 - def _try_resync_autojump_threshold(self): + def _try_resync_autojump_threshold(self) -> None: try: runner = GLOBAL_RUN_CONTEXT.runner if runner.is_guest: @@ -124,24 +124,24 @@ def _try_resync_autojump_threshold(self): # Invoked by the run loop when runner.clock_autojump_threshold is # exceeded. - def _autojump(self): + def _autojump(self) -> None: statistics = _core.current_statistics() jump = statistics.seconds_to_next_deadline if 0 < jump < inf: self.jump(jump) - def _real_to_virtual(self, real): + def _real_to_virtual(self, real: float) -> float: real_offset = real - self._real_base virtual_offset = self._rate * real_offset return self._virtual_base + virtual_offset - def start_clock(self): + def start_clock(self) -> None: self._try_resync_autojump_threshold() - def current_time(self): + def current_time(self) -> float: return self._real_to_virtual(self._real_clock()) - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: virtual_timeout = deadline - self.current_time() if virtual_timeout <= 0: return 0 @@ -150,7 +150,7 @@ def deadline_to_sleep_time(self, deadline): else: return 999999999 - def jump(self, seconds): + def jump(self, seconds: float) -> None: """Manually advance the clock by the given number of seconds. Args: diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index a9778fd244..d55e89554d 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import sys import warnings -from typing import Sequence +from collections.abc import Callable, Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, cast, overload import attr @@ -11,12 +15,16 @@ else: from traceback import print_exception +if TYPE_CHECKING: + from typing_extensions import Self ################################################################ # MultiError ################################################################ -def _filter_impl(handler, root_exc): +def _filter_impl( + handler: Callable[[BaseException], BaseException | None], root_exc: BaseException +) -> BaseException | None: # We have a tree of MultiError's, like: # # MultiError([ @@ -75,7 +83,9 @@ def _filter_impl(handler, root_exc): # Filters a subtree, ignoring tracebacks, while keeping a record of # which MultiErrors were preserved unchanged - def filter_tree(exc, preserved): + def filter_tree( + exc: MultiError | BaseException, preserved: set[int] + ) -> MultiError | BaseException | None: if isinstance(exc, MultiError): new_exceptions = [] changed = False @@ -99,7 +109,9 @@ def filter_tree(exc, preserved): new_exc.__context__ = exc return new_exc - def push_tb_down(tb, exc, preserved): + def push_tb_down( + tb: TracebackType | None, exc: BaseException, preserved: set[int] + ) -> None: if id(exc) in preserved: return new_tb = concat_tb(tb, exc.__traceback__) @@ -110,7 +122,7 @@ def push_tb_down(tb, exc, preserved): else: exc.__traceback__ = new_tb - preserved = set() + preserved: set[int] = set() new_root_exc = filter_tree(root_exc, preserved) push_tb_down(None, root_exc, preserved) # Delete the local functions to avoid a reference cycle (see @@ -126,16 +138,21 @@ def push_tb_down(tb, exc, preserved): # frame show up in the traceback; otherwise, we leave no trace.) @attr.s(frozen=True) class MultiErrorCatcher: - _handler = attr.ib() + _handler: Callable[[BaseException], BaseException | None] = attr.ib() - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, etype, exc, tb): - if exc is not None: - filtered_exc = _filter_impl(self._handler, exc) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + if exc_value is not None: + filtered_exc = _filter_impl(self._handler, exc_value) - if filtered_exc is exc: + if filtered_exc is exc_value: # Let the interpreter re-raise it return False if filtered_exc is None: @@ -155,9 +172,16 @@ def __exit__(self, etype, exc, tb): # delete references from locals to avoid creating cycles # see test_MultiError_catch_doesnt_create_cyclic_garbage del _, filtered_exc, value + return False + + +if TYPE_CHECKING: + _BaseExceptionGroup = BaseExceptionGroup[BaseException] +else: + _BaseExceptionGroup = BaseExceptionGroup -class MultiError(BaseExceptionGroup): +class MultiError(_BaseExceptionGroup): """An exception that contains other exceptions; also known as an "inception". @@ -180,7 +204,9 @@ class MultiError(BaseExceptionGroup): """ - def __init__(self, exceptions, *, _collapse=True): + def __init__( + self, exceptions: Sequence[BaseException], *, _collapse: bool = True + ) -> None: self.collapse = _collapse # Avoid double initialization when _collapse is True and exceptions[0] returned @@ -191,7 +217,9 @@ def __init__(self, exceptions, *, _collapse=True): super().__init__("multiple tasks failed", exceptions) - def __new__(cls, exceptions, *, _collapse=True): + def __new__( # type: ignore[misc] # mypy says __new__ must return a class instance + cls, exceptions: Sequence[BaseException], *, _collapse: bool = True + ) -> NonBaseMultiError | Self | BaseException: exceptions = list(exceptions) for exc in exceptions: if not isinstance(exc, BaseException): @@ -208,26 +236,54 @@ def __new__(cls, exceptions, *, _collapse=True): # In an earlier version of the code, we didn't define __init__ and # simply set the `exceptions` attribute directly on the new object. # However, linters expect attributes to be initialized in __init__. + from_class: type[Self] | type[NonBaseMultiError] = cls if all(isinstance(exc, Exception) for exc in exceptions): - cls = NonBaseMultiError - - return super().__new__(cls, "multiple tasks failed", exceptions) + from_class = NonBaseMultiError + + # Ignoring arg-type: 'Argument 3 to "__new__" of "BaseExceptionGroup" has incompatible type "list[BaseException]"; expected "Sequence[_BaseExceptionT_co]"' + # We have checked that exceptions is indeed a list of BaseException objects, this is fine. + new_obj = super().__new__(from_class, "multiple tasks failed", exceptions) # type: ignore[arg-type] + assert isinstance(new_obj, (cls, NonBaseMultiError)) + return new_obj + + def __reduce__( + self, + ) -> tuple[object, tuple[type[Self], list[BaseException]], dict[str, bool]]: + return ( + self.__new__, + (self.__class__, list(self.exceptions)), + {"collapse": self.collapse}, + ) - def __str__(self): + def __str__(self) -> str: return ", ".join(repr(exc) for exc in self.exceptions) - def __repr__(self): + def __repr__(self) -> str: return f"" - def derive(self, __excs): + @overload # type: ignore[override] # 'Exception' != '_ExceptionT' + def derive(self, excs: Sequence[Exception], /) -> NonBaseMultiError: + ... + + @overload + def derive(self, excs: Sequence[BaseException], /) -> MultiError: + ... + + def derive( + self, excs: Sequence[Exception | BaseException], / + ) -> NonBaseMultiError | MultiError: # We use _collapse=False here to get ExceptionGroup semantics, since derive() # is part of the PEP 654 API - exc = MultiError(__excs, _collapse=False) + exc = MultiError(excs, _collapse=False) exc.collapse = self.collapse return exc @classmethod - def filter(cls, handler, root_exc): + def filter( + cls, + handler: Callable[[BaseException], BaseException | None], + root_exc: BaseException, + ) -> BaseException | None: """Apply the given ``handler`` to all the exceptions in ``root_exc``. Args: @@ -251,7 +307,9 @@ def filter(cls, handler, root_exc): return _filter_impl(handler, root_exc) @classmethod - def catch(cls, handler): + def catch( + cls, handler: Callable[[BaseException], BaseException | None] + ) -> MultiErrorCatcher: """Return a context manager that catches and re-throws exceptions after running :meth:`filter` on them. @@ -269,8 +327,14 @@ def catch(cls, handler): return MultiErrorCatcher(handler) -class NonBaseMultiError(MultiError, ExceptionGroup): - pass +if TYPE_CHECKING: + _ExceptionGroup = ExceptionGroup[Exception] +else: + _ExceptionGroup = ExceptionGroup + + +class NonBaseMultiError(MultiError, _ExceptionGroup): + __slots__ = () # Clean up exception printing: @@ -299,30 +363,6 @@ class NonBaseMultiError(MultiError, ExceptionGroup): try: import tputil except ImportError: - have_tproxy = False -else: - have_tproxy = True - -if have_tproxy: - # http://doc.pypy.org/en/latest/objspace-proxies.html - def copy_tb(base_tb, tb_next): - def controller(operation): - # Rationale for pragma: I looked fairly carefully and tried a few - # things, and AFAICT it's not actually possible to get any - # 'opname' that isn't __getattr__ or __getattribute__. So there's - # no missing test we could add, and no value in coverage nagging - # us about adding one. - if operation.opname in [ - "__getattribute__", - "__getattr__", - ]: # pragma: no cover - if operation.args[0] == "tb_next": - return tb_next - return operation.delegate() - - return tputil.make_proxy(controller, type(base_tb), base_tb) - -else: # ctypes it is import ctypes @@ -342,12 +382,13 @@ class CTraceback(ctypes.Structure): ("tb_lineno", ctypes.c_int), ] - def copy_tb(base_tb, tb_next): + def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: # TracebackType has no public constructor, so allocate one the hard way try: raise ValueError except ValueError as exc: new_tb = exc.__traceback__ + assert new_tb is not None c_new_tb = CTraceback.from_address(id(new_tb)) # At the C level, tb_next either pointer to the next traceback or is @@ -360,14 +401,14 @@ def copy_tb(base_tb, tb_next): # which it already is, so we're done. Otherwise, we have to actually # do some work: if tb_next is not None: - _ctypes.Py_INCREF(tb_next) + _ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined] c_new_tb.tb_next = id(tb_next) assert c_new_tb.tb_frame is not None - _ctypes.Py_INCREF(base_tb.tb_frame) + _ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined] old_tb_frame = new_tb.tb_frame c_new_tb.tb_frame = id(base_tb.tb_frame) - _ctypes.Py_DECREF(old_tb_frame) + _ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined] c_new_tb.tb_lasti = base_tb.tb_lasti c_new_tb.tb_lineno = base_tb.tb_lineno @@ -379,8 +420,33 @@ def copy_tb(base_tb, tb_next): # see test_MultiError_catch_doesnt_create_cyclic_garbage del new_tb, old_tb_frame +else: + # http://doc.pypy.org/en/latest/objspace-proxies.html + def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: + # Mypy refuses to believe that ProxyOperation can be imported properly + # TODO: will need no-any-unimported if/when that's toggled on + def controller(operation: tputil.ProxyOperation) -> Any | None: + # Rationale for pragma: I looked fairly carefully and tried a few + # things, and AFAICT it's not actually possible to get any + # 'opname' that isn't __getattr__ or __getattribute__. So there's + # no missing test we could add, and no value in coverage nagging + # us about adding one. + if operation.opname in [ + "__getattribute__", + "__getattr__", + ]: # pragma: no cover + if operation.args[0] == "tb_next": + return tb_next + return operation.delegate() # Deligate is reverting to original behaviour + + return cast( + TracebackType, tputil.make_proxy(controller, type(base_tb), base_tb) + ) # Returns proxy to traceback + -def concat_tb(head, tail): +def concat_tb( + head: TracebackType | None, tail: TracebackType | None +) -> TracebackType | None: # We have to use an iterative algorithm here, because in the worst case # this might be a RecursionError stack that is by definition too deep to # process by recursion! @@ -412,7 +478,13 @@ def concat_tb(head, tail): ) else: - def trio_show_traceback(self, etype, value, tb, tb_offset=None): + def trio_show_traceback( + self: IPython.core.interactiveshell.InteractiveShell, + etype: type[BaseException], + value: BaseException, + tb: TracebackType, + tb_offset: int | None = None, + ) -> None: # XX it would be better to integrate with IPython's fancy # exception formatting stuff (and not ignore tb_offset) print_exception(value) @@ -443,10 +515,14 @@ def trio_show_traceback(self, etype, value, tb, tb_offset=None): assert sys.excepthook is apport_python_hook.apport_excepthook - def replacement_excepthook(etype, value, tb): - sys.stderr.write("".join(format_exception(etype, value, tb))) + def replacement_excepthook( + etype: type[BaseException], value: BaseException, tb: TracebackType | None + ) -> None: + # This does work, it's an overloaded function + sys.stderr.write("".join(format_exception(etype, value, tb))) # type: ignore[arg-type] fake_sys = ModuleType("trio_fake_sys") fake_sys.__dict__.update(sys.__dict__) - fake_sys.__excepthook__ = replacement_excepthook # type: ignore + # Fake does have __excepthook__ after __dict__ update, but type checkers don't recognize this + fake_sys.__excepthook__ = replacement_excepthook # type: ignore[attr-defined] apport_python_hook.sys = fake_sys diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index f38123540f..6510745e5b 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -69,17 +69,34 @@ # unpark is called. # # See: https://github.com/python-trio/trio/issues/53 +from __future__ import annotations -import attr +import math from collections import OrderedDict +from collections.abc import Iterator +from typing import TYPE_CHECKING + +import attr from .. import _core from .._util import Final +if TYPE_CHECKING: + from ._run import Task + @attr.s(frozen=True, slots=True) -class _ParkingLotStatistics: - tasks_waiting = attr.ib() +class ParkingLotStatistics: + """An object containing debugging information for a ParkingLot. + + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this lot's + :meth:`trio.lowlevel.ParkingLot.park` method. + + """ + + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False, slots=True) @@ -98,13 +115,13 @@ class ParkingLot(metaclass=Final): # {task: None}, we just want a deque where we can quickly delete random # items - _parked = attr.ib(factory=OrderedDict, init=False) + _parked: OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False) - def __len__(self): + def __len__(self) -> int: """Returns the number of parked tasks.""" return len(self._parked) - def __bool__(self): + def __bool__(self) -> bool: """True if there are parked tasks, False otherwise.""" return bool(self._parked) @@ -113,7 +130,7 @@ def __bool__(self): # line (for false wakeups), then we could have it return a ticket that # abstracts the "place in line" concept. @_core.enable_ki_protection - async def park(self): + async def park(self) -> None: """Park the current task until woken by a call to :meth:`unpark` or :meth:`unpark_all`. @@ -122,19 +139,26 @@ async def park(self): self._parked[task] = None task.custom_sleep_data = self - def abort_fn(_): + def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: del task.custom_sleep_data._parked[task] return _core.Abort.SUCCEEDED await _core.wait_task_rescheduled(abort_fn) - def _pop_several(self, count): - for _ in range(min(count, len(self._parked))): + def _pop_several(self, count: int | float) -> Iterator[Task]: + if isinstance(count, float): + if math.isinf(count): + count = len(self._parked) + else: + raise ValueError("Cannot pop a non-integer number of tasks.") + else: + count = min(count, len(self._parked)) + for _ in range(count): task, _ = self._parked.popitem(last=False) yield task @_core.enable_ki_protection - def unpark(self, *, count=1): + def unpark(self, *, count: int | float = 1) -> list[Task]: """Unpark one or more tasks. This wakes up ``count`` tasks that are blocked in :meth:`park`. If @@ -142,7 +166,7 @@ def unpark(self, *, count=1): are available and then returns successfully. Args: - count (int): the number of tasks to unpark. + count (int | math.inf): the number of tasks to unpark. """ tasks = list(self._pop_several(count)) @@ -150,12 +174,12 @@ def unpark(self, *, count=1): _core.reschedule(task) return tasks - def unpark_all(self): + def unpark_all(self) -> list[Task]: """Unpark all parked tasks.""" return self.unpark(count=len(self)) @_core.enable_ki_protection - def repark(self, new_lot, *, count=1): + def repark(self, new_lot: ParkingLot, *, count: int | float = 1) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. This dequeues ``count`` tasks from one lot, and requeues them on @@ -185,7 +209,7 @@ async def main(): Args: new_lot (ParkingLot): the parking lot to move tasks to. - count (int): the number of tasks to move. + count (int|math.inf): the number of tasks to move. """ if not isinstance(new_lot, ParkingLot): @@ -194,7 +218,7 @@ async def main(): new_lot._parked[task] = None task.custom_sleep_data = new_lot - def repark_all(self, new_lot): + def repark_all(self, new_lot: ParkingLot) -> None: """Move all parked tasks from one :class:`ParkingLot` object to another. @@ -203,7 +227,7 @@ def repark_all(self, new_lot): """ return self.repark(new_lot, count=len(self)) - def statistics(self): + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -212,4 +236,4 @@ def statistics(self): :meth:`park` method. """ - return _ParkingLotStatistics(tasks_waiting=len(self._parked)) + return ParkingLotStatistics(tasks_waiting=len(self._parked)) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index c07e29ab97..b2f3a65ddd 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -10,20 +10,38 @@ import threading import warnings from collections import deque -from collections.abc import Callable -from contextlib import contextmanager +from collections.abc import ( + Awaitable, + Callable, + Coroutine, + Generator, + Iterator, + Sequence, +) +from contextlib import AbstractAsyncContextManager, contextmanager from contextvars import copy_context from heapq import heapify, heappop, heappush from math import inf from time import perf_counter -from typing import TYPE_CHECKING, Any, NoReturn, TypeVar +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + NoReturn, + Protocol, + TypeVar, + cast, + final, + overload, +) import attr from outcome import Error, Outcome, Value, capture -from sniffio import current_async_library_cvar +from sniffio import thread_local as sniffio_library from sortedcontainers import SortedDict from .. import _core +from .._abc import Clock, Instrument from .._util import Final, NoPublicConstructor, coroutine_or_error from ._asyncgens import AsyncGenerators from ._entry_queue import EntryQueue, TrioToken @@ -44,15 +62,31 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup +from types import FrameType + if TYPE_CHECKING: + import contextvars + # An unfortunate name collision here with trio._util.Final - from typing_extensions import Final as FinalT + from typing import Final as FinalT + + from typing_extensions import Self DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 -_NO_SEND: FinalT = object() +# Passed as a sentinel +_NO_SEND: FinalT = cast("Outcome[Any]", object()) FnT = TypeVar("FnT", bound="Callable[..., Any]") +StatusT = TypeVar("StatusT") +StatusT_co = TypeVar("StatusT_co", covariant=True) +StatusT_contra = TypeVar("StatusT_contra", contravariant=True) +RetT = TypeVar("RetT") + + +@final +class _NoStatus(metaclass=NoPublicConstructor): + """Sentinel for unset TaskStatus._value.""" # Decorator to mark methods public. This does nothing by itself, but @@ -115,7 +149,7 @@ def function_with_unique_name_xyzzy() -> NoReturn: @attr.s(frozen=True, slots=True) -class SystemClock: +class SystemClock(Clock): # Add a large random offset to our clock to ensure that if people # accidentally call time.perf_counter() directly or start comparing clocks # between different runs, then they'll notice the bug quickly: @@ -144,7 +178,9 @@ class IdlePrimedTypes(enum.Enum): ################################################################ -def collapse_exception_group(excgroup): +def collapse_exception_group( + excgroup: BaseExceptionGroup[BaseException], +) -> BaseException: """Recursively collapse any single-exception groups into that single contained exception. @@ -179,18 +215,18 @@ class Deadlines: """ # Heap of (deadline, id(CancelScope), CancelScope) - _heap = attr.ib(factory=list) + _heap: list[tuple[float, int, CancelScope]] = attr.ib(factory=list) # Count of active deadlines (those that haven't been changed) - _active = attr.ib(default=0) + _active: int = attr.ib(default=0) - def add(self, deadline, cancel_scope): + def add(self, deadline: float, cancel_scope: CancelScope) -> None: heappush(self._heap, (deadline, id(cancel_scope), cancel_scope)) self._active += 1 - def remove(self, deadline, cancel_scope): + def remove(self, deadline: float, cancel_scope: CancelScope) -> None: self._active -= 1 - def next_deadline(self): + def next_deadline(self) -> float: while self._heap: deadline, _, cancel_scope = self._heap[0] if deadline == cancel_scope._registered_deadline: @@ -200,7 +236,7 @@ def next_deadline(self): heappop(self._heap) return inf - def _prune(self): + def _prune(self) -> None: # In principle, it's possible for a cancel scope to toggle back and # forth repeatedly between the same two deadlines, and end up with # lots of stale entries that *look* like they're still active, because @@ -221,7 +257,7 @@ def _prune(self): heapify(pruned_heap) self._heap = pruned_heap - def expire(self, now): + def expire(self, now: float) -> bool: did_something = False while self._heap and self._heap[0][0] <= now: deadline, _, cancel_scope = heappop(self._heap) @@ -271,7 +307,7 @@ class CancelStatus: # Our associated cancel scope. Can be any object with attributes # `deadline`, `shield`, and `cancel_called`, but in current usage # is always a CancelScope object. Must not be None. - _scope = attr.ib() + _scope: CancelScope = attr.ib() # True iff the tasks in self._tasks should receive cancellations # when they checkpoint. Always True when scope.cancel_called is True; @@ -281,31 +317,31 @@ class CancelStatus: # effectively cancelled due to the cancel scope two levels out # becoming cancelled, but then the cancel scope one level out # becomes shielded so we're not effectively cancelled anymore. - effectively_cancelled = attr.ib(default=False) + effectively_cancelled: bool = attr.ib(default=False) # The CancelStatus whose cancellations can propagate to us; we # become effectively cancelled when they do, unless scope.shield # is True. May be None (for the outermost CancelStatus in a call # to trio.run(), briefly during TaskStatus.started(), or during # recovery from mis-nesting of cancel scopes). - _parent = attr.ib(default=None, repr=False) + _parent: CancelStatus | None = attr.ib(default=None, repr=False) # All of the CancelStatuses that have this CancelStatus as their parent. - _children = attr.ib(factory=set, init=False, repr=False) + _children: set[CancelStatus] = attr.ib(factory=set, init=False, repr=False) # Tasks whose cancellation state is currently tied directly to # the cancellation state of this CancelStatus object. Don't modify # this directly; instead, use Task._activate_cancel_status(). # Invariant: all(task._cancel_status is self for task in self._tasks) - _tasks = attr.ib(factory=set, init=False, repr=False) + _tasks: set[Task] = attr.ib(factory=set, init=False, repr=False) # Set to True on still-active cancel statuses that are children # of a cancel status that's been closed. This is used to permit # recovery from mis-nested cancel scopes (well, at least enough # recovery to show a useful traceback). - abandoned_by_misnesting = attr.ib(default=False, init=False, repr=False) + abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self._parent is not None: self._parent._children.add(self) self.recalculate() @@ -313,11 +349,11 @@ def __attrs_post_init__(self): # parent/children/tasks accessors are used by TaskStatus.started() @property - def parent(self): + def parent(self) -> CancelStatus | None: return self._parent @parent.setter - def parent(self, parent): + def parent(self, parent: CancelStatus) -> None: if self._parent is not None: self._parent._children.remove(self) self._parent = parent @@ -326,14 +362,14 @@ def parent(self, parent): self.recalculate() @property - def children(self): + def children(self) -> frozenset[CancelStatus]: return frozenset(self._children) @property - def tasks(self): + def tasks(self) -> frozenset[Task]: return frozenset(self._tasks) - def encloses(self, other): + def encloses(self, other: CancelStatus | None) -> bool: """Returns true if this cancel status is a direct or indirect parent of cancel status *other*, or if *other* is *self*. """ @@ -343,7 +379,7 @@ def encloses(self, other): other = other.parent return False - def close(self): + def close(self) -> None: self.parent = None # now we're not a child of self.parent anymore if self._tasks or self._children: # Cancel scopes weren't exited in opposite order of being @@ -372,14 +408,14 @@ def close(self): child.recalculate() @property - def parent_cancellation_is_visible_to_us(self): + def parent_cancellation_is_visible_to_us(self) -> bool: return ( self._parent is not None and not self._scope.shield and self._parent.effectively_cancelled ) - def recalculate(self): + def recalculate(self) -> None: # This does a depth-first traversal over this and descendent cancel # statuses, to ensure their state is up-to-date. It's basically a # recursive algorithm, but we use an explicit stack to avoid any @@ -398,12 +434,12 @@ def recalculate(self): task._attempt_delivery_of_any_pending_cancel() todo.extend(current._children) - def _mark_abandoned(self): + def _mark_abandoned(self) -> None: self.abandoned_by_misnesting = True for child in self._children: child._mark_abandoned() - def effective_deadline(self): + def effective_deadline(self) -> float: if self.effectively_cancelled: return -inf if self._parent is None or self._scope.shield: @@ -435,6 +471,7 @@ def effective_deadline(self): """ +@final @attr.s(eq=False, repr=False, slots=True) class CancelScope(metaclass=Final): """A *cancellation scope*: the link between a unit of cancellable @@ -475,18 +512,18 @@ class CancelScope(metaclass=Final): has been entered yet, and changes take immediate effect. """ - _cancel_status = attr.ib(default=None, init=False) - _has_been_entered = attr.ib(default=False, init=False) - _registered_deadline = attr.ib(default=inf, init=False) - _cancel_called = attr.ib(default=False, init=False) - cancelled_caught = attr.ib(default=False, init=False) + _cancel_status: CancelStatus | None = attr.ib(default=None, init=False) + _has_been_entered: bool = attr.ib(default=False, init=False) + _registered_deadline: float = attr.ib(default=inf, init=False) + _cancel_called: bool = attr.ib(default=False, init=False) + cancelled_caught: bool = attr.ib(default=False, init=False) # Constructor arguments: - _deadline = attr.ib(default=inf, kw_only=True) - _shield = attr.ib(default=False, kw_only=True) + _deadline: float = attr.ib(default=inf, kw_only=True) + _shield: bool = attr.ib(default=False, kw_only=True) @enable_ki_protection - def __enter__(self): + def __enter__(self) -> Self: task = _core.current_task() if self._has_been_entered: raise RuntimeError( @@ -500,7 +537,7 @@ def __enter__(self): task._activate_cancel_status(self._cancel_status) return self - def _close(self, exc): + def _close(self, exc: BaseException | None) -> BaseException | None: if self._cancel_status is None: new_exc = RuntimeError( "Cancel scope stack corrupted: attempted to exit {!r} " @@ -573,7 +610,12 @@ def _close(self, exc): self._cancel_status = None return exc - def __exit__(self, etype, exc, tb): + def __exit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: # NB: NurseryManager calls _close() directly rather than __exit__(), # so __exit__() must be just _close() plus this logic for adapting # the exception-filtering result to the context manager API. @@ -607,7 +649,7 @@ def __exit__(self, etype, exc, tb): # TODO: check if PEP558 changes the need for this call # https://github.com/python/cpython/pull/3640 - def __repr__(self): + def __repr__(self) -> str: if self._cancel_status is not None: binding = "active" elif self._has_been_entered: @@ -634,7 +676,7 @@ def __repr__(self): @contextmanager @enable_ki_protection - def _might_change_registered_deadline(self): + def _might_change_registered_deadline(self) -> Iterator[None]: try: yield finally: @@ -658,7 +700,7 @@ def _might_change_registered_deadline(self): runner.force_guest_tick_asap() @property - def deadline(self): + def deadline(self) -> float: """Read-write, :class:`float`. An absolute time on the current run's clock at which this scope will automatically become cancelled. You can adjust the deadline by modifying this @@ -684,12 +726,12 @@ def deadline(self): return self._deadline @deadline.setter - def deadline(self, new_deadline): + def deadline(self, new_deadline: float) -> None: with self._might_change_registered_deadline(): self._deadline = float(new_deadline) @property - def shield(self): + def shield(self) -> bool: """Read-write, :class:`bool`, default :data:`False`. So long as this is set to :data:`True`, then the code inside this scope will not receive :exc:`~trio.Cancelled` exceptions from scopes @@ -714,7 +756,7 @@ def shield(self): @shield.setter @enable_ki_protection - def shield(self, new_value): + def shield(self, new_value: bool) -> None: if not isinstance(new_value, bool): raise TypeError("shield must be a bool") self._shield = new_value @@ -722,7 +764,7 @@ def shield(self, new_value): self._cancel_status.recalculate() @enable_ki_protection - def cancel(self): + def cancel(self) -> None: """Cancels this scope immediately. This method is idempotent, i.e., if the scope was already @@ -736,7 +778,7 @@ def cancel(self): self._cancel_status.recalculate() @property - def cancel_called(self): + def cancel_called(self) -> bool: """Readonly :class:`bool`. Records whether cancellation has been requested for this scope, either by an explicit call to :meth:`cancel` or by the deadline expiring. @@ -770,28 +812,57 @@ def cancel_called(self): ################################################################ +class TaskStatus(Protocol[StatusT_contra]): + """The interface provided by :meth:`Nursery.start()` to the spawned task. + + This is provided via the ``task_status`` keyword-only parameter. + """ + + @overload + def started(self: TaskStatus[None]) -> None: + ... + + @overload + def started(self, value: StatusT_contra) -> None: + ... + + def started(self, value: StatusT_contra | None = None) -> None: + """Tasks call this method to indicate that they have initialized. + + See `nursery.start() ` for more information. + """ + + # This code needs to be read alongside the code from Nursery.start to make # sense. @attr.s(eq=False, hash=False, repr=False) -class _TaskStatus: - _old_nursery = attr.ib() - _new_nursery = attr.ib() - _called_started = attr.ib(default=False) - _value = attr.ib(default=None) +class _TaskStatus(TaskStatus[StatusT]): + _old_nursery: Nursery = attr.ib() + _new_nursery: Nursery = attr.ib() + # NoStatus is a sentinel. + _value: StatusT | type[_NoStatus] = attr.ib(default=_NoStatus) - def __repr__(self): + def __repr__(self) -> str: return f"" - def started(self, value=None): - if self._called_started: + @overload + def started(self: _TaskStatus[None]) -> None: + ... + + @overload + def started(self: _TaskStatus[StatusT], value: StatusT) -> None: + ... + + def started(self, value: StatusT | None = None) -> None: + if self._value is not _NoStatus: raise RuntimeError("called 'started' twice on the same task status") - self._called_started = True - self._value = value + self._value = cast(StatusT, value) # If None, StatusT == None # If the old nursery is cancelled, then quietly quit now; the child # will eventually exit on its own, and we don't want to risk moving # children that might have propagating Cancelled exceptions into # a place with no cancelled cancel scopes to catch them. + assert self._old_nursery._cancel_status is not None if self._old_nursery._cancel_status.effectively_cancelled: return @@ -846,10 +917,10 @@ class NurseryManager: """ - strict_exception_groups = attr.ib(default=False) + strict_exception_groups: bool = attr.ib(default=False) @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self) -> Nursery: self._scope = CancelScope() self._scope.__enter__() self._nursery = Nursery._create( @@ -858,7 +929,12 @@ async def __aenter__(self): return self._nursery @enable_ki_protection - async def __aexit__(self, etype, exc, tb): + async def __aexit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: new_exc = await self._nursery._nested_child_finished(exc) # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. @@ -881,16 +957,26 @@ async def __aexit__(self, etype, exc, tb): # see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage del _, combined_error_from_nursery, value, new_exc - def __enter__(self): - raise RuntimeError( - "use 'async with open_nursery(...)', not 'with open_nursery(...)'" - ) + # make sure these raise errors in static analysis if called + if not TYPE_CHECKING: + + def __enter__(self) -> NoReturn: + raise RuntimeError( + "use 'async with open_nursery(...)', not 'with open_nursery(...)'" + ) - def __exit__(self): # pragma: no cover - assert False, """Never called, but should be defined""" + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> NoReturn: # pragma: no cover + raise AssertionError("Never called, but should be defined") -def open_nursery(strict_exception_groups=None): +def open_nursery( + strict_exception_groups: bool | None = None, +) -> AbstractAsyncContextManager[Nursery]: """Returns an async context manager which must be used to create a new `Nursery`. @@ -909,6 +995,7 @@ def open_nursery(strict_exception_groups=None): return NurseryManager(strict_exception_groups=strict_exception_groups) +@final class Nursery(metaclass=NoPublicConstructor): """A context which may be used to spawn (or cancel) child tasks. @@ -931,7 +1018,12 @@ class Nursery(metaclass=NoPublicConstructor): in response to some external event. """ - def __init__(self, parent_task, cancel_scope, strict_exception_groups): + def __init__( + self, + parent_task: Task, + cancel_scope: CancelScope, + strict_exception_groups: bool, + ): self._parent_task = parent_task self._strict_exception_groups = strict_exception_groups parent_task._child_nurseries.append(self) @@ -942,8 +1034,8 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups): # children. self.cancel_scope = cancel_scope assert self.cancel_scope._cancel_status is self._cancel_status - self._children = set() - self._pending_excs = [] + self._children: set[Task] = set() + self._pending_excs: list[BaseException] = [] # The "nested child" is how this code refers to the contents of the # nursery's 'async with' block, which acts like a child Task in all # the ways we can make it. @@ -953,34 +1045,36 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups): self._closed = False @property - def child_tasks(self): + def child_tasks(self) -> frozenset[Task]: """(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task` objects which are still running.""" return frozenset(self._children) @property - def parent_task(self): + def parent_task(self) -> Task: "(`~trio.lowlevel.Task`): The Task that opened this nursery." return self._parent_task - def _add_exc(self, exc): + def _add_exc(self, exc: BaseException) -> None: self._pending_excs.append(exc) self.cancel_scope.cancel() - def _check_nursery_closed(self): + def _check_nursery_closed(self) -> None: if not any([self._nested_child_running, self._children, self._pending_starts]): self._closed = True if self._parent_waiting_in_aexit: self._parent_waiting_in_aexit = False GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task) - def _child_finished(self, task, outcome): + def _child_finished(self, task: Task, outcome: Outcome[Any]) -> None: self._children.remove(task) if isinstance(outcome, Error): self._add_exc(outcome.error) self._check_nursery_closed() - async def _nested_child_finished(self, nested_child_exc): + async def _nested_child_finished( + self, nested_child_exc: BaseException | None + ) -> BaseException | None: # Returns MultiError instance (or any exception if the nursery is in loose mode # and there is just one contained exception) if there are pending exceptions if nested_child_exc is not None: @@ -992,7 +1086,7 @@ async def _nested_child_finished(self, nested_child_exc): # If we get cancelled (or have an exception injected, like # KeyboardInterrupt), then save that, but still wait until our # children finish. - def aborted(raise_cancel): + def aborted(raise_cancel: _core.RaiseCancelT) -> Abort: self._add_exc(capture(raise_cancel).error) return Abort.FAILED @@ -1018,8 +1112,15 @@ def aborted(raise_cancel): # avoid a garbage cycle # (see test_nursery_cancel_doesnt_create_cyclic_garbage) del self._pending_excs - - def start_soon(self, async_fn, *args, name=None): + return None + + def start_soon( + self, + # TODO: TypeVarTuple + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: object = None, + ) -> None: """Creates a child task, scheduling ``await async_fn(*args)``. If you want to run a function and immediately wait for its result, @@ -1061,7 +1162,12 @@ def start_soon(self, async_fn, *args, name=None): """ GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) - async def start(self, async_fn, *args, name=None): + async def start( + self, + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: object = None, + ) -> StatusT: r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1070,7 +1176,7 @@ async def start(self, async_fn, *args, name=None): The ``async_fn`` must accept a ``task_status`` keyword argument, and it must make sure that it (or someone) eventually calls - ``task_status.started()``. + :meth:`task_status.started() `. The conventional way to define ``async_fn`` is like:: @@ -1083,49 +1189,48 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): a do-nothing ``started`` method. This way your function supports being called either like ``await nursery.start(async_fn, arg1, arg2)`` or directly like ``await async_fn(arg1, arg2)``, and - either way it can call ``task_status.started()`` without - worrying about which mode it's in. Defining your function like + either way it can call :meth:`task_status.started() ` + without worrying about which mode it's in. Defining your function like this will make it obvious to readers that it supports being used in both modes. - Before the child calls ``task_status.started()``, it's - effectively run underneath the call to :meth:`start`: if it + Before the child calls :meth:`task_status.started() `, + it's effectively run underneath the call to :meth:`start`: if it raises an exception then that exception is reported by :meth:`start`, and does *not* propagate out of the nursery. If :meth:`start` is cancelled, then the child task is also cancelled. - When the child calls ``task_status.started()``, it's moved out - from underneath :meth:`start` and into the given nursery. + When the child calls :meth:`task_status.started() `, + it's moved out from underneath :meth:`start` and into the given nursery. - If the child task passes a value to - ``task_status.started(value)``, then :meth:`start` returns this - value. Otherwise it returns ``None``. + If the child task passes a value to :meth:`task_status.started(value) `, + then :meth:`start` returns this value. Otherwise, it returns ``None``. """ if self._closed: raise RuntimeError("Nursery is closed to new arrivals") try: self._pending_starts += 1 async with open_nursery() as old_nursery: - task_status = _TaskStatus(old_nursery, self) + task_status: _TaskStatus[StatusT] = _TaskStatus(old_nursery, self) thunk = functools.partial(async_fn, task_status=task_status) task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( thunk, args, old_nursery, name ) task._eventual_parent_nursery = self - # Wait for either _TaskStatus.started or an exception to + # Wait for either TaskStatus.started or an exception to # cancel this nursery: # If we get here, then the child either got reparented or exited - # normally. The complicated logic is all in _TaskStatus.started(). + # normally. The complicated logic is all in TaskStatus.started(). # (Any exceptions propagate directly out of the above.) - if not task_status._called_started: + if task_status._value is _NoStatus: raise RuntimeError("child exited without calling task_status.started()") - return task_status._value + return task_status._value # type: ignore[return-value] # Mypy doesn't narrow yet. finally: self._pending_starts -= 1 self._check_nursery_closed() - def __del__(self): + def __del__(self) -> None: assert not self._children @@ -1134,14 +1239,14 @@ def __del__(self): ################################################################ +@final @attr.s(eq=False, hash=False, repr=False, slots=True) class Task(metaclass=NoPublicConstructor): - _parent_nursery = attr.ib() - coro = attr.ib() - _runner = attr.ib() - name = attr.ib() - # PEP 567 contextvars context - context = attr.ib() + _parent_nursery: Nursery | None = attr.ib() + coro: Coroutine[Any, Outcome[object], Any] = attr.ib() + _runner: Runner = attr.ib() + name: str = attr.ib() + context: contextvars.Context = attr.ib() _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: @@ -1155,26 +1260,26 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn = attr.ib(default=None) - _next_send = attr.ib(default=None) - _abort_func = attr.ib(default=None) - custom_sleep_data = attr.ib(default=None) + _next_send_fn: Callable[[Any], object] = attr.ib(default=None) + _next_send: Outcome[Any] | None | BaseException = attr.ib(default=None) + _abort_func: Callable[[_core.RaiseCancelT], Abort] | None = attr.ib(default=None) + custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() - _child_nurseries = attr.ib(factory=list) - _eventual_parent_nursery = attr.ib(default=None) + _child_nurseries: list[Nursery] = attr.ib(factory=list) + _eventual_parent_nursery: Nursery | None = attr.ib(default=None) # these are counts of how many cancel/schedule points this task has # executed, for assert{_no,}_checkpoints # XX maybe these should be exposed as part of a statistics() method? - _cancel_points = attr.ib(default=0) - _schedule_points = attr.ib(default=0) + _cancel_points: int = attr.ib(default=0) + _schedule_points: int = attr.ib(default=0) - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def parent_nursery(self): + def parent_nursery(self) -> Nursery | None: """The nursery this task is inside (or None if this is the "init" task). @@ -1185,7 +1290,7 @@ def parent_nursery(self): return self._parent_nursery @property - def eventual_parent_nursery(self): + def eventual_parent_nursery(self) -> Nursery | None: """The nursery this task will be inside after it calls ``task_status.started()``. @@ -1197,7 +1302,7 @@ def eventual_parent_nursery(self): return self._eventual_parent_nursery @property - def child_nurseries(self): + def child_nurseries(self) -> list[Nursery]: """The nurseries this task contains. This is a list, with outer nurseries before inner nurseries. @@ -1205,7 +1310,7 @@ def child_nurseries(self): """ return list(self._child_nurseries) - def iter_await_frames(self): + def iter_await_frames(self) -> Iterator[tuple[FrameType, int]]: """Iterates recursively over the coroutine-like objects this task is waiting on, yielding the frame and line number at each frame. @@ -1225,7 +1330,8 @@ def print_stack_for_task(task): print("".join(ss.format())) """ - coro = self.coro + # Ignore static typing as we're doing lots of dynamic introspection + coro: Any = self.coro while coro is not None: if hasattr(coro, "cr_frame"): # A real coroutine @@ -1258,22 +1364,28 @@ def print_stack_for_task(task): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status = attr.ib(default=None, repr=False) + # This can be None, but only in the init task. + _cancel_status: CancelStatus = attr.ib(default=None, repr=False) - def _activate_cancel_status(self, cancel_status): + def _activate_cancel_status(self, cancel_status: CancelStatus | None) -> None: if self._cancel_status is not None: self._cancel_status._tasks.remove(self) - self._cancel_status = cancel_status + self._cancel_status = cancel_status # type: ignore[assignment] if self._cancel_status is not None: self._cancel_status._tasks.add(self) if self._cancel_status.effectively_cancelled: self._attempt_delivery_of_any_pending_cancel() - def _attempt_abort(self, raise_cancel): + def _attempt_abort(self, raise_cancel: _core.RaiseCancelT) -> None: # Either the abort succeeds, in which case we will reschedule the # task, or else it fails, in which case it will worry about # rescheduling itself (hopefully eventually calling reraise to raise # the given exception, but not necessarily). + + # This is only called by the functions immediately below, which both check + # `self.abort_func is not None`. + assert self._abort_func is not None, "FATAL INTERNAL ERROR" + success = self._abort_func(raise_cancel) if type(success) is not Abort: raise TrioInternalError("abort function must return Abort enum") @@ -1283,23 +1395,23 @@ def _attempt_abort(self, raise_cancel): if success is Abort.SUCCEEDED: self._runner.reschedule(self, capture(raise_cancel)) - def _attempt_delivery_of_any_pending_cancel(self): + def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return if not self._cancel_status.effectively_cancelled: return - def raise_cancel(): + def raise_cancel() -> NoReturn: raise Cancelled._create() self._attempt_abort(raise_cancel) - def _attempt_delivery_of_pending_ki(self): + def _attempt_delivery_of_pending_ki(self) -> None: assert self._runner.ki_pending if self._abort_func is None: return - def raise_cancel(): + def raise_cancel() -> NoReturn: self._runner.ki_pending = False raise KeyboardInterrupt @@ -1312,20 +1424,42 @@ def raise_cancel(): class RunContext(threading.local): - runner: "Runner" + runner: Runner task: Task GLOBAL_RUN_CONTEXT: FinalT = RunContext() -@attr.s(frozen=True) -class _RunStatistics: - tasks_living = attr.ib() - tasks_runnable = attr.ib() - seconds_to_next_deadline = attr.ib() - io_statistics = attr.ib() - run_sync_soon_queue_size = attr.ib() +@attr.frozen +class RunStatistics: + """An object containing run-loop-level debugging information. + + Currently, the following fields are defined: + + * ``tasks_living`` (int): The number of tasks that have been spawned + and not yet exited. + * ``tasks_runnable`` (int): The number of tasks that are currently + queued on the run queue (as opposed to blocked waiting for something + to happen). + * ``seconds_to_next_deadline`` (float): The time until the next + pending cancel scope deadline. May be negative if the deadline has + expired but we haven't yet processed cancellations. May be + :data:`~math.inf` if there are no pending deadlines. + * ``run_sync_soon_queue_size`` (int): The number of + unprocessed callbacks queued via + :meth:`trio.lowlevel.TrioToken.run_sync_soon`. + * ``io_statistics`` (object): Some statistics from Trio's I/O + backend. This always has an attribute ``backend`` which is a string + naming which operating-system-specific I/O backend is in use; the + other attributes vary between backends. + """ + + tasks_living: int + tasks_runnable: int + seconds_to_next_deadline: float + io_statistics: IOStatistics + run_sync_soon_queue_size: int # This holds all the state that gets trampolined back and forth between @@ -1349,26 +1483,32 @@ class _RunStatistics: # worker thread. @attr.s(eq=False, hash=False, slots=True) class GuestState: - runner = attr.ib() - run_sync_soon_threadsafe = attr.ib() - run_sync_soon_not_threadsafe = attr.ib() - done_callback = attr.ib() - unrolled_run_gen = attr.ib() - _value_factory: Callable[[], Value] = lambda: Value(None) - unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) - - def guest_tick(self): + runner: Runner = attr.ib() + run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] = attr.ib() + run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] = attr.ib() + done_callback: Callable[[Outcome[Any]], object] = attr.ib() + unrolled_run_gen: Generator[float, EventResult, None] = attr.ib() + _value_factory: Callable[[], Value[Any]] = lambda: Value(None) + unrolled_run_next_send: Outcome[Any] = attr.ib(factory=_value_factory) + + def guest_tick(self) -> None: + prev_library, sniffio_library.name = sniffio_library.name, "trio" try: timeout = self.unrolled_run_next_send.send(self.unrolled_run_gen) except StopIteration: + assert self.runner.main_task_outcome is not None self.done_callback(self.runner.main_task_outcome) return except TrioInternalError as exc: self.done_callback(Error(exc)) return + finally: + sniffio_library.name = prev_library # Optimization: try to skip going into the thread if we can avoid it - events_outcome = capture(self.runner.io_manager.get_events, 0) + events_outcome: Value[EventResult] | Error = capture( + self.runner.io_manager.get_events, 0 + ) if timeout <= 0 or isinstance(events_outcome, Error) or events_outcome.value: # No need to go into the thread self.unrolled_run_next_send = events_outcome @@ -1378,11 +1518,11 @@ def guest_tick(self): # Need to go into the thread and call get_events() there self.runner.guest_tick_scheduled = False - def get_events(): + def get_events() -> EventResult: return self.runner.io_manager.get_events(timeout) - def deliver(events_outcome): - def in_main_thread(): + def deliver(events_outcome: Outcome[EventResult]) -> None: + def in_main_thread() -> None: self.unrolled_run_next_send = events_outcome self.runner.guest_tick_scheduled = True self.guest_tick() @@ -1394,44 +1534,44 @@ def in_main_thread(): @attr.s(eq=False, hash=False, slots=True) class Runner: - clock = attr.ib() + clock: Clock = attr.ib() instruments: Instruments = attr.ib() - io_manager = attr.ib() - ki_manager = attr.ib() - strict_exception_groups = attr.ib() + io_manager: TheIOManager = attr.ib() + ki_manager: KIManager = attr.ib() + strict_exception_groups: bool = attr.ib() # Run-local values, see _local.py - _locals = attr.ib(factory=dict) + _locals: dict[_core.RunVar[Any], Any] = attr.ib(factory=dict) runq: deque[Task] = attr.ib(factory=deque) - tasks = attr.ib(factory=set) + tasks: set[Task] = attr.ib(factory=set) - deadlines = attr.ib(factory=Deadlines) + deadlines: Deadlines = attr.ib(factory=Deadlines) - init_task = attr.ib(default=None) - system_nursery = attr.ib(default=None) - system_context = attr.ib(default=None) - main_task = attr.ib(default=None) - main_task_outcome = attr.ib(default=None) + init_task: Task | None = attr.ib(default=None) + system_nursery: Nursery | None = attr.ib(default=None) + system_context: contextvars.Context = attr.ib(kw_only=True) + main_task: Task | None = attr.ib(default=None) + main_task_outcome: Outcome[Any] | None = attr.ib(default=None) - entry_queue = attr.ib(factory=EntryQueue) - trio_token = attr.ib(default=None) - asyncgens = attr.ib(factory=AsyncGenerators) + entry_queue: EntryQueue = attr.ib(factory=EntryQueue) + trio_token: TrioToken | None = attr.ib(default=None) + asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators) # If everything goes idle for this long, we call clock._autojump() - clock_autojump_threshold = attr.ib(default=inf) + clock_autojump_threshold: float = attr.ib(default=inf) # Guest mode stuff - is_guest = attr.ib(default=False) - guest_tick_scheduled = attr.ib(default=False) + is_guest: bool = attr.ib(default=False) + guest_tick_scheduled: bool = attr.ib(default=False) - def force_guest_tick_asap(self): + def force_guest_tick_asap(self) -> None: if self.guest_tick_scheduled: return self.guest_tick_scheduled = True self.io_manager.force_wakeup() - def close(self): + def close(self) -> None: self.io_manager.close() self.entry_queue.close() self.asyncgens.close() @@ -1441,10 +1581,10 @@ def close(self): self.ki_manager.close() @_public - def current_statistics(self): - """Returns an object containing run-loop-level debugging information. + def current_statistics(self) -> RunStatistics: + """Returns ``RunStatistics``, which contains run-loop-level debugging information. - Currently the following fields are defined: + Currently, the following fields are defined: * ``tasks_living`` (int): The number of tasks that have been spawned and not yet exited. @@ -1465,7 +1605,7 @@ def current_statistics(self): """ seconds_to_next_deadline = self.deadlines.next_deadline() - self.current_time() - return _RunStatistics( + return RunStatistics( tasks_living=len(self.tasks), tasks_runnable=len(self.runq), seconds_to_next_deadline=seconds_to_next_deadline, @@ -1474,7 +1614,7 @@ def current_statistics(self): ) @_public - def current_time(self): + def current_time(self) -> float: """Returns the current time according to Trio's internal clock. Returns: @@ -1487,12 +1627,12 @@ def current_time(self): return self.clock.current_time() @_public - def current_clock(self): + def current_clock(self) -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" return self.clock @_public - def current_root_task(self): + def current_root_task(self) -> Task | None: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -1504,8 +1644,10 @@ def current_root_task(self): # Core task handling primitives ################ - @_public - def reschedule(self, task, next_send=_NO_SEND): + @_public # Type-ignore due to use of Any here. + def reschedule( # type: ignore[misc] + self, task: Task, next_send: Outcome[Any] = _NO_SEND + ) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1539,8 +1681,16 @@ def reschedule(self, task, next_send=_NO_SEND): self.instruments.call("task_scheduled", task) def spawn_impl( - self, async_fn, args, nursery, name, *, system_task=False, context=None - ): + self, + # TODO: TypeVarTuple + async_fn: Callable[..., Awaitable[object]], + args: tuple[object, ...], + nursery: Nursery | None, + name: object, + *, + system_task: bool = False, + context: contextvars.Context | None = None, + ) -> Task: ###### # Make sure the nursery is in working order ###### @@ -1554,23 +1704,20 @@ def spawn_impl( assert self.init_task is None ###### - # Propagate contextvars, and make sure that async_fn can use sniffio. + # Propagate contextvars ###### if context is None: if system_task: context = self.system_context.copy() else: context = copy_context() - # start_soon() or spawn_system_task() might have been invoked - # from a different async library; make sure the new task - # understands it's Trio-flavored. - context.run(current_async_library_cvar.set, "trio") ###### # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. ###### - coro = context.run(coroutine_or_error, async_fn, *args) + # TODO: resolve the type: ignore when implementing TypeVarTuple + coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type] if name is None: name = async_fn @@ -1578,13 +1725,13 @@ def spawn_impl( name = name.func if not isinstance(name, str): try: - name = f"{name.__module__}.{name.__qualname__}" + name = f"{name.__module__}.{name.__qualname__}" # type: ignore[attr-defined] except AttributeError: name = repr(name) if not hasattr(coro, "cr_frame"): # This async function is implemented in C or Cython - async def python_wrapper(orig_coro): + async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: return await orig_coro coro = python_wrapper(coro) @@ -1606,10 +1753,11 @@ async def python_wrapper(orig_coro): self.instruments.call("task_spawned", task) # Special case: normally next_send should be an Outcome, but for the # very first send we have to send a literal unboxed None. - self.reschedule(task, None) + # TODO: remove [unused-ignore] when Outcome is typed + self.reschedule(task, None) # type: ignore[arg-type, unused-ignore] return task - def task_exited(self, task, outcome): + def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: if ( task._cancel_status is not None and task._cancel_status.abandoned_by_misnesting @@ -1648,6 +1796,7 @@ def task_exited(self, task, outcome): if task is self.main_task: self.main_task_outcome = outcome outcome = Value(None) + assert task._parent_nursery is not None, task task._parent_nursery._child_finished(task, outcome) if "task_exited" in self.instruments: @@ -1657,8 +1806,15 @@ def task_exited(self, task, outcome): # System tasks and init ################ - @_public - def spawn_system_task(self, async_fn, *args, name=None, context=None): + @_public # Type-ignore due to use of Any here. + def spawn_system_task( # type: ignore[misc] + self, + # TODO: TypeVarTuple + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: object = None, + context: contextvars.Context | None = None, + ) -> Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -1719,7 +1875,12 @@ def spawn_system_task(self, async_fn, *args, name=None, context=None): context=context, ) - async def init(self, async_fn, args): + async def init( + # TODO: TypeVarTuple + self, + async_fn: Callable[..., Awaitable[object]], + args: tuple[object, ...], + ) -> None: # run_sync_soon task runs here: async with open_nursery() as run_sync_soon_nursery: # All other system tasks run here: @@ -1757,7 +1918,7 @@ async def init(self, async_fn, args): ################ @_public - def current_trio_token(self): + def current_trio_token(self) -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -1770,7 +1931,7 @@ def current_trio_token(self): # KI handling ################ - ki_pending = attr.ib(default=False) + ki_pending: bool = attr.ib(default=False) # deliver_ki is broke. Maybe move all the actual logic and state into # RunToken, and we'll only have one instance per runner? But then we can't @@ -1779,14 +1940,14 @@ def current_trio_token(self): # keep the class public so people can isinstance() it if they want. # This gets called from signal context - def deliver_ki(self): + def deliver_ki(self) -> None: self.ki_pending = True try: self.entry_queue.run_sync_soon(self._deliver_ki_cb) except RunFinishedError: pass - def _deliver_ki_cb(self): + def _deliver_ki_cb(self) -> None: if not self.ki_pending: return # Can't happen because main_task and run_sync_soon_task are created at @@ -1803,10 +1964,12 @@ def _deliver_ki_cb(self): # Quiescing ################ - waiting_for_idle = attr.ib(factory=SortedDict) + # sortedcontainers doesn't have types, and is reportedly very hard to type: + # https://github.com/grantjenks/python-sortedcontainers/issues/68 + waiting_for_idle: Any = attr.ib(factory=SortedDict) @_public - async def wait_all_tasks_blocked(self, cushion=0.0): + async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -1868,7 +2031,7 @@ async def test_lock_fairness(): key = (cushion, id(task)) self.waiting_for_idle[key] = task - def abort(_): + def abort(_: _core.RaiseCancelT) -> Abort: del self.waiting_for_idle[key] return Abort.SUCCEEDED @@ -1943,11 +2106,11 @@ def abort(_): def setup_runner( - clock, - instruments, - restrict_keyboard_interrupt_to_checkpoints, - strict_exception_groups, -): + clock: Clock | None, + instruments: Sequence[Instrument], + restrict_keyboard_interrupt_to_checkpoints: bool, + strict_exception_groups: bool, +) -> Runner: """Create a Runner object and install it as the GLOBAL_RUN_CONTEXT.""" # It wouldn't be *hard* to support nested calls to run(), but I can't # think of a single good reason for it, so let's be conservative for @@ -1957,14 +2120,14 @@ def setup_runner( if clock is None: clock = SystemClock() - instruments = Instruments(instruments) + instrument_group = Instruments(instruments) io_manager = TheIOManager() system_context = copy_context() ki_manager = KIManager() runner = Runner( clock=clock, - instruments=instruments, + instruments=instrument_group, io_manager=io_manager, system_context=system_context, ki_manager=ki_manager, @@ -1981,13 +2144,13 @@ def setup_runner( def run( - async_fn, - *args, - clock=None, - instruments=(), + async_fn: Callable[..., Awaitable[RetT]], + *args: object, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, strict_exception_groups: bool = False, -): +) -> RetT: """Run a Trio-flavored async function, and return the result. Calling:: @@ -2069,34 +2232,42 @@ def run( strict_exception_groups, ) - gen = unrolled_run(runner, async_fn, args) - next_send = None - while True: - try: - timeout = gen.send(next_send) - except StopIteration: - break - next_send = runner.io_manager.get_events(timeout) + prev_library, sniffio_library.name = sniffio_library.name, "trio" + try: + gen = unrolled_run(runner, async_fn, args) + # Need to send None in the first time. + next_send: EventResult = None # type: ignore[assignment] + while True: + try: + timeout = gen.send(next_send) + except StopIteration: + break + next_send = runner.io_manager.get_events(timeout) + finally: + sniffio_library.name = prev_library # Inlined copy of runner.main_task_outcome.unwrap() to avoid # cluttering every single Trio traceback with an extra frame. if isinstance(runner.main_task_outcome, Value): - return runner.main_task_outcome.value - else: + return cast(RetT, runner.main_task_outcome.value) + elif isinstance(runner.main_task_outcome, Error): raise runner.main_task_outcome.error + else: # pragma: no cover + raise AssertionError(runner.main_task_outcome) def start_guest_run( - async_fn, - *args, - run_sync_soon_threadsafe, - done_callback, - run_sync_soon_not_threadsafe=None, + async_fn: Callable[..., Awaitable[RetT]], + *args: object, + run_sync_soon_threadsafe: Callable[[Callable[[], object]], object], + done_callback: Callable[[Outcome[RetT]], object], + run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] + | None = None, host_uses_signal_set_wakeup_fd: bool = False, - clock=None, - instruments=(), + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, strict_exception_groups: bool = False, -): +) -> None: """Start a "guest" run of Trio on top of some other "host" event loop. Each host loop can only have one guest run at a time. @@ -2110,6 +2281,16 @@ def start_guest_run( the host loop and then immediately starts the guest run, and then shuts down the host when the guest run completes. + Once :func:`start_guest_run` returns successfully, the guest run + has been set up enough that you can invoke sync-colored Trio + functions such as :func:`~trio.current_time`, :func:`spawn_system_task`, + and :func:`current_trio_token`. If a `~trio.TrioInternalError` occurs + during this early setup of the guest run, it will be raised out of + :func:`start_guest_run`. All other errors, including all errors + raised by the *async_fn*, will be delivered to your + *done_callback* at some point after :func:`start_guest_run` returns + successfully. + Args: run_sync_soon_threadsafe: An arbitrary callable, which will be passed a @@ -2170,6 +2351,48 @@ def my_done_callback(run_outcome): host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, ), ) + + # Run a few ticks of the guest run synchronously, so that by the + # time we return, the system nursery exists and callers can use + # spawn_system_task. We don't actually run any user code during + # this time, so it shouldn't be possible to get an exception here, + # except for a TrioInternalError. + next_send = cast( + EventResult, None + ) # First iteration must be `None`, every iteration after that is EventResult + for tick in range(5): # expected need is 2 iterations + leave some wiggle room + if runner.system_nursery is not None: + # We're initialized enough to switch to async guest ticks + break + try: + timeout = guest_state.unrolled_run_gen.send(next_send) + except StopIteration: # pragma: no cover + raise TrioInternalError( + "Guest runner exited before system nursery was initialized" + ) + if timeout != 0: # pragma: no cover + guest_state.unrolled_run_gen.throw( + TrioInternalError( + "Guest runner blocked before system nursery was initialized" + ) + ) + # next_send should be the return value of + # IOManager.get_events() if no I/O was waiting, which is + # platform-dependent. We don't actually check for I/O during + # this init phase because no one should be expecting any yet. + if sys.platform == "win32": + next_send = 0 + else: + next_send = [] + else: # pragma: no cover + guest_state.unrolled_run_gen.throw( + TrioInternalError( + "Guest runner yielded too many times before " + "system nursery was initialized" + ) + ) + + guest_state.unrolled_run_next_send = Value(next_send) run_sync_soon_not_threadsafe(guest_state.guest_tick) @@ -2184,10 +2407,10 @@ def my_done_callback(run_outcome): # straight through. def unrolled_run( runner: Runner, - async_fn, - args, + async_fn: Callable[..., object], + args: tuple[object, ...], host_uses_signal_set_wakeup_fd: bool = False, -): +) -> Generator[float, EventResult, None]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True @@ -2273,6 +2496,7 @@ def unrolled_run( break else: assert idle_primed is IdlePrimedTypes.AUTOJUMP_CLOCK + assert isinstance(runner.clock, _core.MockClock) runner.clock._autojump() # Process all runnable tasks, but only the ones that are already @@ -2317,7 +2541,7 @@ def unrolled_run( next_send_fn = task._next_send_fn next_send = task._next_send task._next_send_fn = task._next_send = None - final_outcome = None + final_outcome: Outcome[Any] | None = None try: # We used to unwrap the Outcome object here and send/throw # its contents in directly, but it turns out that .throw() @@ -2336,9 +2560,8 @@ def unrolled_run( # more Context.run adds. tb = task_exc.__traceback__ for _ in range(1 + CONTEXT_RUN_TB_FRAMES): - if tb is None: - break - tb = tb.tb_next + if tb is not None: # pragma: no branch + tb = tb.tb_next final_outcome = Error(task_exc.with_traceback(tb)) # Remove local refs so that e.g. cancelled coroutine locals # are not kept alive by this frame until another exception @@ -2381,7 +2604,8 @@ def unrolled_run( # protocol of unwrapping whatever outcome gets sent in. # Instead, we'll arrange to throw `exc` in directly, # which works for at least asyncio and curio. - runner.reschedule(task, exc) + # TODO: remove [unused-ignore] when Outcome is typed + runner.reschedule(task, exc) # type: ignore[arg-type, unused-ignore] task._next_send_fn = task.coro.throw # prevent long-lived reference # TODO: develop test for this deletion @@ -2424,18 +2648,18 @@ def unrolled_run( ################################################################ -class _TaskStatusIgnored: - def __repr__(self): +class _TaskStatusIgnored(TaskStatus[Any]): + def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value=None): + def started(self, value: Any = None) -> None: pass -TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored() +TASK_STATUS_IGNORED: FinalT[TaskStatus[Any]] = _TaskStatusIgnored() -def current_task(): +def current_task() -> Task: """Return the :class:`Task` object representing the current task. Returns: @@ -2449,7 +2673,7 @@ def current_task(): raise RuntimeError("must be called from async context") from None -def current_effective_deadline(): +def current_effective_deadline() -> float: """Returns the current effective deadline for the current task. This function examines all the cancellation scopes that are currently in @@ -2476,7 +2700,7 @@ def current_effective_deadline(): return current_task()._cancel_status.effective_deadline() -async def checkpoint(): +async def checkpoint() -> None: """A pure :ref:`checkpoint `. This checks for cancellation and allows other tasks to be scheduled, @@ -2503,7 +2727,7 @@ async def checkpoint(): await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """Issue a :ref:`checkpoint ` if the calling context has been cancelled. @@ -2529,13 +2753,25 @@ async def checkpoint_if_cancelled(): if sys.platform == "win32": from ._generated_io_windows import * - from ._io_windows import WindowsIOManager as TheIOManager + from ._io_windows import ( + EventResult as EventResult, + WindowsIOManager as TheIOManager, + _WindowsStatistics as IOStatistics, + ) elif sys.platform == "linux" or (not TYPE_CHECKING and hasattr(select, "epoll")): from ._generated_io_epoll import * - from ._io_epoll import EpollIOManager as TheIOManager + from ._io_epoll import ( + EpollIOManager as TheIOManager, + EventResult as EventResult, + _EpollStatistics as IOStatistics, + ) elif TYPE_CHECKING or hasattr(select, "kqueue"): from ._generated_io_kqueue import * - from ._io_kqueue import KqueueIOManager as TheIOManager + from ._io_kqueue import ( + EventResult as EventResult, + KqueueIOManager as TheIOManager, + _KqueueStatistics as IOStatistics, + ) else: # pragma: no cover raise NotImplementedError("unsupported platform") diff --git a/trio/_core/tests/__init__.py b/trio/_core/_tests/__init__.py similarity index 100% rename from trio/_core/tests/__init__.py rename to trio/_core/_tests/__init__.py diff --git a/trio/_core/tests/test_asyncgen.py b/trio/_core/_tests/test_asyncgen.py similarity index 99% rename from trio/_core/tests/test_asyncgen.py rename to trio/_core/_tests/test_asyncgen.py index 65bde5857f..f72d5c6859 100644 --- a/trio/_core/tests/test_asyncgen.py +++ b/trio/_core/_tests/test_asyncgen.py @@ -1,12 +1,12 @@ +import contextlib import sys import weakref -import pytest -import contextlib from math import inf -from functools import partial + +import pytest from ... import _core -from .tutil import gc_collect_harder, buggy_pypy_asyncgens, restore_unraisablehook +from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook @pytest.mark.skipif(sys.version_info < (3, 10), reason="no aclosing() in stdlib<3.10") diff --git a/trio/_core/tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py similarity index 87% rename from trio/_core/tests/test_guest_mode.py rename to trio/_core/_tests/test_guest_mode.py index 9fed232214..6b1bc2df51 100644 --- a/trio/_core/tests/test_guest_mode.py +++ b/trio/_core/_tests/test_guest_mode.py @@ -1,21 +1,23 @@ -import pytest import asyncio import contextvars -import sys -import traceback import queue -from functools import partial -from math import inf import signal import socket +import sys import threading import time +import traceback import warnings +from functools import partial +from math import inf + +import pytest import trio import trio.testing -from .tutil import gc_collect_harder, buggy_pypy_asyncgens, restore_unraisablehook + from ..._util import signal_raise +from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook # The simplest possible "host" loop. @@ -24,7 +26,7 @@ # our main # - final result is returned # - any unhandled exceptions cause an immediate crash -def trivial_guest_run(trio_fn, **start_guest_run_kwargs): +def trivial_guest_run(trio_fn, *, in_host_after_start=None, **start_guest_run_kwargs): todo = queue.Queue() host_thread = threading.current_thread() @@ -56,6 +58,8 @@ def done_callback(outcome): done_callback=done_callback, **start_guest_run_kwargs, ) + if in_host_after_start is not None: + in_host_after_start() try: while True: @@ -107,6 +111,49 @@ async def do_receive(): trivial_guest_run(trio_main) +def test_guest_is_initialized_when_start_returns(): + trio_token = None + record = [] + + async def trio_main(in_host): + record.append("main task ran") + await trio.sleep(0) + assert trio.lowlevel.current_trio_token() is trio_token + return "ok" + + def after_start(): + # We should get control back before the main task executes any code + assert record == [] + + nonlocal trio_token + trio_token = trio.lowlevel.current_trio_token() + trio_token.run_sync_soon(record.append, "run_sync_soon cb ran") + + @trio.lowlevel.spawn_system_task + async def early_task(): + record.append("system task ran") + await trio.sleep(0) + + res = trivial_guest_run(trio_main, in_host_after_start=after_start) + assert res == "ok" + assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"} + + # Errors during initialization (which can only be TrioInternalErrors) + # are raised out of start_guest_run, not out of the done_callback + with pytest.raises(trio.TrioInternalError): + + class BadClock: + def start_clock(self): + raise ValueError("whoops") + + def after_start_never_runs(): # pragma: no cover + pytest.fail("shouldn't get here") + + trivial_guest_run( + trio_main, clock=BadClock(), in_host_after_start=after_start_never_runs + ) + + def test_host_can_directly_wake_trio_task(): async def trio_main(in_host): ev = trio.Event() @@ -140,6 +187,35 @@ async def trio_main(in_host): assert trivial_guest_run(trio_main) == "ok" +def test_guest_mode_sniffio_integration(): + from sniffio import current_async_library, thread_local as sniffio_library + + async def trio_main(in_host): + async def synchronize(): + """Wait for all in_host() calls issued so far to complete.""" + evt = trio.Event() + in_host(evt.set) + await evt.wait() + + # Host and guest have separate sniffio_library contexts + in_host(partial(setattr, sniffio_library, "name", "nullio")) + await synchronize() + assert current_async_library() == "trio" + + record = [] + in_host(lambda: record.append(current_async_library())) + await synchronize() + assert record == ["nullio"] + assert current_async_library() == "trio" + + return "ok" + + try: + assert trivial_guest_run(trio_main) == "ok" + finally: + sniffio_library.name = None + + def test_warn_set_wakeup_fd_overwrite(): assert signal.set_wakeup_fd(-1) == -1 @@ -502,10 +578,6 @@ async def trio_main(in_host): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") -@pytest.mark.xfail( - sys.implementation.name == "pypy", - reason="async generator issue under investigation", -) @restore_unraisablehook() def test_guest_mode_asyncgens(): import sniffio @@ -525,8 +597,6 @@ async def agen(label): record.add((label, library)) async def iterate_in_aio(): - # "trio" gets inherited from our Trio caller if we don't set this - sniffio.current_async_library_cvar.set("asyncio") await agen("asyncio").asend(None) async def trio_main(): diff --git a/trio/_core/tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py similarity index 99% rename from trio/_core/tests/test_instrumentation.py rename to trio/_core/_tests/test_instrumentation.py index 57d3461d3b..498a3eb272 100644 --- a/trio/_core/tests/test_instrumentation.py +++ b/trio/_core/_tests/test_instrumentation.py @@ -1,6 +1,7 @@ import attr import pytest -from ... import _core, _abc + +from ... import _abc, _core from .tutil import check_sequence_matches diff --git a/trio/_core/tests/test_io.py b/trio/_core/_tests/test_io.py similarity index 81% rename from trio/_core/tests/test_io.py rename to trio/_core/_tests/test_io.py index 916ba6cd6f..7a4689d3c1 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/_tests/test_io.py @@ -1,19 +1,27 @@ -import pytest +from __future__ import annotations -import socket as stdlib_socket -import select import random -import errno +import socket as stdlib_socket +from collections.abc import Generator from contextlib import suppress +from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar + +import pytest -from ... import _core -from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints import trio +from ... import _core +from ...testing import assert_checkpoints, wait_all_tasks_blocked + # Cross-platform tests for IO handling +if TYPE_CHECKING: + from typing_extensions import ParamSpec -def fill_socket(sock): + ArgsT = ParamSpec("ArgsT") + + +def fill_socket(sock: stdlib_socket.socket) -> None: try: while True: sock.send(b"x" * 65536) @@ -21,7 +29,7 @@ def fill_socket(sock): pass -def drain_socket(sock): +def drain_socket(sock: stdlib_socket.socket) -> None: try: while True: sock.recv(65536) @@ -29,8 +37,13 @@ def drain_socket(sock): pass +WaitSocket = Callable[[stdlib_socket.socket], Awaitable[object]] +SocketPair = Tuple[stdlib_socket.socket, stdlib_socket.socket] +RetT = TypeVar("RetT") + + @pytest.fixture -def socketpair(): +def socketpair() -> Generator[SocketPair, None, None]: pair = stdlib_socket.socketpair() for sock in pair: sock.setblocking(False) @@ -39,38 +52,35 @@ def socketpair(): sock.close() -def using_fileno(fn): - def fileno_wrapper(fileobj): +def also_using_fileno( + fn: Callable[[stdlib_socket.socket | int], RetT], +) -> list[Callable[[stdlib_socket.socket], RetT]]: + def fileno_wrapper(fileobj: stdlib_socket.socket) -> RetT: return fn(fileobj.fileno()) name = f"<{fn.__name__} on fileno>" fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name - return fileno_wrapper + return [fn, fileno_wrapper] -wait_readable_options = [trio.lowlevel.wait_readable] -wait_writable_options = [trio.lowlevel.wait_writable] -notify_closing_options = [trio.lowlevel.notify_closing] - -for options_list in [ - wait_readable_options, - wait_writable_options, - notify_closing_options, -]: - options_list += [using_fileno(f) for f in options_list] - # Decorators that feed in different settings for wait_readable / wait_writable # / notify_closing. # Note that if you use all three decorators on the same test, it will run all # N**3 *combinations* read_socket_test = pytest.mark.parametrize( - "wait_readable", wait_readable_options, ids=lambda fn: fn.__name__ + "wait_readable", + also_using_fileno(trio.lowlevel.wait_readable), + ids=lambda fn: fn.__name__, ) write_socket_test = pytest.mark.parametrize( - "wait_writable", wait_writable_options, ids=lambda fn: fn.__name__ + "wait_writable", + also_using_fileno(trio.lowlevel.wait_writable), + ids=lambda fn: fn.__name__, ) notify_closing_test = pytest.mark.parametrize( - "notify_closing", notify_closing_options, ids=lambda fn: fn.__name__ + "notify_closing", + also_using_fileno(trio.lowlevel.notify_closing), + ids=lambda fn: fn.__name__, ) @@ -79,7 +89,9 @@ def fileno_wrapper(fileobj): # momentarily and then immediately resuming. @read_socket_test @write_socket_test -async def test_wait_basic(socketpair, wait_readable, wait_writable): +async def test_wait_basic( + socketpair: SocketPair, wait_readable: WaitSocket, wait_writable: WaitSocket +) -> None: a, b = socketpair # They start out writable() @@ -89,7 +101,7 @@ async def test_wait_basic(socketpair, wait_readable, wait_writable): # But readable() blocks until data arrives record = [] - async def block_on_read(): + async def block_on_read() -> None: try: with assert_checkpoints(): await wait_readable(a) @@ -112,7 +124,7 @@ async def block_on_read(): await wait_readable(b) record = [] - async def block_on_write(): + async def block_on_write() -> None: try: with assert_checkpoints(): await wait_writable(a) @@ -145,7 +157,7 @@ async def block_on_write(): @read_socket_test -async def test_double_read(socketpair, wait_readable): +async def test_double_read(socketpair: SocketPair, wait_readable: WaitSocket) -> None: a, b = socketpair # You can't have two tasks trying to read from a socket at the same time @@ -158,7 +170,7 @@ async def test_double_read(socketpair, wait_readable): @write_socket_test -async def test_double_write(socketpair, wait_writable): +async def test_double_write(socketpair: SocketPair, wait_writable: WaitSocket) -> None: a, b = socketpair # You can't have two tasks trying to write to a socket at the same time @@ -175,15 +187,18 @@ async def test_double_write(socketpair, wait_writable): @write_socket_test @notify_closing_test async def test_interrupted_by_close( - socketpair, wait_readable, wait_writable, notify_closing -): + socketpair: SocketPair, + wait_readable: WaitSocket, + wait_writable: WaitSocket, + notify_closing: Callable[[stdlib_socket.socket], object], +) -> None: a, b = socketpair - async def reader(): + async def reader() -> None: with pytest.raises(_core.ClosedResourceError): await wait_readable(a) - async def writer(): + async def writer() -> None: with pytest.raises(_core.ClosedResourceError): await wait_writable(a) @@ -198,14 +213,16 @@ async def writer(): @read_socket_test @write_socket_test -async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable): - record = [] +async def test_socket_simultaneous_read_write( + socketpair: SocketPair, wait_readable: WaitSocket, wait_writable: WaitSocket +) -> None: + record: list[str] = [] - async def r_task(sock): + async def r_task(sock: stdlib_socket.socket) -> None: await wait_readable(sock) record.append("r_task") - async def w_task(sock): + async def w_task(sock: stdlib_socket.socket) -> None: await wait_writable(sock) record.append("w_task") @@ -226,7 +243,9 @@ async def w_task(sock): @read_socket_test @write_socket_test -async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable): +async def test_socket_actual_streaming( + socketpair: SocketPair, wait_readable: WaitSocket, wait_writable: WaitSocket +) -> None: a, b = socketpair # Use a small send buffer on one of the sockets to increase the chance of @@ -236,9 +255,9 @@ async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable) N = 1000000 # 1 megabyte MAX_CHUNK = 65536 - results = {} + results: dict[str, int] = {} - async def sender(sock, seed, key): + async def sender(sock: stdlib_socket.socket, seed: int, key: str) -> None: r = random.Random(seed) sent = 0 while sent < N: @@ -253,7 +272,7 @@ async def sender(sock, seed, key): sock.shutdown(stdlib_socket.SHUT_WR) results[key] = sent - async def receiver(sock, key): + async def receiver(sock: stdlib_socket.socket, key: str) -> None: received = 0 while True: print("received", received) @@ -275,7 +294,7 @@ async def receiver(sock, key): assert results["send_b"] == results["recv_a"] -async def test_notify_closing_on_invalid_object(): +async def test_notify_closing_on_invalid_object() -> None: # It should either be a no-op (generally on Unix, where we don't know # which fds are valid), or an OSError (on Windows, where we currently only # support sockets, so we have to do some validation to figure out whether @@ -291,7 +310,7 @@ async def test_notify_closing_on_invalid_object(): assert got_oserror or got_no_error -async def test_wait_on_invalid_object(): +async def test_wait_on_invalid_object() -> None: # We definitely want to raise an error everywhere if you pass in an # invalid fd to wait_* for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]: @@ -303,12 +322,12 @@ async def test_wait_on_invalid_object(): await wait(fileno) -async def test_io_manager_statistics(): - def check(*, expected_readers, expected_writers): +async def test_io_manager_statistics() -> None: + def check(*, expected_readers: int, expected_writers: int) -> None: statistics = _core.current_statistics() print(statistics) iostats = statistics.io_statistics - if iostats.backend in ["epoll", "windows"]: + if iostats.backend == "epoll" or iostats.backend == "windows": assert iostats.tasks_waiting_read == expected_readers assert iostats.tasks_waiting_write == expected_writers else: @@ -351,7 +370,7 @@ def check(*, expected_readers, expected_writers): check(expected_readers=1, expected_writers=0) -async def test_can_survive_unnotified_close(): +async def test_can_survive_unnotified_close() -> None: # An "unnotified" close is when the user closes an fd/socket/handle # directly, without calling notify_closing first. This should never happen # -- users should call notify_closing before closing things. But, just in @@ -369,9 +388,13 @@ async def test_can_survive_unnotified_close(): # This test exercises some tricky "unnotified close" scenarios, to make # sure we get the "acceptable" behaviors. - async def allow_OSError(async_func, *args): + async def allow_OSError( + async_func: Callable[ArgsT, Awaitable[object]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, + ) -> None: with suppress(OSError): - await async_func(*args) + await async_func(*args, **kwargs) with stdlib_socket.socket() as s: async with trio.open_nursery() as nursery: @@ -429,7 +452,7 @@ async def allow_OSError(async_func, *args): # sleep waiting on 'a2', with the idea that the 'a2' notification will # definitely arrive, and when it does then we can assume that whatever # notification was going to arrive for 'a' has also arrived. - async def wait_readable_a2_then_set(): + async def wait_readable_a2_then_set() -> None: await trio.lowlevel.wait_readable(a2) e.set() diff --git a/trio/_core/tests/test_ki.py b/trio/_core/_tests/test_ki.py similarity index 97% rename from trio/_core/tests/test_ki.py rename to trio/_core/_tests/test_ki.py index 101e21441d..b6eef68e22 100644 --- a/trio/_core/tests/test_ki.py +++ b/trio/_core/_tests/test_ki.py @@ -1,23 +1,26 @@ -import outcome -import pytest -import sys -import os -import signal -import threading +from __future__ import annotations + import contextlib -import time import inspect +import signal +import threading +from typing import TYPE_CHECKING + +import outcome +import pytest try: - from async_generator import yield_, async_generator + from async_generator import async_generator, yield_ except ImportError: # pragma: no cover async_generator = yield_ = None from ... import _core -from ...testing import wait_all_tasks_blocked -from ..._util import signal_raise, is_main_thread from ..._timeouts import sleep -from .tutil import slow +from ..._util import signal_raise +from ...testing import wait_all_tasks_blocked + +if TYPE_CHECKING: + from ..._core import Abort, RaiseCancelT def ki_self(): @@ -378,7 +381,7 @@ async def main(): ki_self() task = _core.current_task() - def abort(_): + def abort(_: RaiseCancelT) -> Abort: _core.reschedule(task, outcome.Value(1)) return _core.Abort.FAILED @@ -397,7 +400,7 @@ async def main(): ki_self() task = _core.current_task() - def abort(raise_cancel): + def abort(raise_cancel: RaiseCancelT) -> Abort: result = outcome.capture(raise_cancel) _core.reschedule(task, result) return _core.Abort.FAILED diff --git a/trio/_core/tests/test_local.py b/trio/_core/_tests/test_local.py similarity index 98% rename from trio/_core/tests/test_local.py rename to trio/_core/_tests/test_local.py index 619dcd20d4..d36be0479e 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/_tests/test_local.py @@ -8,7 +8,7 @@ def test_runvar_smoketest(): t1 = _core.RunVar("test1") t2 = _core.RunVar("test2", default="catfish") - assert "RunVar" in repr(t1) + assert repr(t1) == "" async def first_check(): with pytest.raises(LookupError): diff --git a/trio/_core/tests/test_mock_clock.py b/trio/_core/_tests/test_mock_clock.py similarity index 99% rename from trio/_core/tests/test_mock_clock.py rename to trio/_core/_tests/test_mock_clock.py index e5b2373ca5..9c74df3334 100644 --- a/trio/_core/tests/test_mock_clock.py +++ b/trio/_core/_tests/test_mock_clock.py @@ -1,9 +1,10 @@ -from math import inf import time +from math import inf import pytest from trio import sleep + from ... import _core from .. import wait_all_tasks_blocked from .._mock_clock import MockClock diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/_tests/test_multierror.py similarity index 88% rename from trio/_core/tests/test_multierror.py rename to trio/_core/_tests/test_multierror.py index 650f9bf597..6990a7b756 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -1,37 +1,34 @@ +from __future__ import annotations + import gc -import logging import os +import pickle +import re import subprocess +import sys +import warnings from pathlib import Path +from traceback import extract_tb, print_exception import pytest -from traceback import ( - extract_tb, - print_exception, - format_exception, -) -from traceback import _cause_message # type: ignore -import sys -import re - -from .tutil import slow -from .._multierror import MultiError, concat_tb, NonBaseMultiError from ... import TrioDeprecationWarning from ..._core import open_nursery +from .._multierror import MultiError, NonBaseMultiError, concat_tb +from .tutil import slow if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup class NotHashableException(Exception): - code = None + code: int | None = None - def __init__(self, code): + def __init__(self, code: int) -> None: super().__init__() self.code = code - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, NotHashableException): return False return self.code == other.code @@ -451,7 +448,11 @@ def run_script(name, use_ipython=False): print("subprocess PYTHONPATH:", env.get("PYTHONPATH")) if use_ipython: - lines = [script_path.read_text(), "exit()"] + lines = [ + "import runpy", + f"runpy.run_path(r'{script_path}', run_name='trio.fake')", + "exit()", + ] cmd = [ sys.executable, @@ -473,12 +474,10 @@ def run_script(name, use_ipython=False): return completed -def check_simple_excepthook(completed, uses_ipython): +def check_simple_excepthook(completed): assert_match_in_seq( [ - "in = (3, 8) - else "in ", + "in ", "MultiError", "--- 1 ---", "in exc1_fn", @@ -492,7 +491,7 @@ def check_simple_excepthook(completed, uses_ipython): try: - import IPython + import IPython # noqa: F401 except ImportError: # pragma: no cover have_ipython = False else: @@ -505,14 +504,14 @@ def check_simple_excepthook(completed, uses_ipython): @need_ipython def test_ipython_exc_handler(): completed = run_script("simple_excepthook.py", use_ipython=True) - check_simple_excepthook(completed, True) + check_simple_excepthook(completed) @slow @need_ipython def test_ipython_imported_but_unused(): completed = run_script("simple_excepthook_IPython.py") - check_simple_excepthook(completed, False) + check_simple_excepthook(completed) @slow @@ -555,3 +554,36 @@ def test_apport_excepthook_monkeypatch_interaction(): ["--- 1 ---", "KeyError", "--- 2 ---", "ValueError"], stdout, ) + + +@pytest.mark.parametrize("protocol", range(0, pickle.HIGHEST_PROTOCOL + 1)) +def test_pickle_multierror(protocol: int) -> None: + # use trio.MultiError to make sure that pickle works through the deprecation layer + import trio + + my_except = ZeroDivisionError() + + try: + 1 / 0 + except ZeroDivisionError as e: + my_except = e + + # MultiError will collapse into different classes depending on the errors + for cls, errors in ( + (ZeroDivisionError, [my_except]), + (NonBaseMultiError, [my_except, ValueError()]), + (MultiError, [BaseException(), my_except]), + ): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", TrioDeprecationWarning) + me = trio.MultiError(errors) # type: ignore[attr-defined] + dump = pickle.dumps(me, protocol=protocol) + load = pickle.loads(dump) + assert repr(me) == repr(load) + assert me.__class__ == load.__class__ == cls + + assert me.__dict__.keys() == load.__dict__.keys() + for me_val, load_val in zip(me.__dict__.values(), load.__dict__.values()): + # tracebacks etc are not preserved through pickling for the default + # exceptions, so we only check that the repr stays the same + assert repr(me_val) == repr(load_val) diff --git a/trio/_core/tests/test_multierror_scripts/__init__.py b/trio/_core/_tests/test_multierror_scripts/__init__.py similarity index 100% rename from trio/_core/tests/test_multierror_scripts/__init__.py rename to trio/_core/_tests/test_multierror_scripts/__init__.py diff --git a/trio/_core/tests/test_multierror_scripts/_common.py b/trio/_core/_tests/test_multierror_scripts/_common.py similarity index 100% rename from trio/_core/tests/test_multierror_scripts/_common.py rename to trio/_core/_tests/test_multierror_scripts/_common.py diff --git a/trio/_core/tests/test_multierror_scripts/apport_excepthook.py b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py similarity index 63% rename from trio/_core/tests/test_multierror_scripts/apport_excepthook.py rename to trio/_core/_tests/test_multierror_scripts/apport_excepthook.py index 12e7fb0851..0e46f37e17 100644 --- a/trio/_core/tests/test_multierror_scripts/apport_excepthook.py +++ b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py @@ -3,11 +3,13 @@ # make sure it's on sys.path. import sys +import _common # isort: split + sys.path.append("/usr/lib/python3/dist-packages") import apport_python_hook apport_python_hook.install() -import trio +from trio._core._multierror import MultiError # Bypass deprecation warnings -raise trio.MultiError([KeyError("key_error"), ValueError("value_error")]) +raise MultiError([KeyError("key_error"), ValueError("value_error")]) diff --git a/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py similarity index 80% rename from trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py rename to trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py index b3fd110e50..7ccb341dc9 100644 --- a/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py +++ b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py @@ -1,10 +1,10 @@ -import _common - # Override the regular excepthook too -- it doesn't change anything either way # because ipython doesn't use it, but we want to make sure Trio doesn't warn # about it. import sys +import _common # isort: split + def custom_excepthook(*args): print("custom running!") @@ -29,8 +29,8 @@ def custom_exc_hook(etype, value, tb, tb_offset=None): ip.set_custom_exc((SomeError,), custom_exc_hook) -import trio +from trio._core._multierror import MultiError # Bypass deprecation warnings. # The custom excepthook should run, because Trio was polite and didn't # override it -raise trio.MultiError([ValueError(), KeyError()]) +raise MultiError([ValueError(), KeyError()]) diff --git a/trio/_core/tests/test_multierror_scripts/simple_excepthook.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py similarity index 64% rename from trio/_core/tests/test_multierror_scripts/simple_excepthook.py rename to trio/_core/_tests/test_multierror_scripts/simple_excepthook.py index 94004525db..65371107bc 100644 --- a/trio/_core/tests/test_multierror_scripts/simple_excepthook.py +++ b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py @@ -1,6 +1,6 @@ -import _common +import _common # isort: split -import trio +from trio._core._multierror import MultiError # Bypass deprecation warnings def exc1_fn(): @@ -18,4 +18,4 @@ def exc2_fn(): # This should be printed nicely, because Trio overrode sys.excepthook -raise trio.MultiError([exc1_fn(), exc2_fn()]) +raise MultiError([exc1_fn(), exc2_fn()]) diff --git a/trio/_core/tests/test_multierror_scripts/simple_excepthook_IPython.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook_IPython.py similarity index 99% rename from trio/_core/tests/test_multierror_scripts/simple_excepthook_IPython.py rename to trio/_core/_tests/test_multierror_scripts/simple_excepthook_IPython.py index 6aa12493b0..51a88c96ce 100644 --- a/trio/_core/tests/test_multierror_scripts/simple_excepthook_IPython.py +++ b/trio/_core/_tests/test_multierror_scripts/simple_excepthook_IPython.py @@ -3,5 +3,4 @@ # To tickle the "is IPython loaded?" logic, make sure that Trio tolerates # IPython loaded but not actually in use import IPython - import simple_excepthook diff --git a/trio/_core/tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py similarity index 98% rename from trio/_core/tests/test_parking_lot.py rename to trio/_core/_tests/test_parking_lot.py index db3fc76709..3f03fdbade 100644 --- a/trio/_core/tests/test_parking_lot.py +++ b/trio/_core/_tests/test_parking_lot.py @@ -72,6 +72,9 @@ async def waiter(i, lot): ) lot.unpark_all() + with pytest.raises(ValueError): + lot.unpark(count=1.5) + async def cancellable_waiter(name, lot, scopes, record): with _core.CancelScope() as scope: diff --git a/trio/_core/tests/test_run.py b/trio/_core/_tests/test_run.py similarity index 78% rename from trio/_core/tests/test_run.py rename to trio/_core/_tests/test_run.py index 4d2cf204fe..5c45cf828f 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/_tests/test_run.py @@ -1,56 +1,69 @@ +from __future__ import annotations + import contextvars import functools -import platform +import gc import sys import threading import time import types -import warnings import weakref -from contextlib import contextmanager, ExitStack +from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Generator, +) +from contextlib import ExitStack, contextmanager from math import inf -from textwrap import dedent -import gc +from typing import NoReturn, TypeVar -import attr import outcome -import sniffio import pytest +import sniffio +from ... import _core +from ..._core._multierror import MultiError, NonBaseMultiError +from ..._threads import to_thread_run_sync +from ..._timeouts import fail_after, sleep +from ...testing import Sequencer, assert_checkpoints, wait_all_tasks_blocked +from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD from .tutil import ( - slow, + buggy_pypy_asyncgens, check_sequence_matches, + create_asyncio_future_in_new_loop, gc_collect_harder, ignore_coroutine_never_awaited_warnings, - buggy_pypy_asyncgens, restore_unraisablehook, - create_asyncio_future_in_new_loop, -) - -from ... import _core -from ..._core._multierror import MultiError, NonBaseMultiError -from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD -from ..._threads import to_thread_run_sync -from ..._timeouts import sleep, fail_after -from ...testing import ( - wait_all_tasks_blocked, - Sequencer, - assert_checkpoints, + slow, ) if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup +T = TypeVar("T") + + # slightly different from _timeouts.sleep_forever because it returns the value # its rescheduled with, which is really only useful for tests of # rescheduling... -async def sleep_forever(): +async def sleep_forever() -> object: return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -def test_basic(): - async def trivial(x): +def not_none(x: T | None) -> T: + """Assert that this object is not None. + + This is just to satisfy type checkers, if this ever fails the test is broken. + """ + assert x is not None + return x + + +def test_basic() -> None: + async def trivial(x: T) -> T: return x assert _core.run(trivial, 8) == 8 @@ -61,17 +74,17 @@ async def trivial(x): with pytest.raises(TypeError): # Not an async function - _core.run(lambda: None) + _core.run(lambda: None) # type: ignore - async def trivial2(x): + async def trivial2(x: T) -> T: await _core.checkpoint() return x assert _core.run(trivial2, 1) == 1 -def test_initial_task_error(): - async def main(x): +def test_initial_task_error() -> None: + async def main(x: object) -> NoReturn: raise ValueError(x) with pytest.raises(ValueError) as excinfo: @@ -79,9 +92,9 @@ async def main(x): assert excinfo.value.args == (17,) -def test_run_nesting(): - async def inception(): - async def main(): # pragma: no cover +def test_run_nesting() -> None: + async def inception() -> None: + async def main() -> None: # pragma: no cover pass return _core.run(main) @@ -91,10 +104,10 @@ async def main(): # pragma: no cover assert "from inside" in str(excinfo.value) -async def test_nursery_warn_use_async_with(): +async def test_nursery_warn_use_async_with() -> None: with pytest.raises(RuntimeError) as excinfo: on = _core.open_nursery() - with on: + with on: # type: ignore pass # pragma: no cover excinfo.match( r"use 'async with open_nursery\(...\)', not 'with open_nursery\(...\)'" @@ -105,7 +118,7 @@ async def test_nursery_warn_use_async_with(): pass -async def test_nursery_main_block_error_basic(): +async def test_nursery_main_block_error_basic() -> None: exc = ValueError("whoops") with pytest.raises(ValueError) as excinfo: @@ -114,10 +127,10 @@ async def test_nursery_main_block_error_basic(): assert excinfo.value is exc -async def test_child_crash_basic(): +async def test_child_crash_basic() -> None: exc = ValueError("uh oh") - async def erroring(): + async def erroring() -> NoReturn: raise exc try: @@ -128,13 +141,13 @@ async def erroring(): assert e is exc -async def test_basic_interleave(): - async def looper(whoami, record): +async def test_basic_interleave() -> None: + async def looper(whoami: str, record: list[tuple[str, int]]) -> None: for i in range(3): record.append((whoami, i)) await _core.checkpoint() - record = [] + record: list[tuple[str, int]] = [] async with _core.open_nursery() as nursery: nursery.start_soon(looper, "a", record) nursery.start_soon(looper, "b", record) @@ -144,10 +157,10 @@ async def looper(whoami, record): ) -def test_task_crash_propagation(): - looper_record = [] +def test_task_crash_propagation() -> None: + looper_record: list[str] = [] - async def looper(): + async def looper() -> None: try: while True: await _core.checkpoint() @@ -155,10 +168,10 @@ async def looper(): print("looper cancelled") looper_record.append("cancelled") - async def crasher(): + async def crasher() -> NoReturn: raise ValueError("argh") - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(looper) nursery.start_soon(crasher) @@ -170,13 +183,13 @@ async def main(): assert excinfo.value.args == ("argh",) -def test_main_and_task_both_crash(): +def test_main_and_task_both_crash() -> None: # If main crashes and there's also a task crash, then we get both in a # MultiError - async def crasher(): + async def crasher() -> NoReturn: raise ValueError - async def main(): + async def main() -> NoReturn: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) raise KeyError @@ -190,11 +203,11 @@ async def main(): } -def test_two_child_crashes(): - async def crasher(etype): +def test_two_child_crashes() -> None: + async def crasher(etype: type[Exception]) -> NoReturn: raise etype - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher, KeyError) nursery.start_soon(crasher, ValueError) @@ -207,8 +220,8 @@ async def main(): } -async def test_child_crash_wakes_parent(): - async def crasher(): +async def test_child_crash_wakes_parent() -> None: + async def crasher() -> NoReturn: raise ValueError with pytest.raises(ValueError): @@ -217,11 +230,11 @@ async def crasher(): await sleep_forever() -async def test_reschedule(): - t1 = None - t2 = None +async def test_reschedule() -> None: + t1: _core.Task | None = None + t2: _core.Task | None = None - async def child1(): + async def child1() -> None: nonlocal t1, t2 t1 = _core.current_task() print("child1 start") @@ -229,14 +242,14 @@ async def child1(): print("child1 woke") assert x == 0 print("child1 rescheduling t2") - _core.reschedule(t2, outcome.Error(ValueError())) + _core.reschedule(not_none(t2), outcome.Error(ValueError())) print("child1 exit") - async def child2(): + async def child2() -> None: nonlocal t1, t2 print("child2 start") t2 = _core.current_task() - _core.reschedule(t1, outcome.Value(0)) + _core.reschedule(not_none(t1), outcome.Value(0)) print("child2 sleep") with pytest.raises(ValueError): await sleep_forever() @@ -249,7 +262,7 @@ async def child2(): nursery.start_soon(child2) -async def test_current_time(): +async def test_current_time() -> None: t1 = _core.current_time() # Windows clock is pretty low-resolution -- appveyor tests fail unless we # sleep for a bit here. @@ -258,7 +271,7 @@ async def test_current_time(): assert t1 < t2 -async def test_current_time_with_mock_clock(mock_clock): +async def test_current_time_with_mock_clock(mock_clock: _core.MockClock) -> None: start = mock_clock.current_time() assert mock_clock.current_time() == _core.current_time() assert mock_clock.current_time() == _core.current_time() @@ -266,38 +279,38 @@ async def test_current_time_with_mock_clock(mock_clock): assert start + 3.14 == mock_clock.current_time() == _core.current_time() -async def test_current_clock(mock_clock): +async def test_current_clock(mock_clock: _core.MockClock) -> None: assert mock_clock is _core.current_clock() -async def test_current_task(): +async def test_current_task() -> None: parent_task = _core.current_task() - async def child(): - assert _core.current_task().parent_nursery.parent_task is parent_task + async def child() -> None: + assert not_none(_core.current_task().parent_nursery).parent_task is parent_task async with _core.open_nursery() as nursery: nursery.start_soon(child) -async def test_root_task(): - root = _core.current_root_task() +async def test_root_task() -> None: + root = not_none(_core.current_root_task()) assert root.parent_nursery is root.eventual_parent_nursery is None -def test_out_of_context(): +def test_out_of_context() -> None: with pytest.raises(RuntimeError): _core.current_task() with pytest.raises(RuntimeError): _core.current_time() -async def test_current_statistics(mock_clock): +async def test_current_statistics(mock_clock: _core.MockClock) -> None: # Make sure all the early startup stuff has settled down await wait_all_tasks_blocked() # A child that sticks around to make some interesting stats: - async def child(): + async def child() -> None: try: await sleep_forever() except _core.Cancelled: @@ -344,7 +357,7 @@ async def child(): assert stats.seconds_to_next_deadline == inf -async def test_cancel_scope_repr(mock_clock): +async def test_cancel_scope_repr(mock_clock: _core.MockClock) -> None: scope = _core.CancelScope() assert "unbound" in repr(scope) with scope: @@ -360,8 +373,8 @@ async def test_cancel_scope_repr(mock_clock): assert "exited" in repr(scope) -def test_cancel_points(): - async def main1(): +def test_cancel_points() -> None: + async def main1() -> None: with _core.CancelScope() as scope: await _core.checkpoint_if_cancelled() scope.cancel() @@ -370,7 +383,7 @@ async def main1(): _core.run(main1) - async def main2(): + async def main2() -> None: with _core.CancelScope() as scope: await _core.checkpoint() scope.cancel() @@ -379,7 +392,7 @@ async def main2(): _core.run(main2) - async def main3(): + async def main3() -> None: with _core.CancelScope() as scope: scope.cancel() with pytest.raises(_core.Cancelled): @@ -387,7 +400,7 @@ async def main3(): _core.run(main3) - async def main4(): + async def main4() -> None: with _core.CancelScope() as scope: scope.cancel() await _core.cancel_shielded_checkpoint() @@ -398,7 +411,7 @@ async def main4(): _core.run(main4) -async def test_cancel_edge_cases(): +async def test_cancel_edge_cases() -> None: with _core.CancelScope() as scope: # Two cancels in a row -- idempotent scope.cancel() @@ -416,8 +429,8 @@ async def test_cancel_edge_cases(): await sleep_forever() -async def test_cancel_scope_multierror_filtering(): - async def crasher(): +async def test_cancel_scope_multierror_filtering() -> None: + async def crasher() -> NoReturn: raise KeyError try: @@ -442,7 +455,7 @@ async def crasher(): # nursery block continue propagating to reach the # outer scope. assert len(multi_exc.exceptions) == 5 - summary = {} + summary: dict[type, int] = {} for exc in multi_exc.exceptions: summary.setdefault(type(exc), 0) summary[type(exc)] += 1 @@ -459,13 +472,13 @@ async def crasher(): assert False -async def test_precancelled_task(): +async def test_precancelled_task() -> None: # a task that gets spawned into an already-cancelled nursery should begin # execution (https://github.com/python-trio/trio/issues/41), but get a # cancelled error at its first blocking call. - record = [] + record: list[str] = [] - async def blocker(): + async def blocker() -> None: record.append("started") await sleep_forever() @@ -475,7 +488,7 @@ async def blocker(): assert record == ["started"] -async def test_cancel_shielding(): +async def test_cancel_shielding() -> None: with _core.CancelScope() as outer: with _core.CancelScope() as inner: await _core.checkpoint() @@ -485,7 +498,7 @@ async def test_cancel_shielding(): assert inner.shield is False with pytest.raises(TypeError): - inner.shield = "hello" + inner.shield = "hello" # type: ignore assert inner.shield is False inner.shield = True @@ -516,16 +529,16 @@ async def test_cancel_shielding(): # make sure that cancellation propagates immediately to all children -async def test_cancel_inheritance(): - record = set() +async def test_cancel_inheritance() -> None: + record: set[str] = set() - async def leaf(ident): + async def leaf(ident: str) -> None: try: await sleep_forever() except _core.Cancelled: record.add(ident) - async def worker(ident): + async def worker(ident: str) -> None: async with _core.open_nursery() as nursery: nursery.start_soon(leaf, ident + "-l1") nursery.start_soon(leaf, ident + "-l2") @@ -538,7 +551,7 @@ async def worker(ident): assert record == {"w1-l1", "w1-l2", "w2-l1", "w2-l2"} -async def test_cancel_shield_abort(): +async def test_cancel_shield_abort() -> None: with _core.CancelScope() as outer: async with _core.open_nursery() as nursery: outer.cancel() @@ -547,7 +560,7 @@ async def test_cancel_shield_abort(): # shield, so it manages to get to sleep record = [] - async def sleeper(): + async def sleeper() -> None: record.append("sleeping") try: await sleep_forever() @@ -569,7 +582,7 @@ async def sleeper(): assert record == ["sleeping", "cancelled"] -async def test_basic_timeout(mock_clock): +async def test_basic_timeout(mock_clock: _core.MockClock) -> None: start = _core.current_time() with _core.CancelScope() as scope: assert scope.deadline == inf @@ -606,7 +619,7 @@ async def test_basic_timeout(mock_clock): await _core.checkpoint() -async def test_cancel_scope_nesting(): +async def test_cancel_scope_nesting() -> None: # Nested scopes: if two triggering at once, the outer one wins with _core.CancelScope() as scope1: with _core.CancelScope() as scope2: @@ -645,7 +658,7 @@ async def test_cancel_scope_nesting(): # Regression test for https://github.com/python-trio/trio/issues/1175 -async def test_unshield_while_cancel_propagating(): +async def test_unshield_while_cancel_propagating() -> None: with _core.CancelScope() as outer: with _core.CancelScope() as inner: outer.cancel() @@ -656,8 +669,8 @@ async def test_unshield_while_cancel_propagating(): assert outer.cancelled_caught and not inner.cancelled_caught -async def test_cancel_unbound(): - async def sleep_until_cancelled(scope): +async def test_cancel_unbound() -> None: + async def sleep_until_cancelled(scope: _core.CancelScope) -> None: with scope, fail_after(1): await sleep_forever() @@ -706,7 +719,7 @@ async def sleep_until_cancelled(scope): # Can't enter from multiple tasks simultaneously scope = _core.CancelScope() - async def enter_scope(): + async def enter_scope() -> None: with scope: await sleep_forever() @@ -730,7 +743,7 @@ async def enter_scope(): assert scope.cancel_called # never become un-cancelled -async def test_cancel_scope_misnesting(): +async def test_cancel_scope_misnesting() -> None: outer = _core.CancelScope() inner = _core.CancelScope() with ExitStack() as stack: @@ -742,12 +755,12 @@ async def test_cancel_scope_misnesting(): # If there are other tasks inside the abandoned part of the cancel tree, # they get cancelled when the misnesting is detected - async def task1(): + async def task1() -> None: with pytest.raises(_core.Cancelled): await sleep_forever() # Even if inside another cancel scope - async def task2(): + async def task2() -> None: with _core.CancelScope(): with pytest.raises(_core.Cancelled): await sleep_forever() @@ -786,20 +799,20 @@ async def task2(): # Trying to exit a cancel scope from an unrelated task raises an error # without affecting any state - async def task3(task_status): + async def task3(task_status: _core.TaskStatus[_core.CancelScope]) -> None: with _core.CancelScope() as scope: task_status.started(scope) await sleep_forever() async with _core.open_nursery() as nursery: - scope = await nursery.start(task3) + scope: _core.CancelScope = await nursery.start(task3) with pytest.raises(RuntimeError, match="from unrelated"): scope.__exit__(None, None, None) scope.cancel() @slow -async def test_timekeeping(): +async def test_timekeeping() -> None: # probably a good idea to use a real clock for *one* test anyway... TARGET = 1.0 # give it a few tries in case of random CI server flakiness @@ -819,15 +832,16 @@ async def test_timekeeping(): assert False -async def test_failed_abort(): - stubborn_task = [None] - stubborn_scope = [None] - record = [] +async def test_failed_abort() -> None: + stubborn_task: _core.Task | None = None + stubborn_scope: _core.CancelScope | None = None + record: list[str] = [] - async def stubborn_sleeper(): - stubborn_task[0] = _core.current_task() + async def stubborn_sleeper() -> None: + nonlocal stubborn_task, stubborn_scope + stubborn_task = _core.current_task() with _core.CancelScope() as scope: - stubborn_scope[0] = scope + stubborn_scope = scope record.append("sleep") x = await _core.wait_task_rescheduled(lambda _: _core.Abort.FAILED) assert x == 1 @@ -841,18 +855,18 @@ async def stubborn_sleeper(): nursery.start_soon(stubborn_sleeper) await wait_all_tasks_blocked() assert record == ["sleep"] - stubborn_scope[0].cancel() + not_none(stubborn_scope).cancel() await wait_all_tasks_blocked() # cancel didn't wake it up assert record == ["sleep"] # wake it up again by hand - _core.reschedule(stubborn_task[0], outcome.Value(1)) + _core.reschedule(not_none(stubborn_task), outcome.Value(1)) assert record == ["sleep", "woke", "cancelled"] @restore_unraisablehook() -def test_broken_abort(): - async def main(): +def test_broken_abort() -> None: + async def main() -> None: # These yields are here to work around an annoying warning -- we're # going to crash the main loop, and if we (by chance) do this before # the run_sync_soon task runs for the first time, then Python gives us @@ -866,7 +880,7 @@ async def main(): with _core.CancelScope() as scope: scope.cancel() # None is not a legal return value here - await _core.wait_task_rescheduled(lambda _: None) + await _core.wait_task_rescheduled(lambda _: None) # type: ignore with pytest.raises(_core.TrioInternalError): _core.run(main) @@ -878,11 +892,11 @@ async def main(): @restore_unraisablehook() -def test_error_in_run_loop(): +def test_error_in_run_loop() -> None: # Blow stuff up real good to check we at least get a TrioInternalError - async def main(): + async def main() -> None: task = _core.current_task() - task._schedule_points = "hello!" + task._schedule_points = "hello!" # type: ignore await _core.checkpoint() with ignore_coroutine_never_awaited_warnings(): @@ -890,10 +904,10 @@ async def main(): _core.run(main) -async def test_spawn_system_task(): - record = [] +async def test_spawn_system_task() -> None: + record: list[tuple[str, int]] = [] - async def system_task(x): + async def system_task(x: int) -> None: record.append(("x", x)) record.append(("ki", _core.currently_ki_protected())) await _core.checkpoint() @@ -904,11 +918,11 @@ async def system_task(x): # intentionally make a system task crash -def test_system_task_crash(): - async def crasher(): +def test_system_task_crash() -> None: + async def crasher() -> NoReturn: raise KeyError - async def main(): + async def main() -> None: _core.spawn_system_task(crasher) await sleep_forever() @@ -916,19 +930,19 @@ async def main(): _core.run(main) -def test_system_task_crash_MultiError(): - async def crasher1(): +def test_system_task_crash_MultiError() -> None: + async def crasher1() -> NoReturn: raise KeyError - async def crasher2(): + async def crasher2() -> NoReturn: raise ValueError - async def system_task(): + async def system_task() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher1) nursery.start_soon(crasher2) - async def main(): + async def main() -> None: _core.spawn_system_task(system_task) await sleep_forever() @@ -942,24 +956,24 @@ async def main(): assert isinstance(exc, (KeyError, ValueError)) -def test_system_task_crash_plus_Cancelled(): +def test_system_task_crash_plus_Cancelled() -> None: # Set up a situation where a system task crashes with a # MultiError([Cancelled, ValueError]) - async def crasher(): + async def crasher() -> None: try: await sleep_forever() except _core.Cancelled: raise ValueError - async def cancelme(): + async def cancelme() -> None: await sleep_forever() - async def system_task(): + async def system_task() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) nursery.start_soon(cancelme) - async def main(): + async def main() -> None: _core.spawn_system_task(system_task) # then we exit, triggering a cancellation @@ -968,11 +982,11 @@ async def main(): assert type(excinfo.value.__cause__) is ValueError -def test_system_task_crash_KeyboardInterrupt(): - async def ki(): +def test_system_task_crash_KeyboardInterrupt() -> None: + async def ki() -> NoReturn: raise KeyboardInterrupt - async def main(): + async def main() -> None: _core.spawn_system_task(ki) await sleep_forever() @@ -990,7 +1004,7 @@ async def main(): # 4) this task has timed out # 5) ...but it's on the run queue, so the timeout is queued to be delivered # the next time that it's blocked. -async def test_yield_briefly_checks_for_timeout(mock_clock): +async def test_yield_briefly_checks_for_timeout(mock_clock: _core.MockClock) -> None: with _core.CancelScope(deadline=_core.current_time() + 5): await _core.checkpoint() with pytest.raises(_core.Cancelled): @@ -1004,11 +1018,11 @@ async def test_yield_briefly_checks_for_timeout(mock_clock): # still nice to know that it works :-). # # Update: it turns out I was right to be nervous! see the next test... -async def test_exc_info(): - record = [] +async def test_exc_info() -> None: + record: list[str] = [] seq = Sequencer() - async def child1(): + async def child1() -> None: with pytest.raises(ValueError) as excinfo: try: async with seq(0): @@ -1025,7 +1039,7 @@ async def child1(): assert excinfo.value.__context__ is None record.append("child1 success") - async def child2(): + async def child2() -> None: with pytest.raises(KeyError) as excinfo: async with seq(1): pass # we don't yield until seq(3) below @@ -1065,10 +1079,10 @@ async def child2(): # like re-raising and exception chaining are broken. # # https://bugs.python.org/issue29587 -async def test_exc_info_after_yield_error(): - child_task = None +async def test_exc_info_after_yield_error() -> None: + child_task: _core.Task | None = None - async def child(): + async def child() -> None: nonlocal child_task child_task = _core.current_task() @@ -1085,15 +1099,15 @@ async def child(): async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() - _core.reschedule(child_task, outcome.Error(ValueError())) + _core.reschedule(not_none(child_task), outcome.Error(ValueError())) # Similar to previous test -- if the ValueError() gets sent in via 'throw', # then Python's normal implicit chaining stuff is broken. -async def test_exception_chaining_after_yield_error(): - child_task = None +async def test_exception_chaining_after_yield_error() -> None: + child_task: _core.Task | None = None - async def child(): + async def child() -> None: nonlocal child_task child_task = _core.current_task() @@ -1106,13 +1120,13 @@ async def child(): async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() - _core.reschedule(child_task, outcome.Error(ValueError())) + _core.reschedule(not_none(child_task), outcome.Error(ValueError())) assert isinstance(excinfo.value.__context__, KeyError) -async def test_nursery_exception_chaining_doesnt_make_context_loops(): - async def crasher(): +async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None: + async def crasher() -> NoReturn: raise KeyError with pytest.raises(MultiError) as excinfo: @@ -1123,8 +1137,8 @@ async def crasher(): assert excinfo.value.__context__ is None -def test_TrioToken_identity(): - async def get_and_check_token(): +def test_TrioToken_identity() -> None: + async def get_and_check_token() -> _core.TrioToken: token = _core.current_trio_token() # Two calls in the same run give the same object assert token is _core.current_trio_token() @@ -1137,10 +1151,10 @@ async def get_and_check_token(): assert hash(t1) != hash(t2) -async def test_TrioToken_run_sync_soon_basic(): - record = [] +async def test_TrioToken_run_sync_soon_basic() -> None: + record: list[tuple[str, int]] = [] - def cb(x): + def cb(x: int) -> None: record.append(("cb", x)) token = _core.current_trio_token() @@ -1150,23 +1164,22 @@ def cb(x): assert record == [("cb", 1)] -def test_TrioToken_run_sync_soon_too_late(): - token = None +def test_TrioToken_run_sync_soon_too_late() -> None: + token: _core.TrioToken | None = None - async def main(): + async def main() -> None: nonlocal token token = _core.current_trio_token() _core.run(main) - assert token is not None with pytest.raises(_core.RunFinishedError): - token.run_sync_soon(lambda: None) # pragma: no branch + not_none(token).run_sync_soon(lambda: None) # pragma: no branch -async def test_TrioToken_run_sync_soon_idempotent(): - record = [] +async def test_TrioToken_run_sync_soon_idempotent() -> None: + record: list[int] = [] - def cb(x): + def cb(x: int) -> None: record.append(x) token = _core.current_trio_token() @@ -1190,21 +1203,21 @@ def cb(x): assert record == list(range(100)) -def test_TrioToken_run_sync_soon_idempotent_requeue(): +def test_TrioToken_run_sync_soon_idempotent_requeue() -> None: # We guarantee that if a call has finished, queueing it again will call it # again. Due to the lack of synchronization, this effectively means that # we have to guarantee that once a call has *started*, queueing it again # will call it again. Also this is much easier to test :-) - record = [] + record: list[None] = [] - def redo(token): + def redo(token: _core.TrioToken) -> None: record.append(None) try: token.run_sync_soon(redo, token, idempotent=True) except _core.RunFinishedError: pass - async def main(): + async def main() -> None: token = _core.current_trio_token() token.run_sync_soon(redo, token, idempotent=True) await _core.checkpoint() @@ -1216,10 +1229,10 @@ async def main(): assert len(record) >= 2 -def test_TrioToken_run_sync_soon_after_main_crash(): - record = [] +def test_TrioToken_run_sync_soon_after_main_crash() -> None: + record: list[str] = [] - async def main(): + async def main() -> None: token = _core.current_trio_token() # After main exits but before finally cleaning up, callback processed # normally @@ -1232,12 +1245,12 @@ async def main(): assert record == ["sync-cb"] -def test_TrioToken_run_sync_soon_crashes(): - record = set() +def test_TrioToken_run_sync_soon_crashes() -> None: + record: set[str] = set() - async def main(): + async def main() -> None: token = _core.current_trio_token() - token.run_sync_soon(lambda: dict()["nope"]) + token.run_sync_soon(lambda: {}["nope"]) # type: ignore[index] # check that a crashing run_sync_soon callback doesn't stop further # calls to run_sync_soon token.run_sync_soon(lambda: record.add("2nd run_sync_soon ran")) @@ -1253,7 +1266,7 @@ async def main(): assert record == {"2nd run_sync_soon ran", "cancelled!"} -async def test_TrioToken_run_sync_soon_FIFO(): +async def test_TrioToken_run_sync_soon_FIFO() -> None: N = 100 record = [] token = _core.current_trio_token() @@ -1263,43 +1276,42 @@ async def test_TrioToken_run_sync_soon_FIFO(): assert record == list(range(N)) -def test_TrioToken_run_sync_soon_starvation_resistance(): +def test_TrioToken_run_sync_soon_starvation_resistance() -> None: # Even if we push callbacks in from callbacks, so that the callback queue # never empties out, then we still can't starve out other tasks from # running. - token = None - record = [] + token: _core.TrioToken | None = None + record: list[tuple[str, int]] = [] - def naughty_cb(i): - nonlocal token + def naughty_cb(i: int) -> None: try: - token.run_sync_soon(naughty_cb, i + 1) + not_none(token).run_sync_soon(naughty_cb, i + 1) except _core.RunFinishedError: record.append(("run finished", i)) - async def main(): + async def main() -> None: nonlocal token token = _core.current_trio_token() token.run_sync_soon(naughty_cb, 0) - record.append("starting") + record.append(("starting", 0)) for _ in range(20): await _core.checkpoint() _core.run(main) assert len(record) == 2 - assert record[0] == "starting" + assert record[0] == ("starting", 0) assert record[1][0] == "run finished" assert record[1][1] >= 19 -def test_TrioToken_run_sync_soon_threaded_stress_test(): +def test_TrioToken_run_sync_soon_threaded_stress_test() -> None: cb_counter = 0 - def cb(): + def cb() -> None: nonlocal cb_counter cb_counter += 1 - def stress_thread(token): + def stress_thread(token: _core.TrioToken) -> None: try: while True: token.run_sync_soon(cb) @@ -1307,7 +1319,7 @@ def stress_thread(token): except _core.RunFinishedError: pass - async def main(): + async def main() -> None: token = _core.current_trio_token() thread = threading.Thread(target=stress_thread, args=(token,)) thread.start() @@ -1320,7 +1332,7 @@ async def main(): print(cb_counter) -async def test_TrioToken_run_sync_soon_massive_queue(): +async def test_TrioToken_run_sync_soon_massive_queue() -> None: # There are edge cases in the wakeup fd code when the wakeup fd overflows, # so let's try to make that happen. This is also just a good stress test # in general. (With the current-as-of-2017-02-14 code using a socketpair @@ -1331,7 +1343,7 @@ async def test_TrioToken_run_sync_soon_massive_queue(): token = _core.current_trio_token() counter = [0] - def cb(i): + def cb(i: int) -> None: # This also tests FIFO ordering of callbacks assert counter[0] == i counter[0] += 1 @@ -1343,21 +1355,21 @@ def cb(i): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") -def test_TrioToken_run_sync_soon_late_crash(): +def test_TrioToken_run_sync_soon_late_crash() -> None: # Crash after system nursery is closed -- easiest way to do that is # from an async generator finalizer. - record = [] - saved = [] + record: list[str] = [] + saved: list[AsyncGenerator[int, None]] = [] - async def agen(): + async def agen() -> AsyncGenerator[int, None]: token = _core.current_trio_token() try: yield 1 finally: - token.run_sync_soon(lambda: {}["nope"]) + token.run_sync_soon(lambda: {}["nope"]) # type: ignore[index] token.run_sync_soon(lambda: record.append("2nd ran")) - async def main(): + async def main() -> None: saved.append(agen()) await saved[-1].asend(None) record.append("main exiting") @@ -1369,14 +1381,14 @@ async def main(): assert record == ["main exiting", "2nd ran"] -async def test_slow_abort_basic(): +async def test_slow_abort_basic() -> None: with _core.CancelScope() as scope: scope.cancel() with pytest.raises(_core.Cancelled): task = _core.current_task() token = _core.current_trio_token() - def slow_abort(raise_cancel): + def slow_abort(raise_cancel: _core.RaiseCancelT) -> _core.Abort: result = outcome.capture(raise_cancel) token.run_sync_soon(_core.reschedule, task, result) return _core.Abort.FAILED @@ -1384,14 +1396,14 @@ def slow_abort(raise_cancel): await _core.wait_task_rescheduled(slow_abort) -async def test_slow_abort_edge_cases(): - record = [] +async def test_slow_abort_edge_cases() -> None: + record: list[str] = [] - async def slow_aborter(): + async def slow_aborter() -> None: task = _core.current_task() token = _core.current_trio_token() - def slow_abort(raise_cancel): + def slow_abort(raise_cancel: _core.RaiseCancelT) -> _core.Abort: record.append("abort-called") result = outcome.capture(raise_cancel) token.run_sync_soon(_core.reschedule, task, result) @@ -1427,11 +1439,13 @@ def slow_abort(raise_cancel): assert record == ["sleeping", "abort-called", "cancelled", "done"] -async def test_task_tree_introspection(): - tasks = {} - nurseries = {} +async def test_task_tree_introspection() -> None: + tasks: dict[str, _core.Task] = {} + nurseries: dict[str, _core.Nursery] = {} - async def parent(task_status=_core.TASK_STATUS_IGNORED): + async def parent( + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: tasks["parent"] = _core.current_task() assert tasks["parent"].child_nurseries == [] @@ -1442,6 +1456,7 @@ async def parent(task_status=_core.TASK_STATUS_IGNORED): assert tasks["parent"].child_nurseries == [] + nursery: _core.Nursery | None async with _core.open_nursery() as nursery: nurseries["parent"] = nursery await nursery.start(child1) @@ -1459,7 +1474,7 @@ async def parent(task_status=_core.TASK_STATUS_IGNORED): t = nursery.parent_task nursery = t.parent_nursery - async def child2(): + async def child2() -> None: tasks["child2"] = _core.current_task() assert tasks["parent"].child_nurseries == [nurseries["parent"]] assert nurseries["parent"].child_tasks == frozenset({tasks["child1"]}) @@ -1467,9 +1482,11 @@ async def child2(): assert nurseries["child1"].child_tasks == frozenset({tasks["child2"]}) assert tasks["child2"].child_nurseries == [] - async def child1(task_status=_core.TASK_STATUS_IGNORED): + async def child1( + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: me = tasks["child1"] = _core.current_task() - assert me.parent_nursery.parent_task is tasks["parent"] + assert not_none(me.parent_nursery).parent_task is tasks["parent"] assert me.parent_nursery is not nurseries["parent"] assert me.eventual_parent_nursery is nurseries["parent"] task_status.started() @@ -1493,13 +1510,13 @@ async def child1(task_status=_core.TASK_STATUS_IGNORED): assert task.eventual_parent_nursery is None -async def test_nursery_closure(): - async def child1(nursery): +async def test_nursery_closure() -> None: + async def child1(nursery: _core.Nursery) -> None: # We can add new tasks to the nursery even after entering __aexit__, # so long as there are still tasks running nursery.start_soon(child2) - async def child2(): + async def child2() -> None: pass async with _core.open_nursery() as nursery: @@ -1510,24 +1527,27 @@ async def child2(): nursery.start_soon(child2) -async def test_spawn_name(): - async def func1(expected): +async def test_spawn_name() -> None: + async def func1(expected: str) -> None: task = _core.current_task() assert expected in task.name - async def func2(): # pragma: no cover + async def func2() -> None: # pragma: no cover pass + async def check(spawn_fn: Callable[..., object]) -> None: + spawn_fn(func1, "func1") + spawn_fn(func1, "func2", name=func2) + spawn_fn(func1, "func3", name="func3") + spawn_fn(functools.partial(func1, "func1")) + spawn_fn(func1, "object", name=object()) + async with _core.open_nursery() as nursery: - for spawn_fn in [nursery.start_soon, _core.spawn_system_task]: - spawn_fn(func1, "func1") - spawn_fn(func1, "func2", name=func2) - spawn_fn(func1, "func3", name="func3") - spawn_fn(functools.partial(func1, "func1")) - spawn_fn(func1, "object", name=object()) + await check(nursery.start_soon) + await check(_core.spawn_system_task) -async def test_current_effective_deadline(mock_clock): +async def test_current_effective_deadline(mock_clock: _core.MockClock) -> None: assert _core.current_effective_deadline() == inf with _core.CancelScope(deadline=5) as scope1: @@ -1549,39 +1569,45 @@ async def test_current_effective_deadline(mock_clock): assert _core.current_effective_deadline() == inf -def test_nice_error_on_bad_calls_to_run_or_spawn(): - def bad_call_run(*args): - _core.run(*args) +def test_nice_error_on_bad_calls_to_run_or_spawn() -> None: + def bad_call_run( + func: Callable[..., Awaitable[object]], + *args: tuple[object, ...], + ) -> None: + _core.run(func, *args) - def bad_call_spawn(*args): - async def main(): + def bad_call_spawn( + func: Callable[..., Awaitable[object]], + *args: tuple[object, ...], + ) -> None: + async def main() -> None: async with _core.open_nursery() as nursery: - nursery.start_soon(*args) + nursery.start_soon(func, *args) _core.run(main) for bad_call in bad_call_run, bad_call_spawn: - async def f(): # pragma: no cover + async def f() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="expecting an async function"): - bad_call(f()) + bad_call(f()) # type: ignore[arg-type] - async def async_gen(arg): # pragma: no cover + async def async_gen(arg: T) -> AsyncGenerator[T, None]: # pragma: no cover yield arg with pytest.raises( TypeError, match="expected an async function but got an async generator" ): - bad_call(async_gen, 0) + bad_call(async_gen, 0) # type: ignore -def test_calling_asyncio_function_gives_nice_error(): - async def child_xyzzy(): +def test_calling_asyncio_function_gives_nice_error() -> None: + async def child_xyzzy() -> None: await create_asyncio_future_in_new_loop() - async def misguided(): + async def misguided() -> None: await child_xyzzy() with pytest.raises(TypeError) as excinfo: @@ -1594,18 +1620,16 @@ async def misguided(): ) -async def test_asyncio_function_inside_nursery_does_not_explode(): +async def test_asyncio_function_inside_nursery_does_not_explode() -> None: # Regression test for https://github.com/python-trio/trio/issues/552 with pytest.raises(TypeError) as excinfo: async with _core.open_nursery() as nursery: - import asyncio - nursery.start_soon(sleep_forever) await create_asyncio_future_in_new_loop() assert "asyncio" in str(excinfo.value) -async def test_trivial_yields(): +async def test_trivial_yields() -> None: with assert_checkpoints(): await _core.checkpoint() @@ -1629,8 +1653,8 @@ async def test_trivial_yields(): } -async def test_nursery_start(autojump_clock): - async def no_args(): # pragma: no cover +async def test_nursery_start(autojump_clock: _core.MockClock) -> None: + async def no_args() -> None: # pragma: no cover pass # Errors in calling convention get raised immediately from start @@ -1638,7 +1662,9 @@ async def no_args(): # pragma: no cover with pytest.raises(TypeError): await nursery.start(no_args) - async def sleep_then_start(seconds, *, task_status=_core.TASK_STATUS_IGNORED): + async def sleep_then_start( + seconds: int, *, task_status: _core.TaskStatus[int] = _core.TASK_STATUS_IGNORED + ) -> None: repr(task_status) # smoke test await sleep(seconds) task_status.started(seconds) @@ -1663,7 +1689,9 @@ async def sleep_then_start(seconds, *, task_status=_core.TASK_STATUS_IGNORED): assert _core.current_time() - t0 == 2 * 3 # calling started twice - async def double_started(task_status=_core.TASK_STATUS_IGNORED): + async def double_started( + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: task_status.started() with pytest.raises(RuntimeError): task_status.started() @@ -1672,7 +1700,9 @@ async def double_started(task_status=_core.TASK_STATUS_IGNORED): await nursery.start(double_started) # child crashes before calling started -> error comes out of .start() - async def raise_keyerror(task_status=_core.TASK_STATUS_IGNORED): + async def raise_keyerror( + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: raise KeyError("oops") async with _core.open_nursery() as nursery: @@ -1680,18 +1710,22 @@ async def raise_keyerror(task_status=_core.TASK_STATUS_IGNORED): await nursery.start(raise_keyerror) # child exiting cleanly before calling started -> triggers a RuntimeError - async def nothing(task_status=_core.TASK_STATUS_IGNORED): + async def nothing( + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: return async with _core.open_nursery() as nursery: - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError) as excinfo1: await nursery.start(nothing) - assert "exited without calling" in str(excinfo.value) + assert "exited without calling" in str(excinfo1.value) # if the call to start() is cancelled, then the call to started() does # nothing -- the child keeps executing under start(). The value it passed # is ignored; start() raises Cancelled. - async def just_started(task_status=_core.TASK_STATUS_IGNORED): + async def just_started( + task_status: _core.TaskStatus[str] = _core.TASK_STATUS_IGNORED, + ) -> None: task_status.started("hi") async with _core.open_nursery() as nursery: @@ -1702,16 +1736,18 @@ async def just_started(task_status=_core.TASK_STATUS_IGNORED): # and if after the no-op started(), the child crashes, the error comes out # of start() - async def raise_keyerror_after_started(task_status=_core.TASK_STATUS_IGNORED): + async def raise_keyerror_after_started( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED + ) -> None: task_status.started() raise KeyError("whoopsiedaisy") async with _core.open_nursery() as nursery: with _core.CancelScope() as cs: cs.cancel() - with pytest.raises(MultiError) as excinfo: + with pytest.raises(MultiError) as excinfo2: await nursery.start(raise_keyerror_after_started) - assert {type(e) for e in excinfo.value.exceptions} == { + assert {type(e) for e in excinfo2.value.exceptions} == { _core.Cancelled, KeyError, } @@ -1725,7 +1761,7 @@ async def raise_keyerror_after_started(task_status=_core.TASK_STATUS_IGNORED): assert _core.current_time() == t0 -async def test_task_nursery_stack(): +async def test_task_nursery_stack() -> None: task = _core.current_task() assert task._child_nurseries == [] async with _core.open_nursery() as nursery1: @@ -1738,10 +1774,12 @@ async def test_task_nursery_stack(): assert task._child_nurseries == [] -async def test_nursery_start_with_cancelled_nursery(): +async def test_nursery_start_with_cancelled_nursery() -> None: # This function isn't testing task_status, it's using task_status as a # convenient way to get a nursery that we can test spawning stuff into. - async def setup_nursery(task_status=_core.TASK_STATUS_IGNORED): + async def setup_nursery( + task_status: _core.TaskStatus[_core.Nursery] = _core.TASK_STATUS_IGNORED, + ) -> None: async with _core.open_nursery() as nursery: task_status.started(nursery) await sleep_forever() @@ -1749,7 +1787,11 @@ async def setup_nursery(task_status=_core.TASK_STATUS_IGNORED): # Calls started() while children are asleep, so we can make sure # that the cancellation machinery notices and aborts when a sleeping task # is moved into a cancelled scope. - async def sleeping_children(fn, *, task_status=_core.TASK_STATUS_IGNORED): + async def sleeping_children( + fn: Callable[[], object], + *, + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: async with _core.open_nursery() as nursery: nursery.start_soon(sleep_forever) nursery.start_soon(sleep_forever) @@ -1759,7 +1801,7 @@ async def sleeping_children(fn, *, task_status=_core.TASK_STATUS_IGNORED): # Cancelling the setup_nursery just *before* calling started() async with _core.open_nursery() as nursery: - target_nursery = await nursery.start(setup_nursery) + target_nursery: _core.Nursery = await nursery.start(setup_nursery) await target_nursery.start( sleeping_children, target_nursery.cancel_scope.cancel ) @@ -1771,8 +1813,12 @@ async def sleeping_children(fn, *, task_status=_core.TASK_STATUS_IGNORED): target_nursery.cancel_scope.cancel() -async def test_nursery_start_keeps_nursery_open(autojump_clock): - async def sleep_a_bit(task_status=_core.TASK_STATUS_IGNORED): +async def test_nursery_start_keeps_nursery_open( + autojump_clock: _core.MockClock, +) -> None: + async def sleep_a_bit( + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: await sleep(2) task_status.started() await sleep(3) @@ -1794,11 +1840,13 @@ async def sleep_a_bit(task_status=_core.TASK_STATUS_IGNORED): # Check that it still works even if the task that the nursery is waiting # for ends up crashing, and never actually enters the nursery. - async def sleep_then_crash(task_status=_core.TASK_STATUS_IGNORED): + async def sleep_then_crash( + task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, + ) -> None: await sleep(7) raise KeyError - async def start_sleep_then_crash(nursery): + async def start_sleep_then_crash(nursery: _core.Nursery) -> None: with pytest.raises(KeyError): await nursery.start(sleep_then_crash) @@ -1810,14 +1858,14 @@ async def start_sleep_then_crash(nursery): assert _core.current_time() - t0 == 7 -async def test_nursery_explicit_exception(): +async def test_nursery_explicit_exception() -> None: with pytest.raises(KeyError): async with _core.open_nursery(): raise KeyError() -async def test_nursery_stop_iteration(): - async def fail(): +async def test_nursery_stop_iteration() -> None: + async def fail() -> NoReturn: raise ValueError try: @@ -1828,13 +1876,13 @@ async def fail(): assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError) -async def test_nursery_stop_async_iteration(): +async def test_nursery_stop_async_iteration() -> None: class it: - def __init__(self, count): + def __init__(self, count: int): self.count = count self.val = 0 - async def __anext__(self): + async def __anext__(self) -> int: await sleep(0) val = self.val if val >= self.count: @@ -1843,18 +1891,20 @@ async def __anext__(self): return val class async_zip: - def __init__(self, *largs): + def __init__(self, *largs: it): self.nexts = [obj.__anext__ for obj in largs] - async def _accumulate(self, f, items, i): + async def _accumulate( + self, f: Callable[[], Awaitable[int]], items: list[int | None], i: int + ) -> None: items[i] = await f() - def __aiter__(self): + def __aiter__(self) -> async_zip: return self - async def __anext__(self): + async def __anext__(self) -> list[int]: nexts = self.nexts - items = [None] * len(nexts) + items: list[int] = [-1] * len(nexts) async with _core.open_nursery() as nursery: for i, f in enumerate(nexts): @@ -1862,14 +1912,14 @@ async def __anext__(self): return items - result = [] + result: list[list[int]] = [] async for vals in async_zip(it(4), it(2)): result.append(vals) assert result == [[0, 0], [1, 1]] -async def test_traceback_frame_removal(): - async def my_child_task(): +async def test_traceback_frame_removal() -> None: + async def my_child_task() -> NoReturn: raise KeyError() try: @@ -1888,17 +1938,18 @@ async def my_child_task(): # task, not trio/contextvars internals. And there's only one frame # inside the child task, so this will also detect if our frame-removal # is too eager. - frame = first_exc.__traceback__.tb_frame - assert frame.f_code is my_child_task.__code__ + tb = first_exc.__traceback__ + assert tb is not None + assert tb.tb_frame.f_code is my_child_task.__code__ -def test_contextvar_support(): - var = contextvars.ContextVar("test") +def test_contextvar_support() -> None: + var: contextvars.ContextVar[str] = contextvars.ContextVar("test") var.set("before") assert var.get() == "before" - async def inner(): + async def inner() -> None: task = _core.current_task() assert task.context.get(var) == "before" assert var.get() == "before" @@ -1911,15 +1962,15 @@ async def inner(): assert var.get() == "before" -async def test_contextvar_multitask(): +async def test_contextvar_multitask() -> None: var = contextvars.ContextVar("test", default="hmmm") - async def t1(): + async def t1() -> None: assert var.get() == "hmmm" var.set("hmmmm") assert var.get() == "hmmmm" - async def t2(): + async def t2() -> None: assert var.get() == "hmmmm" async with _core.open_nursery() as n: @@ -1931,17 +1982,17 @@ async def t2(): await wait_all_tasks_blocked() -def test_system_task_contexts(): - cvar = contextvars.ContextVar("qwilfish") +def test_system_task_contexts() -> None: + cvar: contextvars.ContextVar[str] = contextvars.ContextVar("qwilfish") cvar.set("water") - async def system_task(): + async def system_task() -> None: assert cvar.get() == "water" - async def regular_task(): + async def regular_task() -> None: assert cvar.get() == "poison" - async def inner(): + async def inner() -> None: async with _core.open_nursery() as nursery: cvar.set("poison") nursery.start_soon(regular_task) @@ -1951,25 +2002,28 @@ async def inner(): _core.run(inner) -def test_Nursery_init(): +async def test_Nursery_init() -> None: + """Test that nurseries cannot be constructed directly.""" + # This function is async so that we have access to a task object we can + # pass in. It should never be accessed though. + task = _core.current_task() + scope = _core.CancelScope() with pytest.raises(TypeError): - _core._run.Nursery(None, None) + _core._run.Nursery(task, scope, True) -async def test_Nursery_private_init(): +async def test_Nursery_private_init() -> None: # context manager creation should not raise async with _core.open_nursery() as nursery: assert False == nursery._closed -def test_Nursery_subclass(): +def test_Nursery_subclass() -> None: with pytest.raises(TypeError): - - class Subclass(_core._run.Nursery): - pass + type("Subclass", (_core._run.Nursery,), {}) -def test_Cancelled_init(): +def test_Cancelled_init() -> None: with pytest.raises(TypeError): raise _core.Cancelled @@ -1980,33 +2034,29 @@ def test_Cancelled_init(): _core.Cancelled._create() -def test_Cancelled_str(): +def test_Cancelled_str() -> None: cancelled = _core.Cancelled._create() assert str(cancelled) == "Cancelled" -def test_Cancelled_subclass(): +def test_Cancelled_subclass() -> None: with pytest.raises(TypeError): - - class Subclass(_core.Cancelled): - pass + type("Subclass", (_core.Cancelled,), {}) -def test_CancelScope_subclass(): +def test_CancelScope_subclass() -> None: with pytest.raises(TypeError): - - class Subclass(_core.CancelScope): - pass + type("Subclass", (_core.CancelScope,), {}) -def test_sniffio_integration(): +def test_sniffio_integration() -> None: with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() - async def check_inside_trio(): + async def check_inside_trio() -> None: assert sniffio.current_async_library() == "trio" - def check_function_returning_coroutine(): + def check_function_returning_coroutine() -> Awaitable[object]: assert sniffio.current_async_library() == "trio" return check_inside_trio() @@ -2015,18 +2065,29 @@ def check_function_returning_coroutine(): with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() - async def check_new_task_resets_sniffio_library(): - sniffio.current_async_library_cvar.set("nullio") - _core.spawn_system_task(check_inside_trio) + @contextmanager + def alternate_sniffio_library() -> Generator[None, None, None]: + prev_token = sniffio.current_async_library_cvar.set("nullio") + prev_library, sniffio.thread_local.name = sniffio.thread_local.name, "nullio" + try: + yield + assert sniffio.current_async_library() == "nullio" + finally: + sniffio.thread_local.name = prev_library + sniffio.current_async_library_cvar.reset(prev_token) + + async def check_new_task_resets_sniffio_library() -> None: + with alternate_sniffio_library(): + _core.spawn_system_task(check_inside_trio) async with _core.open_nursery() as nursery: - nursery.start_soon(check_inside_trio) - nursery.start_soon(check_function_returning_coroutine) - assert sniffio.current_async_library() == "nullio" + with alternate_sniffio_library(): + nursery.start_soon(check_inside_trio) + nursery.start_soon(check_function_returning_coroutine) _core.run(check_new_task_resets_sniffio_library) -async def test_Task_custom_sleep_data(): +async def test_Task_custom_sleep_data() -> None: task = _core.current_task() assert task.custom_sleep_data is None task.custom_sleep_data = 1 @@ -2036,15 +2097,18 @@ async def test_Task_custom_sleep_data(): @types.coroutine -def async_yield(value): +def async_yield(value: T) -> Generator[T, None, None]: yield value -async def test_permanently_detach_coroutine_object(): - task = None - pdco_outcome = None +async def test_permanently_detach_coroutine_object() -> None: + task: _core.Task | None = None + pdco_outcome: outcome.Outcome[str] | None = None - async def detachable_coroutine(task_outcome, yield_value): + async def detachable_coroutine( + task_outcome: outcome.Outcome[None], + yield_value: object, + ) -> None: await sleep(0) nonlocal task, pdco_outcome task = _core.current_task() @@ -2059,10 +2123,10 @@ async def detachable_coroutine(task_outcome, yield_value): # If we get here then Trio thinks the task has exited... but the coroutine # is still iterable assert pdco_outcome is None - assert task.coro.send("be free!") == "I'm free!" + assert not_none(task).coro.send("be free!") == "I'm free!" assert pdco_outcome == outcome.Value("be free!") with pytest.raises(StopIteration): - task.coro.send(None) + not_none(task).coro.send(None) # Check the exception paths too task = None @@ -2071,12 +2135,13 @@ async def detachable_coroutine(task_outcome, yield_value): async with _core.open_nursery() as nursery: nursery.start_soon(detachable_coroutine, outcome.Error(KeyError()), "uh oh") throw_in = ValueError() - assert task.coro.throw(throw_in) == "uh oh" + assert isinstance(task, _core.Task) # For type checkers. + assert not_none(task).coro.throw(throw_in) == "uh oh" assert pdco_outcome == outcome.Error(throw_in) with pytest.raises(StopIteration): task.coro.send(None) - async def bad_detach(): + async def bad_detach() -> None: async with _core.open_nursery(): with pytest.raises(RuntimeError) as excinfo: await _core.permanently_detach_coroutine_object(outcome.Value(None)) @@ -2086,21 +2151,21 @@ async def bad_detach(): nursery.start_soon(bad_detach) -async def test_detach_and_reattach_coroutine_object(): - unrelated_task = None - task = None +async def test_detach_and_reattach_coroutine_object() -> None: + unrelated_task: _core.Task | None = None + task: _core.Task | None = None - async def unrelated_coroutine(): + async def unrelated_coroutine() -> None: nonlocal unrelated_task unrelated_task = _core.current_task() - async def reattachable_coroutine(): + async def reattachable_coroutine() -> None: + nonlocal task await sleep(0) - nonlocal task task = _core.current_task() - def abort_fn(_): # pragma: no cover + def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: # pragma: no cover return _core.Abort.FAILED got = await _core.temporarily_detach_coroutine_object(abort_fn) @@ -2110,7 +2175,9 @@ def abort_fn(_): # pragma: no cover await async_yield(2) with pytest.raises(RuntimeError) as excinfo: - await _core.reattach_detached_coroutine_object(unrelated_task, None) + await _core.reattach_detached_coroutine_object( + not_none(unrelated_task), None + ) assert "does not match" in str(excinfo.value) await _core.reattach_detached_coroutine_object(task, "byebye") @@ -2121,28 +2188,26 @@ def abort_fn(_): # pragma: no cover nursery.start_soon(unrelated_coroutine) nursery.start_soon(reattachable_coroutine) await wait_all_tasks_blocked() - assert unrelated_task is not None - assert task is not None # Okay, it's detached. Here's our coroutine runner: - assert task.coro.send("not trio!") == 1 - assert task.coro.send(None) == 2 - assert task.coro.send(None) == "byebye" + assert not_none(task).coro.send("not trio!") == 1 + assert not_none(task).coro.send(None) == 2 + assert not_none(task).coro.send(None) == "byebye" # Now it's been reattached, and we can leave the nursery -async def test_detached_coroutine_cancellation(): +async def test_detached_coroutine_cancellation() -> None: abort_fn_called = False - task = None + task: _core.Task | None = None - async def reattachable_coroutine(): + async def reattachable_coroutine() -> None: await sleep(0) nonlocal task task = _core.current_task() - def abort_fn(_): + def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: nonlocal abort_fn_called abort_fn_called = True return _core.Abort.FAILED @@ -2163,22 +2228,22 @@ def abort_fn(_): @restore_unraisablehook() -def test_async_function_implemented_in_C(): +def test_async_function_implemented_in_C() -> None: # These used to crash because we'd try to mutate the coroutine object's # cr_frame, but C functions don't have Python frames. - async def agen_fn(record): + async def agen_fn(record: list[str]) -> AsyncIterator[None]: assert not _core.currently_ki_protected() record.append("the generator ran") yield - run_record = [] + run_record: list[str] = [] agen = agen_fn(run_record) _core.run(agen.__anext__) assert run_record == ["the generator ran"] - async def main(): - start_soon_record = [] + async def main() -> None: + start_soon_record: list[str] = [] agen = agen_fn(start_soon_record) async with _core.open_nursery() as nursery: nursery.start_soon(agen.__anext__) @@ -2187,7 +2252,7 @@ async def main(): _core.run(main) -async def test_very_deep_cancel_scope_nesting(): +async def test_very_deep_cancel_scope_nesting() -> None: # This used to crash with a RecursionError in CancelStatus.recalculate with ExitStack() as exit_stack: outermost_scope = _core.CancelScope() @@ -2197,7 +2262,7 @@ async def test_very_deep_cancel_scope_nesting(): outermost_scope.cancel() -async def test_cancel_scope_deadline_duplicates(): +async def test_cancel_scope_deadline_duplicates() -> None: # This exercises an assert in Deadlines._prune, by intentionally creating # duplicate entries in the deadline heap. now = _core.current_time() @@ -2211,16 +2276,16 @@ async def test_cancel_scope_deadline_duplicates(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage(): +async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/issues/1770 gc.collect() - async def do_a_cancel(): + async def do_a_cancel() -> None: with _core.CancelScope() as cscope: cscope.cancel() await sleep_forever() - async def crasher(): + async def crasher() -> NoReturn: raise ValueError old_flags = gc.get_debug() @@ -2250,11 +2315,11 @@ async def crasher(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_cancel_scope_exit_doesnt_create_cyclic_garbage(): +async def test_cancel_scope_exit_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/pull/2063 gc.collect() - async def crasher(): + async def crasher() -> NoReturn: raise ValueError old_flags = gc.get_debug() @@ -2279,16 +2344,21 @@ async def crasher(): gc.garbage.clear() +@pytest.mark.xfail( + sys.version_info >= (3, 12), + reason="Waiting on https://github.com/python/cpython/issues/100964", +) @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_nursery_cancel_doesnt_create_cyclic_garbage(): +async def test_nursery_cancel_doesnt_create_cyclic_garbage() -> None: + collected = False + # https://github.com/python-trio/trio/issues/1770#issuecomment-730229423 - def toggle_collected(): + def toggle_collected() -> None: nonlocal collected collected = True - collected = False gc.collect() old_flags = gc.get_debug() try: @@ -2317,17 +2387,17 @@ def toggle_collected(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_locals_destroyed_promptly_on_cancel(): +async def test_locals_destroyed_promptly_on_cancel() -> None: destroyed = False - def finalizer(): + def finalizer() -> None: nonlocal destroyed destroyed = True class A: pass - async def task(): + async def task() -> None: a = A() weakref.finalize(a, finalizer) await _core.checkpoint() @@ -2338,12 +2408,12 @@ async def task(): assert destroyed -def test_run_strict_exception_groups(): +def test_run_strict_exception_groups() -> None: """ Test that nurseries respect the global context setting of strict_exception_groups. """ - async def main(): + async def main() -> NoReturn: async with _core.open_nursery(): raise Exception("foo") @@ -2355,13 +2425,13 @@ async def main(): assert exc.value.exceptions[0].args == ("foo",) -def test_run_strict_exception_groups_nursery_override(): +def test_run_strict_exception_groups_nursery_override() -> None: """ Test that a nursery can override the global context setting of strict_exception_groups. """ - async def main(): + async def main() -> NoReturn: async with _core.open_nursery(strict_exception_groups=False): raise Exception("foo") @@ -2369,7 +2439,7 @@ async def main(): _core.run(main, strict_exception_groups=True) -async def test_nursery_strict_exception_groups(): +async def test_nursery_strict_exception_groups() -> None: """Test that strict exception groups can be enabled on a per-nursery basis.""" with pytest.raises(MultiError) as exc: async with _core.open_nursery(strict_exception_groups=True): @@ -2380,13 +2450,13 @@ async def test_nursery_strict_exception_groups(): assert exc.value.exceptions[0].args == ("foo",) -async def test_nursery_collapse_strict(): +async def test_nursery_collapse_strict() -> None: """ Test that a single exception from a nested nursery with strict semantics doesn't get collapsed when CancelledErrors are stripped from it. """ - async def raise_error(): + async def raise_error() -> NoReturn: raise RuntimeError("test error") with pytest.raises(MultiError) as exc: @@ -2406,13 +2476,13 @@ async def raise_error(): assert isinstance(exceptions[1].exceptions[0], RuntimeError) -async def test_nursery_collapse_loose(): +async def test_nursery_collapse_loose() -> None: """ Test that a single exception from a nested nursery with loose semantics gets collapsed when CancelledErrors are stripped from it. """ - async def raise_error(): + async def raise_error() -> NoReturn: raise RuntimeError("test error") with pytest.raises(MultiError) as exc: @@ -2430,7 +2500,7 @@ async def raise_error(): assert isinstance(exceptions[1], RuntimeError) -async def test_cancel_scope_no_cancellederror(): +async def test_cancel_scope_no_cancellederror() -> None: """ Test that when a cancel scope encounters an exception group that does NOT contain a Cancelled exception, it will NOT set the ``cancelled_caught`` flag. diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py similarity index 97% rename from trio/_core/tests/test_thread_cache.py rename to trio/_core/_tests/test_thread_cache.py index 5f19a5ac64..de78443f4e 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/_tests/test_thread_cache.py @@ -1,13 +1,13 @@ -import pytest import threading -from queue import Queue import time -import sys from contextlib import contextmanager +from queue import Queue + +import pytest -from .tutil import slow, gc_collect_harder, disable_threading_excepthook from .. import _thread_cache -from .._thread_cache import start_thread_soon, ThreadCache +from .._thread_cache import ThreadCache, start_thread_soon +from .tutil import gc_collect_harder, slow def test_thread_cache_basics(): diff --git a/trio/_core/tests/test_tutil.py b/trio/_core/_tests/test_tutil.py similarity index 100% rename from trio/_core/tests/test_tutil.py rename to trio/_core/_tests/test_tutil.py diff --git a/trio/_core/tests/test_unbounded_queue.py b/trio/_core/_tests/test_unbounded_queue.py similarity index 100% rename from trio/_core/tests/test_unbounded_queue.py rename to trio/_core/_tests/test_unbounded_queue.py diff --git a/trio/_core/tests/test_windows.py b/trio/_core/_tests/test_windows.py similarity index 82% rename from trio/_core/tests/test_windows.py rename to trio/_core/_tests/test_windows.py index bd81ef0f33..99bb97284b 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/_tests/test_windows.py @@ -1,6 +1,9 @@ import os +import sys import tempfile from contextlib import contextmanager +from typing import TYPE_CHECKING +from unittest.mock import create_autospec import pytest @@ -8,20 +11,59 @@ # Mark all the tests in this file as being windows-only pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") -from .tutil import slow, gc_collect_harder, restore_unraisablehook -from ... import _core, sleep, move_on_after +assert sys.platform == "win32" or not TYPE_CHECKING # Skip type checking on Windows + +from ... import _core, sleep from ...testing import wait_all_tasks_blocked +from .tutil import gc_collect_harder, restore_unraisablehook, slow if on_windows: from .._windows_cffi import ( + INVALID_HANDLE_VALUE, + FileFlags, ffi, kernel32, - INVALID_HANDLE_VALUE, raise_winerror, - FileFlags, ) +def test_winerror(monkeypatch) -> None: + mock = create_autospec(ffi.getwinerror) + monkeypatch.setattr(ffi, "getwinerror", mock) + + # Returning none = no error, should not happen. + mock.return_value = None + with pytest.raises(RuntimeError, match="No error set"): + raise_winerror() + mock.assert_called_once_with() + mock.reset_mock() + + with pytest.raises(RuntimeError, match="No error set"): + raise_winerror(38) + mock.assert_called_once_with(38) + mock.reset_mock() + + mock.return_value = (12, "test error") + with pytest.raises(OSError) as exc: + raise_winerror(filename="file_1", filename2="file_2") + mock.assert_called_once_with() + mock.reset_mock() + assert exc.value.winerror == 12 + assert exc.value.strerror == "test error" + assert exc.value.filename == "file_1" + assert exc.value.filename2 == "file_2" + + # With an explicit number passed in, it overrides what getwinerror() returns. + with pytest.raises(OSError) as exc: + raise_winerror(18, filename="a/file", filename2="b/file") + mock.assert_called_once_with(18) + mock.reset_mock() + assert exc.value.winerror == 18 + assert exc.value.strerror == "test error" + assert exc.value.filename == "a/file" + assert exc.value.filename2 == "b/file" + + # The undocumented API that this is testing should be changed to stop using # UnboundedQueue (or just removed until we have time to redo it), but until # then we filter out the warning. @@ -91,7 +133,7 @@ async def read_region(start, end): assert buffer == data - with pytest.raises(BufferError): + with pytest.raises((BufferError, TypeError)): await _core.readinto_overlapped(handle, b"immutable") finally: kernel32.CloseHandle(handle) @@ -99,8 +141,8 @@ async def read_region(start, end): @contextmanager def pipe_with_overlapped_read(): - from asyncio.windows_utils import pipe import msvcrt + from asyncio.windows_utils import pipe read_handle, write_handle = pipe(overlapped=(True, False)) try: @@ -175,8 +217,8 @@ async def test_too_late_to_cancel(): def test_lsp_that_hooks_select_gives_good_error(monkeypatch): - from .._windows_cffi import WSAIoctls, _handle from .. import _io_windows + from .._windows_cffi import WSAIoctls, _handle def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): if hasattr(sock, "fileno"): # pragma: no branch @@ -199,8 +241,8 @@ def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch): # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to # make sure we get an error rather than an infinite loop. - from .._windows_cffi import WSAIoctls, _handle from .. import _io_windows + from .._windows_cffi import WSAIoctls, _handle def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): if hasattr(sock, "fileno"): # pragma: no branch diff --git a/trio/_core/tests/tutil.py b/trio/_core/_tests/tutil.py similarity index 78% rename from trio/_core/tests/tutil.py rename to trio/_core/_tests/tutil.py index 016e0fd3e1..b3aa73fb7d 100644 --- a/trio/_core/tests/tutil.py +++ b/trio/_core/_tests/tutil.py @@ -1,19 +1,17 @@ # Utilities for testing import asyncio -import socket as stdlib_socket -import threading +import gc import os +import socket as stdlib_socket import sys +import warnings +from contextlib import closing, contextmanager from typing import TYPE_CHECKING import pytest -import warnings -from contextlib import contextmanager, closing - -import gc -# See trio/tests/conftest.py for the other half of this -from trio.tests.conftest import RUN_SLOW +# See trio/_tests/conftest.py for the other half of this +from trio._tests.pytest_plugin import RUN_SLOW slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests") @@ -43,7 +41,7 @@ with s: try: s.bind(("::1", 0)) - except OSError: + except OSError: # pragma: no cover # since support for 3.7 was removed can_bind_ipv6 = False else: can_bind_ipv6 = True @@ -62,7 +60,7 @@ def gc_collect_harder(): # garbage collection, because executing their __del__ method to print the # warning can cause them to be resurrected. So we call collect a few times # to make sure. - for _ in range(4): + for _ in range(5): gc.collect() @@ -86,37 +84,13 @@ def _noop(*args, **kwargs): pass -if sys.version_info >= (3, 8): - - @contextmanager - def restore_unraisablehook(): - sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook - try: - yield - finally: - sys.unraisablehook = prev - - @contextmanager - def disable_threading_excepthook(): - if sys.version_info >= (3, 10): - threading.excepthook, prev = threading.__excepthook__, threading.excepthook - else: - threading.excepthook, prev = _noop, threading.excepthook - - try: - yield - finally: - threading.excepthook = prev - -else: - - @contextmanager - def restore_unraisablehook(): # pragma: no cover - yield - - @contextmanager - def disable_threading_excepthook(): # pragma: no cover +@contextmanager +def restore_unraisablehook(): + sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook + try: yield + finally: + sys.unraisablehook = prev # template is like: diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py index a36181ee36..8381153576 100644 --- a/trio/_core/_thread_cache.py +++ b/trio/_core/_thread_cache.py @@ -1,13 +1,17 @@ -import sys -import traceback -from threading import Thread, Lock -import outcome +from __future__ import annotations + import ctypes import ctypes.util +import sys +import traceback +from functools import partial from itertools import count +from threading import Lock, Thread +from typing import Any, Callable, Generic, TypeVar -from typing import Callable, Optional, Tuple -from functools import partial +import outcome + +RetT = TypeVar("RetT") def _to_os_thread_name(name: str) -> bytes: @@ -17,18 +21,20 @@ def _to_os_thread_name(name: str) -> bytes: # used to construct the method used to set os thread name, or None, depending on platform. # called once on import -def get_os_thread_name_func() -> Optional[Callable[[Optional[int], str], None]]: - def namefunc(setname: Callable[[int, bytes], int], ident: Optional[int], name: str): +def get_os_thread_name_func() -> Callable[[int | None, str], None] | None: + def namefunc( + setname: Callable[[int, bytes], int], ident: int | None, name: str + ) -> None: # Thread.ident is None "if it has not been started". Unclear if that can happen # with current usage. if ident is not None: # pragma: no cover setname(ident, _to_os_thread_name(name)) - # namefunc on mac also takes an ident, even if pthread_setname_np doesn't/can't use it + # namefunc on Mac also takes an ident, even if pthread_setname_np doesn't/can't use it # so the caller don't need to care about platform. def darwin_namefunc( - setname: Callable[[bytes], int], ident: Optional[int], name: str - ): + setname: Callable[[bytes], int], ident: int | None, name: str + ) -> None: # I don't know if Mac can rename threads that hasn't been started, but default # to no to be on the safe side. if ident is not None: # pragma: no cover @@ -39,7 +45,14 @@ def darwin_namefunc( libpthread_path = ctypes.util.find_library("pthread") if not libpthread_path: return None - libpthread = ctypes.CDLL(libpthread_path) + + # Sometimes windows can find the path, but gives a permission error when + # accessing it. Catching a wider exception in case of more esoteric errors. + # https://github.com/python-trio/trio/issues/2688 + try: + libpthread = ctypes.CDLL(libpthread_path) + except Exception: # pragma: no cover + return None # get the setname method from it # afaik this should never fail @@ -103,9 +116,13 @@ def darwin_namefunc( name_counter = count() -class WorkerThread: - def __init__(self, thread_cache): - self._job: Optional[Tuple[Callable, Callable, str]] = None +class WorkerThread(Generic[RetT]): + def __init__(self, thread_cache: ThreadCache) -> None: + self._job: tuple[ + Callable[[], RetT], + Callable[[outcome.Outcome[RetT]], object], + str | None, + ] | None = None self._thread_cache = thread_cache # This Lock is used in an unconventional way. # @@ -123,7 +140,7 @@ def __init__(self, thread_cache): set_os_thread_name(self._thread.ident, self._default_name) self._thread.start() - def _handle_job(self): + def _handle_job(self) -> None: # Handle job in a separate method to ensure user-created # objects are cleaned up in a consistent manner. assert self._job is not None @@ -154,7 +171,7 @@ def _handle_job(self): print("Exception while delivering result of thread", file=sys.stderr) traceback.print_exception(type(e), e, e.__traceback__) - def _work(self): + def _work(self) -> None: while True: if self._worker_lock.acquire(timeout=IDLE_TIMEOUT): # We got a job @@ -178,10 +195,16 @@ def _work(self): class ThreadCache: - def __init__(self): - self._idle_workers = {} - - def start_thread_soon(self, fn, deliver, name: Optional[str] = None): + def __init__(self) -> None: + self._idle_workers: dict[WorkerThread[Any], None] = {} + + def start_thread_soon( + self, + fn: Callable[[], RetT], + deliver: Callable[[outcome.Outcome[RetT]], object], + name: str | None = None, + ) -> None: + worker: WorkerThread[RetT] try: worker, _ = self._idle_workers.popitem() except KeyError: @@ -193,7 +216,11 @@ def start_thread_soon(self, fn, deliver, name: Optional[str] = None): THREAD_CACHE = ThreadCache() -def start_thread_soon(fn, deliver, name: Optional[str] = None): +def start_thread_soon( + fn: Callable[[], RetT], + deliver: Callable[[outcome.Outcome[RetT]], object], + name: str | None = None, +) -> None: """Runs ``deliver(outcome.capture(fn))`` in a worker thread. Generally ``fn`` does some blocking work, and ``deliver`` delivers the diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index aedf839a8d..760c46bc51 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -1,14 +1,20 @@ -# These are the only functions that ever yield back to the task runner. +"""These are the only functions that ever yield back to the task runner.""" +from __future__ import annotations -import types import enum +import types +from typing import TYPE_CHECKING, Any, Callable, NoReturn import attr import outcome from . import _run -from typing import Callable, NoReturn, Any +if TYPE_CHECKING: + from outcome import Outcome + from typing_extensions import TypeAlias + + from ._run import Task # Helper for the bottommost 'yield'. You can't use 'yield' inside an async @@ -19,7 +25,7 @@ # tracking machinery. Since our traps are public APIs, we make them real async # functions, and then this helper takes care of the actual yield: @types.coroutine -def _async_yield(obj): +def _async_yield(obj: Any) -> Any: # type: ignore[misc] return (yield obj) @@ -29,7 +35,7 @@ class CancelShieldedCheckpoint: pass -async def cancel_shielded_checkpoint(): +async def cancel_shielded_checkpoint() -> None: """Introduce a schedule point, but not a cancel point. This is *not* a :ref:`checkpoint `, but it is half of a @@ -42,7 +48,7 @@ async def cancel_shielded_checkpoint(): await trio.lowlevel.checkpoint() """ - return (await _async_yield(CancelShieldedCheckpoint)).unwrap() + (await _async_yield(CancelShieldedCheckpoint)).unwrap() # Return values for abort functions @@ -63,10 +69,10 @@ class Abort(enum.Enum): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class WaitTaskRescheduled: - abort_func = attr.ib() + abort_func: Callable[[RaiseCancelT], Abort] = attr.ib() -RaiseCancelT = Callable[[], NoReturn] # TypeAlias +RaiseCancelT: TypeAlias = Callable[[], NoReturn] # Should always return the type a Task "expects", unless you willfully reschedule it @@ -176,10 +182,10 @@ def abort(inner_raise_cancel): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class PermanentlyDetachCoroutineObject: - final_outcome = attr.ib() + final_outcome: Outcome = attr.ib() -async def permanently_detach_coroutine_object(final_outcome): +async def permanently_detach_coroutine_object(final_outcome: Outcome) -> Any: """Permanently detach the current task from the Trio scheduler. Normally, a Trio task doesn't exit until its coroutine object exits. When @@ -210,7 +216,9 @@ async def permanently_detach_coroutine_object(final_outcome): return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) -async def temporarily_detach_coroutine_object(abort_func): +async def temporarily_detach_coroutine_object( + abort_func: Callable[[RaiseCancelT], Abort] +) -> Any: """Temporarily detach the current coroutine object from the Trio scheduler. @@ -246,7 +254,7 @@ async def temporarily_detach_coroutine_object(abort_func): return await _async_yield(WaitTaskRescheduled(abort_func)) -async def reattach_detached_coroutine_object(task, yield_value): +async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None: """Reattach a coroutine object that was detached using :func:`temporarily_detach_coroutine_object`. diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index 9c747749b4..1b7dea095f 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -1,17 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + import attr from .. import _core from .._deprecate import deprecated from .._util import Final +T = TypeVar("T") + +if TYPE_CHECKING: + from typing_extensions import Self + -@attr.s(frozen=True) -class _UnboundedQueueStats: - qsize = attr.ib() - tasks_waiting = attr.ib() +@attr.s(slots=True, frozen=True) +class UnboundedQueueStatistics: + """An object containing debugging information. + Currently the following fields are defined: + + * ``qsize``: The number of items currently in the queue. + * ``tasks_waiting``: The number of tasks blocked on this queue's + :meth:`get_batch` method. + + """ -class UnboundedQueue(metaclass=Final): + qsize: int = attr.ib() + tasks_waiting: int = attr.ib() + + +class UnboundedQueue(Generic[T], metaclass=Final): """An unbounded queue suitable for certain unusual forms of inter-task communication. @@ -47,20 +66,20 @@ class UnboundedQueue(metaclass=Final): thing="trio.lowlevel.UnboundedQueue", instead="trio.open_memory_channel(math.inf)", ) - def __init__(self): + def __init__(self) -> None: self._lot = _core.ParkingLot() - self._data = [] + self._data: list[T] = [] # used to allow handoff from put to the first task in the lot self._can_get = False - def __repr__(self): + def __repr__(self) -> str: return f"" - def qsize(self): + def qsize(self) -> int: """Returns the number of items currently in the queue.""" return len(self._data) - def empty(self): + def empty(self) -> bool: """Returns True if the queue is empty, False otherwise. There is some subtlety to interpreting this method's return value: see @@ -70,7 +89,7 @@ def empty(self): return not self._data @_core.enable_ki_protection - def put_nowait(self, obj): + def put_nowait(self, obj: T) -> None: """Put an object into the queue, without blocking. This always succeeds, because the queue is unbounded. We don't provide @@ -88,13 +107,13 @@ def put_nowait(self, obj): self._can_get = True self._data.append(obj) - def _get_batch_protected(self): + def _get_batch_protected(self) -> list[T]: data = self._data.copy() self._data.clear() self._can_get = False return data - def get_batch_nowait(self): + def get_batch_nowait(self) -> list[T]: """Attempt to get the next batch from the queue, without blocking. Returns: @@ -110,7 +129,7 @@ def get_batch_nowait(self): raise _core.WouldBlock return self._get_batch_protected() - async def get_batch(self): + async def get_batch(self) -> list[T]: """Get the next batch from the queue, blocking as necessary. Returns: @@ -128,22 +147,14 @@ async def get_batch(self): finally: await _core.cancel_shielded_checkpoint() - def statistics(self): - """Return an object containing debugging information. - - Currently the following fields are defined: - - * ``qsize``: The number of items currently in the queue. - * ``tasks_waiting``: The number of tasks blocked on this queue's - :meth:`get_batch` method. - - """ - return _UnboundedQueueStats( + def statistics(self) -> UnboundedQueueStatistics: + """Return an :class:`UnboundedQueueStatistics` object containing debugging information.""" + return UnboundedQueueStatistics( qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting ) - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> list[T]: return await self.get_batch() diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index c084403eaa..2ad1a023fe 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -1,5 +1,7 @@ -import socket +from __future__ import annotations + import signal +import socket import warnings from .. import _core @@ -7,7 +9,7 @@ class WakeupSocketpair: - def __init__(self): + def __init__(self) -> None: self.wakeup_sock, self.write_sock = socket.socketpair() self.wakeup_sock.setblocking(False) self.write_sock.setblocking(False) @@ -27,26 +29,26 @@ def __init__(self): self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except OSError: pass - self.old_wakeup_fd = None + self.old_wakeup_fd: int | None = None - def wakeup_thread_and_signal_safe(self): + def wakeup_thread_and_signal_safe(self) -> None: try: self.write_sock.send(b"\x00") except BlockingIOError: pass - async def wait_woken(self): + async def wait_woken(self) -> None: await _core.wait_readable(self.wakeup_sock) self.drain() - def drain(self): + def drain(self) -> None: try: while True: self.wakeup_sock.recv(2**16) except BlockingIOError: pass - def wakeup_on_signals(self): + def wakeup_on_signals(self) -> None: assert self.old_wakeup_fd is None if not is_main_thread(): return @@ -64,7 +66,7 @@ def wakeup_on_signals(self): ) ) - def close(self): + def close(self) -> None: self.wakeup_sock.close() self.write_sock.close() if self.old_wakeup_fd is not None: diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index a1071519e9..a65a332c2f 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -1,6 +1,13 @@ -import cffi -import re +from __future__ import annotations + import enum +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import NoReturn, TypeAlias + +import cffi ################################################################ # Functions and types @@ -214,7 +221,8 @@ # being _MSC_VER >= 800) LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB) -ffi = cffi.FFI() +ffi = cffi.api.FFI() +CData: TypeAlias = cffi.api.FFI.CData ffi.cdef(LIB) kernel32 = ffi.dlopen("kernel32.dll") @@ -301,23 +309,33 @@ class IoControlCodes(enum.IntEnum): ################################################################ -def _handle(obj): +def _handle(obj: int | CData) -> CData: # For now, represent handles as either cffi HANDLEs or as ints. If you # try to pass in a file descriptor instead, it's not going to work # out. (For that msvcrt.get_osfhandle does the trick, but I don't know if # we'll actually need that for anything...) For sockets this doesn't # matter, Python never allocates an fd. So let's wait until we actually # encounter the problem before worrying about it. - if type(obj) is int: + if isinstance(obj, int): return ffi.cast("HANDLE", obj) - else: - return obj + return obj -def raise_winerror(winerror=None, *, filename=None, filename2=None): +def raise_winerror( + winerror: int | None = None, + *, + filename: str | None = None, + filename2: str | None = None, +) -> NoReturn: if winerror is None: - winerror, msg = ffi.getwinerror() + err = ffi.getwinerror() + if err is None: + raise RuntimeError("No error set?") + winerror, msg = err else: - _, msg = ffi.getwinerror(winerror) + err = ffi.getwinerror(winerror) + if err is None: + raise RuntimeError("No error set?") + _, msg = err # https://docs.python.org/3/library/exceptions.html#OSError raise OSError(0, msg, filename, winerror, filename2) diff --git a/trio/_core/tests/conftest.py b/trio/_core/tests/conftest.py deleted file mode 100644 index aca1f98a65..0000000000 --- a/trio/_core/tests/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest -import inspect - -# XX this should move into a global something -from ...testing import MockClock, trio_test - - -@pytest.fixture -def mock_clock(): - return MockClock() - - -@pytest.fixture -def autojump_clock(): - return MockClock(autojump_threshold=0) - - -# FIXME: split off into a package (or just make part of Trio's public -# interface?), with config file to enable? and I guess a mark option too; I -# guess it's useful with the class- and file-level marking machinery (where -# the raw @trio_test decorator isn't enough). -@pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): - if inspect.iscoroutinefunction(pyfuncitem.obj): - pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/trio/_core/tests/test_util.py b/trio/_core/tests/test_util.py deleted file mode 100644 index 5871ed8eef..0000000000 --- a/trio/_core/tests/test_util.py +++ /dev/null @@ -1 +0,0 @@ -import pytest diff --git a/trio/_deprecate.py b/trio/_deprecate.py index 7641baefd3..0a9553b854 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -1,10 +1,21 @@ +from __future__ import annotations + import sys +import warnings +from collections.abc import Callable from functools import wraps from types import ModuleType -import warnings +from typing import TYPE_CHECKING, ClassVar, TypeVar import attr +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + ArgsT = ParamSpec("ArgsT") + +RetT = TypeVar("RetT") + # We want our warnings to be visible by default (at least for now), but we # also want it to be possible to override that using the -W switch. AFAICT @@ -29,17 +40,24 @@ class TrioDeprecationWarning(FutureWarning): """ -def _url_for_issue(issue): +def _url_for_issue(issue: int) -> str: return f"https://github.com/python-trio/trio/issues/{issue}" -def _stringify(thing): +def _stringify(thing: object) -> str: if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"): return f"{thing.__module__}.{thing.__qualname__}" return str(thing) -def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): +def warn_deprecated( + thing: object, + version: str, + *, + issue: int | None, + instead: object, + stacklevel: int = 2, +) -> None: stacklevel += 1 msg = f"{_stringify(thing)} is deprecated since Trio {version}" if instead is None: @@ -53,12 +71,14 @@ def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): # @deprecated("0.2.0", issue=..., instead=...) # def ... -def deprecated(version, *, thing=None, issue, instead): - def do_wrap(fn): +def deprecated( + version: str, *, thing: object = None, issue: int | None, instead: object +) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: + def do_wrap(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: nonlocal thing @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: warn_deprecated(thing, version, instead=instead, issue=issue) return fn(*args, **kwargs) @@ -87,11 +107,17 @@ def wrapper(*args, **kwargs): return do_wrap -def deprecated_alias(old_qualname, new_fn, version, *, issue): +def deprecated_alias( + old_qualname: str, + new_fn: Callable[ArgsT, RetT], + version: str, + *, + issue: int | None, +) -> Callable[ArgsT, RetT]: @deprecated(version, issue=issue, instead=new_fn) @wraps(new_fn, assigned=("__module__", "__annotations__")) - def wrapper(*args, **kwargs): - "Deprecated alias." + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + """Deprecated alias.""" return new_fn(*args, **kwargs) wrapper.__qualname__ = old_qualname @@ -101,16 +127,18 @@ def wrapper(*args, **kwargs): @attr.s(frozen=True) class DeprecatedAttribute: - _not_set = object() + _not_set: ClassVar[object] = object() - value = attr.ib() - version = attr.ib() - issue = attr.ib() - instead = attr.ib(default=_not_set) + value: object = attr.ib() + version: str = attr.ib() + issue: int | None = attr.ib() + instead: object = attr.ib(default=_not_set) class _ModuleWithDeprecations(ModuleType): - def __getattr__(self, name): + __deprecated_attributes__: dict[str, DeprecatedAttribute] + + def __getattr__(self, name: str) -> object: if name in self.__deprecated_attributes__: info = self.__deprecated_attributes__[name] instead = info.instead @@ -124,9 +152,10 @@ def __getattr__(self, name): raise AttributeError(msg.format(self.__name__, name)) -def enable_attribute_deprecations(module_name): +def enable_attribute_deprecations(module_name: str) -> None: module = sys.modules[module_name] module.__class__ = _ModuleWithDeprecations + assert isinstance(module, _ModuleWithDeprecations) # Make sure that this is always defined so that # _ModuleWithDeprecations.__getattr__ can access it without jumping # through hoops or risking infinite recursion. diff --git a/trio/_dtls.py b/trio/_dtls.py index 910637455a..541144de07 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -6,38 +6,63 @@ # Hopefully they fix this before implementing DTLS 1.3, because it's a very different # protocol, and it's probably impossible to pull tricks like we do here. -import struct -import hmac -import os +from __future__ import annotations + import enum -from itertools import count -import weakref import errno +import hmac +import os +import struct import warnings +import weakref +from itertools import count +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Iterable, + Iterator, + TypeVar, + Union, +) +from weakref import ReferenceType, WeakValueDictionary import attr import trio -from trio._util import NoPublicConstructor, Final + +from ._util import Final, NoPublicConstructor + +if TYPE_CHECKING: + from types import TracebackType + + # See DTLSEndpoint.__init__ for why this is imported here + from OpenSSL import SSL + from OpenSSL.SSL import Context + from typing_extensions import Self, TypeAlias + + from trio.socket import Address, _SocketType MAX_UDP_PACKET_SIZE = 65527 -def packet_header_overhead(sock): +def packet_header_overhead(sock: _SocketType) -> int: if sock.family == trio.socket.AF_INET: return 28 else: return 48 -def worst_case_mtu(sock): +def worst_case_mtu(sock: _SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: return 1280 - packet_header_overhead(sock) -def best_guess_mtu(sock): +def best_guess_mtu(sock: _SocketType) -> int: return 1500 - packet_header_overhead(sock) @@ -99,14 +124,14 @@ class BadPacket(Exception): # ChangeCipherSpec is used during the handshake but has its own ContentType. # # Cannot fail. -def part_of_handshake_untrusted(packet): +def part_of_handshake_untrusted(packet: bytes) -> bool: # If the packet is too short, then slicing will successfully return a # short string, which will necessarily fail to match. return packet[3:5] == b"\x00\x00" # Cannot fail -def is_client_hello_untrusted(packet): +def is_client_hello_untrusted(packet: bytes) -> bool: try: return ( packet[0] == ContentType.handshake @@ -141,7 +166,7 @@ class Record: payload: bytes = attr.ib(repr=to_hex) -def records_untrusted(packet): +def records_untrusted(packet: bytes) -> Iterator[Record]: i = 0 while i < len(packet): try: @@ -159,7 +184,7 @@ def records_untrusted(packet): yield Record(ct, version, epoch_seqno, payload) -def encode_record(record): +def encode_record(record: Record) -> bytes: header = RECORD_HEADER.pack( record.content_type, record.version, @@ -188,7 +213,7 @@ class HandshakeFragment: frag: bytes = attr.ib(repr=to_hex) -def decode_handshake_fragment_untrusted(payload): +def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment: # Raises BadPacket if decoding fails try: ( @@ -218,7 +243,7 @@ def decode_handshake_fragment_untrusted(payload): ) -def encode_handshake_fragment(hsf): +def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes: hs_header = HANDSHAKE_MESSAGE_HEADER.pack( hsf.msg_type, hsf.msg_len.to_bytes(3, "big"), @@ -229,7 +254,7 @@ def encode_handshake_fragment(hsf): return hs_header + hsf.frag -def decode_client_hello_untrusted(packet): +def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]: # Raises BadPacket if parsing fails # Returns (record epoch_seqno, cookie from the packet, data that should be # hashed into cookie) @@ -325,12 +350,19 @@ class OpaqueHandshakeMessage: record: Record +_AnyHandshakeMessage: TypeAlias = Union[ + HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage +] + + # This takes a raw outgoing handshake volley that openssl generated, and # reconstructs the handshake messages inside it, so that we can repack them # into records while retransmitting. So the data ought to be well-behaved -- # it's not coming from the network. -def decode_volley_trusted(volley): - messages = [] +def decode_volley_trusted( + volley: bytes, +) -> list[_AnyHandshakeMessage]: + messages: list[_AnyHandshakeMessage] = [] messages_by_seq = {} for record in records_untrusted(volley): # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. @@ -374,13 +406,17 @@ def decode_volley_trusted(volley): class RecordEncoder: - def __init__(self): + def __init__(self) -> None: self._record_seq = count() - def set_first_record_number(self, n): + def set_first_record_number(self, n: int) -> None: self._record_seq = count(n) - def encode_volley(self, messages, mtu): + def encode_volley( + self, + messages: Iterable[_AnyHandshakeMessage], + mtu: int, + ) -> list[bytearray]: packets = [] packet = bytearray() for message in messages: @@ -512,13 +548,13 @@ def encode_volley(self, messages, mtu): COOKIE_LENGTH = 32 -def _current_cookie_tick(): +def _current_cookie_tick() -> int: return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) # Simple deterministic and invertible serializer -- i.e., a useful tool for converting # structured data into something we can cryptographically sign. -def _signable(*fields): +def _signable(*fields: bytes) -> bytes: out = [] for field in fields: out.append(struct.pack("!Q", len(field))) @@ -526,7 +562,9 @@ def _signable(*fields): return b"".join(out) -def _make_cookie(key, salt, tick, address, client_hello_bits): +def _make_cookie( + key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes +) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -542,7 +580,9 @@ def _make_cookie(key, salt, tick, address, client_hello_bits): return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] -def valid_cookie(key, cookie, address, client_hello_bits): +def valid_cookie( + key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes +) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -562,7 +602,9 @@ def valid_cookie(key, cookie, address, client_hello_bits): return False -def challenge_for(key, address, epoch_seqno, client_hello_bits): +def challenge_for( + key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes +) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() cookie = _make_cookie(key, salt, tick, address, client_hello_bits) @@ -602,12 +644,15 @@ def challenge_for(key, address, epoch_seqno, client_hello_bits): return packet -class _Queue: - def __init__(self, incoming_packets_buffer): - self.s, self.r = trio.open_memory_channel(incoming_packets_buffer) +_T = TypeVar("_T") + +class _Queue(Generic[_T]): + def __init__(self, incoming_packets_buffer: int | float): + self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer) -def _read_loop(read_fn): + +def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: chunks = [] while True: try: @@ -618,7 +663,9 @@ def _read_loop(read_fn): return b"".join(chunks) -async def handle_client_hello_untrusted(endpoint, address, packet): +async def handle_client_hello_untrusted( + endpoint: DTLSEndpoint, address: Address, packet: bytes +) -> None: if endpoint._listening_context is None: return @@ -691,7 +738,9 @@ async def handle_client_hello_untrusted(endpoint, address, packet): endpoint._incoming_connections_q.s.send_nowait(stream) -async def dtls_receive_loop(endpoint_ref, sock): +async def dtls_receive_loop( + endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType +) -> None: try: while True: try: @@ -726,7 +775,8 @@ async def dtls_receive_loop(endpoint_ref, sock): await stream._resend_final_volley() else: try: - stream._q.s.send_nowait(packet) + # mypy for some reason cannot determine type of _q + stream._q.s.send_nowait(packet) # type:ignore[has-type] except trio.WouldBlock: stream._packets_dropped_in_trio += 1 else: @@ -748,6 +798,17 @@ async def dtls_receive_loop(endpoint_ref, sock): @attr.frozen class DTLSChannelStatistics: + """Currently this has only one attribute: + + - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of + incoming packets from this peer that Trio successfully received from the + network, but then got dropped because the internal channel buffer was full. If + this is non-zero, then you might want to call ``receive`` more often, or use a + larger ``incoming_packets_buffer``, or just not worry about it because your + UDP-based protocol should be able to handle the occasional lost packet, right? + + """ + incoming_packets_dropped_in_trio: int @@ -767,7 +828,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint, peer_address, ctx): + def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -778,25 +839,32 @@ def __init__(self, endpoint, peer_address, ctx): # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to # support and isn't useful anyway -- especially for DTLS where it's equivalent # to just performing a new handshake. - ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) + ctx.set_options( + ( + SSL.OP_NO_QUERY_MTU + | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined] + ) + ) self._ssl = SSL.Connection(ctx) - self._handshake_mtu = None + self._handshake_mtu = 0 # This calls self._ssl.set_ciphertext_mtu, which is important, because if you # don't call it then openssl doesn't work. self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) self._replaced = False self._closed = False - self._q = _Queue(endpoint.incoming_packets_buffer) + self._q = _Queue[bytes](endpoint.incoming_packets_buffer) self._handshake_lock = trio.Lock() - self._record_encoder = RecordEncoder() + self._record_encoder: RecordEncoder = RecordEncoder() + + self._final_volley: list[_AnyHandshakeMessage] = [] - def _set_replaced(self): + def _set_replaced(self) -> None: self._replaced = True # Any packets we already received could maybe possibly still be processed, but # there are no more coming. So we close this on the sender side. self._q.s.close() - def _check_replaced(self): + def _check_replaced(self) -> None: if self._replaced: raise trio.BrokenResourceError( "peer tore down this connection to start a new one" @@ -809,7 +877,7 @@ def _check_replaced(self): # DTLS where packets are all independent and can be lost anyway. We do at least need # to handle receiving it properly though, which might be easier if we send it... - def close(self): + def close(self) -> None: """Close this connection. `DTLSChannel`\\s don't actually own any OS-level resources – the @@ -830,13 +898,18 @@ def close(self): # ClosedResourceError self._q.r.close() - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, *args): - self.close() + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() - async def aclose(self): + async def aclose(self) -> None: """Close this connection, but asynchronously. This is included to satisfy the `trio.abc.Channel` contract. It's @@ -846,7 +919,7 @@ async def aclose(self): self.close() await trio.lowlevel.checkpoint() - async def _send_volley(self, volley_messages): + async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None: packets = self._record_encoder.encode_volley( volley_messages, self._handshake_mtu ) @@ -854,10 +927,10 @@ async def _send_volley(self, volley_messages): async with self.endpoint._send_lock: await self.endpoint.socket.sendto(packet, self.peer_address) - async def _resend_final_volley(self): + async def _resend_final_volley(self) -> None: await self._send_volley(self._final_volley) - async def do_handshake(self, *, initial_retransmit_timeout=1.0): + async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None: """Perform the handshake. Calling this is optional – if you don't, then it will be automatically called @@ -890,16 +963,17 @@ async def do_handshake(self, *, initial_retransmit_timeout=1.0): return timeout = initial_retransmit_timeout - volley_messages = [] + volley_messages: list[_AnyHandshakeMessage] = [] volley_failed_sends = 0 - def read_volley(): + def read_volley() -> list[_AnyHandshakeMessage]: volley_bytes = _read_loop(self._ssl.bio_read) new_volley_messages = decode_volley_trusted(volley_bytes) if ( new_volley_messages and volley_messages and isinstance(new_volley_messages[0], HandshakeMessage) + and isinstance(volley_messages[0], HandshakeMessage) and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq ): # openssl decided to retransmit; discard because we handle @@ -987,7 +1061,7 @@ def read_volley(): self._handshake_mtu, worst_case_mtu(self.endpoint.socket) ) - async def send(self, data): + async def send(self, data: bytes) -> None: """Send a packet of data, securely.""" if self._closed: @@ -1003,7 +1077,7 @@ async def send(self, data): _read_loop(self._ssl.bio_read), self.peer_address ) - async def receive(self): + async def receive(self) -> bytes: """Fetch the next packet of data from this connection's peer, waiting if necessary. @@ -1029,7 +1103,7 @@ async def receive(self): if cleartext: return cleartext - def set_ciphertext_mtu(self, new_mtu): + def set_ciphertext_mtu(self, new_mtu: int) -> None: """Tells Trio the `largest amount of data that can be sent in a single packet to this peer `__. @@ -1064,7 +1138,7 @@ def set_ciphertext_mtu(self, new_mtu): self._handshake_mtu = new_mtu self._ssl.set_ciphertext_mtu(new_mtu) - def get_cleartext_mtu(self): + def get_cleartext_mtu(self) -> int: """Returns the largest number of bytes that you can pass in a single call to `send` while still fitting within the network-level MTU. @@ -1073,21 +1147,10 @@ def get_cleartext_mtu(self): """ if not self._did_handshake: raise trio.NeedHandshakeError - return self._ssl.get_cleartext_mtu() - - def statistics(self): - """Returns an object with statistics about this connection. + return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return] - Currently this has only one attribute: - - - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of - incoming packets from this peer that Trio successfully received from the - network, but then got dropped because the internal channel buffer was full. If - this is non-zero, then you might want to call ``receive`` more often, or use a - larger ``incoming_packets_buffer``, or just not worry about it because your - UDP-based protocol should be able to handle the occasional lost packet, right? - - """ + def statistics(self) -> DTLSChannelStatistics: + """Returns a `DTLSChannelStatistics` object with statistics about this connection.""" return DTLSChannelStatistics(self._packets_dropped_in_trio) @@ -1115,16 +1178,18 @@ class DTLSEndpoint(metaclass=Final): """ - def __init__(self, socket, *, incoming_packets_buffer=10): + def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL from OpenSSL import SSL - self.socket = None # for __del__, in case the next line raises + # for __del__, in case the next line raises + self._initialized: bool = False if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") - self.socket = socket + self._initialized = True + self.socket: _SocketType = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() @@ -1133,15 +1198,15 @@ def __init__(self, socket, *, incoming_packets_buffer=10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams = weakref.WeakValueDictionary() - self._listening_context = None - self._listening_key = None - self._incoming_connections_q = _Queue(float("inf")) + self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary() + self._listening_context: Context | None = None + self._listening_key: bytes | None = None + self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) self._send_lock = trio.Lock() self._closed = False self._receive_loop_spawned = False - def _ensure_receive_loop(self): + def _ensure_receive_loop(self) -> None: # We have to spawn this lazily, because on Windows it will immediately error out # if the socket isn't already bound -- which for clients might not happen until # after we send our first packet. @@ -1151,9 +1216,9 @@ def _ensure_receive_loop(self): ) self._receive_loop_spawned = True - def __del__(self): + def __del__(self) -> None: # Do nothing if this object was never fully constructed - if self.socket is None: + if not self._initialized: return # Close the socket in Trio context (if our Trio context still exists), so that # the background task gets notified about the closure and can exit. @@ -1167,7 +1232,7 @@ def __del__(self): f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self ) - def close(self): + def close(self) -> None: """Close this socket, and all associated DTLS connections. This object can also be used as a context manager. @@ -1179,19 +1244,31 @@ def close(self): stream.close() self._incoming_connections_q.s.close() - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, *args): - self.close() + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() - def _check_closed(self): + def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError + # async_fn cannot be typed with ParamSpec, since we don't accept + # kwargs. Can be typed with TypeVarTuple once it's fully supported + # in mypy. async def serve( - self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED - ): + self, + ssl_context: Context, + async_fn: Callable[..., Awaitable[object]], + *args: Any, + task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED, + ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. @@ -1235,7 +1312,7 @@ async def handler(dtls_channel): self._listening_context = ssl_context task_status.started() - async def handler_wrapper(stream): + async def handler_wrapper(stream: DTLSChannel) -> None: with stream: await async_fn(stream, *args) @@ -1245,7 +1322,7 @@ async def handler_wrapper(stream): finally: self._listening_context = None - def connect(self, address, ssl_context): + def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel: """Initiate an outgoing DTLS connection. Notice that this is a synchronous method. That's because it doesn't actually diff --git a/trio/_file_io.py b/trio/_file_io.py index 8c8425c775..6b79ae25b5 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -1,13 +1,39 @@ -from functools import partial +from __future__ import annotations + import io +from functools import partial +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + BinaryIO, + Callable, + Generic, + Iterable, + TypeVar, + Union, + overload, +) + +import trio -from .abc import AsyncResource from ._util import async_wraps +from .abc import AsyncResource -import trio +if TYPE_CHECKING: + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + StrOrBytesPath, + ) + from typing_extensions import Literal # This list is also in the docs, make sure to keep them in sync -_FILE_SYNC_ATTRS = { +_FILE_SYNC_ATTRS: set[str] = { "closed", "encoding", "errors", @@ -29,7 +55,7 @@ } # This list is also in the docs, make sure to keep them in sync -_FILE_ASYNC_METHODS = { +_FILE_ASYNC_METHODS: set[str] = { "flush", "read", "read1", @@ -48,59 +74,201 @@ } -class AsyncIOWrapper(AsyncResource): +FileT = TypeVar("FileT") +FileT_co = TypeVar("FileT_co", covariant=True) +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) +AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True) +AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True) + +# This is a little complicated. IO objects have a lot of methods, and which are available on +# different types varies wildly. We want to match the interface of whatever file we're wrapping. +# This pile of protocols each has one sync method/property, meaning they're going to be compatible +# with a file class that supports that method/property. The ones parameterized with AnyStr take +# either str or bytes depending. + +# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're +# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be +# conditional - it's only valid to call them if the object you're accessing them on is compatible +# with that type hint. By using the protocols, the type checker will be checking to see if the +# wrapped type has that method, and only allow the methods that do to be called. We can then alter +# the signature however it needs to match runtime behaviour. +# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types +if TYPE_CHECKING: + from typing_extensions import Buffer, Protocol + + # fmt: off + + class _HasClosed(Protocol): + @property + def closed(self) -> bool: ... + + class _HasEncoding(Protocol): + @property + def encoding(self) -> str: ... + + class _HasErrors(Protocol): + @property + def errors(self) -> str | None: ... + + class _HasFileNo(Protocol): + def fileno(self) -> int: ... + + class _HasIsATTY(Protocol): + def isatty(self) -> bool: ... + + class _HasNewlines(Protocol[T_co]): + # Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any. + @property + def newlines(self) -> T_co: ... + + class _HasReadable(Protocol): + def readable(self) -> bool: ... + + class _HasSeekable(Protocol): + def seekable(self) -> bool: ... + + class _HasWritable(Protocol): + def writable(self) -> bool: ... + + class _HasBuffer(Protocol): + @property + def buffer(self) -> BinaryIO: ... + + class _HasRaw(Protocol): + @property + def raw(self) -> io.RawIOBase: ... + + class _HasLineBuffering(Protocol): + @property + def line_buffering(self) -> bool: ... + + class _HasCloseFD(Protocol): + @property + def closefd(self) -> bool: ... + + class _HasName(Protocol): + @property + def name(self) -> str: ... + + class _HasMode(Protocol): + @property + def mode(self) -> str: ... + + class _CanGetValue(Protocol[AnyStr_co]): + def getvalue(self) -> AnyStr_co: ... + + class _CanGetBuffer(Protocol): + def getbuffer(self) -> memoryview: ... + + class _CanFlush(Protocol): + def flush(self) -> None: ... + + class _CanRead(Protocol[AnyStr_co]): + def read(self, size: int | None = ..., /) -> AnyStr_co: ... + + class _CanRead1(Protocol): + def read1(self, size: int | None = ..., /) -> bytes: ... + + class _CanReadAll(Protocol[AnyStr_co]): + def readall(self) -> AnyStr_co: ... + + class _CanReadInto(Protocol): + def readinto(self, buf: Buffer, /) -> int | None: ... + + class _CanReadInto1(Protocol): + def readinto1(self, buffer: Buffer, /) -> int: ... + + class _CanReadLine(Protocol[AnyStr_co]): + def readline(self, size: int = ..., /) -> AnyStr_co: ... + + class _CanReadLines(Protocol[AnyStr]): + def readlines(self, hint: int = ...) -> list[AnyStr]: ... + + class _CanSeek(Protocol): + def seek(self, target: int, whence: int = 0, /) -> int: ... + + class _CanTell(Protocol): + def tell(self) -> int: ... + + class _CanTruncate(Protocol): + def truncate(self, size: int | None = ..., /) -> int: ... + + class _CanWrite(Protocol[AnyStr_contra]): + def write(self, data: AnyStr_contra, /) -> int: ... + + class _CanWriteLines(Protocol[T_contra]): + # The lines parameter varies for bytes/str, so use a typevar to make the async match. + def writelines(self, lines: Iterable[T_contra], /) -> None: ... + + class _CanPeek(Protocol[AnyStr_co]): + def peek(self, size: int = 0, /) -> AnyStr_co: ... + + class _CanDetach(Protocol[T_co]): + # The T typevar will be the unbuffered/binary file this file wraps. + def detach(self) -> T_co: ... + + class _CanClose(Protocol): + def close(self) -> None: ... + + +# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a +# subtype of the protocols. +class AsyncIOWrapper(AsyncResource, Generic[FileT_co]): """A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous file object` interface. Wrapped methods that could block are executed in :meth:`trio.to_thread.run_sync`. - All properties and methods defined in in :mod:`~io` are exposed by this + All properties and methods defined in :mod:`~io` are exposed by this wrapper, if they exist in the wrapped file object. - """ - def __init__(self, file): + def __init__(self, file: FileT_co) -> None: self._wrapped = file @property - def wrapped(self): + def wrapped(self) -> FileT_co: """object: A reference to the wrapped file object""" return self._wrapped - def __getattr__(self, name): - if name in _FILE_SYNC_ATTRS: - return getattr(self._wrapped, name) - if name in _FILE_ASYNC_METHODS: - meth = getattr(self._wrapped, name) + if not TYPE_CHECKING: + + def __getattr__(self, name: str) -> object: + if name in _FILE_SYNC_ATTRS: + return getattr(self._wrapped, name) + if name in _FILE_ASYNC_METHODS: + meth = getattr(self._wrapped, name) - @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): - func = partial(meth, *args, **kwargs) - return await trio.to_thread.run_sync(func) + @async_wraps(self.__class__, self._wrapped.__class__, name) + async def wrapper(*args, **kwargs): + func = partial(meth, *args, **kwargs) + return await trio.to_thread.run_sync(func) - # cache the generated method - setattr(self, name, wrapper) - return wrapper + # cache the generated method + setattr(self, name, wrapper) + return wrapper - raise AttributeError(name) + raise AttributeError(name) - def __dir__(self): + def __dir__(self) -> Iterable[str]: attrs = set(super().__dir__()) attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a)) attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a)) return attrs - def __aiter__(self): + def __aiter__(self) -> AsyncIOWrapper[FileT_co]: return self - async def __anext__(self): + async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr: line = await self.readline() if line: return line else: raise StopAsyncIteration - async def detach(self): + async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]: """Like :meth:`io.BufferedIOBase.detach`, but async. This also re-wraps the result in a new :term:`asynchronous file object` @@ -111,7 +279,7 @@ async def detach(self): raw = await trio.to_thread.run_sync(self._wrapped.detach) return wrap_file(raw) - async def aclose(self): + async def aclose(self: AsyncIOWrapper[_CanClose]) -> None: """Like :meth:`io.IOBase.close`, but async. This is also shielded from cancellation; if a cancellation scope is @@ -125,18 +293,167 @@ async def aclose(self): await trio.lowlevel.checkpoint_if_cancelled() + if TYPE_CHECKING: + # fmt: off + # Based on typing.IO and io stubs. + @property + def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ... + @property + def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ... + @property + def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ... + @property + def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ... + @property + def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ... + @property + def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ... + @property + def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ... + @property + def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ... + @property + def name(self: AsyncIOWrapper[_HasName]) -> str: ... + @property + def mode(self: AsyncIOWrapper[_HasMode]) -> str: ... + + def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ... + def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ... + def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ... + def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ... + def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ... + def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ... + def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ... + async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ... + async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ... + async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ... + async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ... + async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ... + async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ... + async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ... + async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ... + async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ... + async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ... + async def write(self: AsyncIOWrapper[_CanWrite[AnyStr]], data: AnyStr, /) -> int: ... + async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ... + async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ... + async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ... + + +# Type hints are copied from builtin open. +_OpenFile = Union["StrOrBytesPath", int] +_Opener = Callable[[str, int], int] + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.TextIOWrapper]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.FileIO]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedRandom]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedWriter]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedReader]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: int, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[BinaryIO]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[IO[Any]]: + ... + async def open_file( - file, - mode="r", - buffering=-1, - encoding=None, - errors=None, - newline=None, - closefd=True, - opener=None, -): - """Asynchronous version of :func:`io.open`. + file: _OpenFile, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[Any]: + """Asynchronous version of :func:`open`. Returns: An :term:`asynchronous file object` @@ -161,7 +478,7 @@ async def open_file( return _file -def wrap_file(file): +def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]: """This wraps any file object in a wrapper that provides an asynchronous file object interface. @@ -179,7 +496,7 @@ def wrap_file(file): """ - def has(attr): + def has(attr: str) -> bool: return hasattr(file, attr) and callable(getattr(file, attr)) if not (has("close") and (has("read") or has("write"))): diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index c31b4fdbf3..e136b2e4bc 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,12 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + import attr import trio -from .abc import HalfCloseableStream - from trio._util import Final +from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + -async def aclose_forcefully(resource): +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) + + +async def aclose_forcefully(resource: AsyncResource) -> None: """Close an async resource or async generator immediately, without blocking to do any graceful cleanup. @@ -36,8 +47,17 @@ async def aclose_forcefully(resource): await resource.aclose() +def _is_halfclosable(stream: SendStream) -> TypeGuard[HalfCloseableStream]: + """Check if the stream has a send_eof() method.""" + return hasattr(stream, "send_eof") + + @attr.s(eq=False, hash=False) -class StapledStream(HalfCloseableStream, metaclass=Final): +class StapledStream( + HalfCloseableStream, + Generic[SendStreamT, ReceiveStreamT], + metaclass=Final, +): """This class `staples `__ together two unidirectional streams to make single bidirectional stream. @@ -72,34 +92,36 @@ class StapledStream(HalfCloseableStream, metaclass=Final): """ - send_stream = attr.ib() - receive_stream = attr.ib() + send_stream: SendStreamT = attr.ib() + receive_stream: ReceiveStreamT = attr.ib() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Calls ``self.send_stream.send_all``.""" return await self.send_stream.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls ``self.send_stream.wait_send_all_might_not_block``.""" return await self.send_stream.wait_send_all_might_not_block() - async def send_eof(self): + async def send_eof(self) -> None: """Shuts down the send side of the stream. - If ``self.send_stream.send_eof`` exists, then calls it. Otherwise, - calls ``self.send_stream.aclose()``. - + If :meth:`self.send_stream.send_eof() ` exists, + then this calls it. Otherwise, this calls + :meth:`self.send_stream.aclose() `. """ - if hasattr(self.send_stream, "send_eof"): - return await self.send_stream.send_eof() + stream = self.send_stream + if _is_halfclosable(stream): + return await stream.send_eof() else: - return await self.send_stream.aclose() + return await stream.aclose() - async def receive_some(self, max_bytes=None): + # we intentionally accept more types from the caller than we support returning + async def receive_some(self, max_bytes: int | None = None) -> bytes: """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) - async def aclose(self): + async def aclose(self) -> None: """Calls ``aclose`` on both underlying streams.""" try: await self.send_stream.aclose() diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py index 2028d30766..e6840eae97 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/trio/_highlevel_open_tcp_listeners.py @@ -1,8 +1,13 @@ +from __future__ import annotations + import errno import sys +from collections.abc import Awaitable, Callable from math import inf import trio +from trio import TaskStatus + from . import socket as tsocket if sys.version_info < (3, 11): @@ -22,7 +27,7 @@ # backpressure. If a connection gets stuck waiting in the backlog queue, then # from the peer's point of view the connection succeeded but then their # send/recv will stall until we get to it, possibly for a long time. OTOH if -# there isn't room in the backlog queue... then their connect stalls, possibly +# there isn't room in the backlog queue, then their connect stalls, possibly # for a long time, which is pretty much the same thing. # # A large backlog can also use a bit more kernel memory, but this seems fairly @@ -36,16 +41,24 @@ # so this is unnecessary -- we can just pass in "infinity" and get the maximum # that way. (Verified on Windows, Linux, macOS using # notes-to-self/measure-listen-backlog.py) -def _compute_backlog(backlog): - if backlog is None: - backlog = inf +def _compute_backlog(backlog: int | float | None) -> int: # Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are # missing overflow protection, so we apply our own overflow protection. # https://github.com/golang/go/issues/5030 + if isinstance(backlog, float): + # TODO: Remove when removing infinity support + # https://github.com/python-trio/trio/pull/2724#discussion_r1278541729 + if backlog != inf: + raise ValueError(f"Only accepts infinity, not {backlog!r}") + backlog = None + if backlog is None: + return 0xFFFF return min(backlog, 0xFFFF) -async def open_tcp_listeners(port, *, host=None, backlog=None): +async def open_tcp_listeners( + port: int, *, host: str | bytes | None = None, backlog: int | float | None = None +) -> list[trio.SocketListener]: """Create :class:`SocketListener` objects to listen for TCP connections. Args: @@ -61,7 +74,7 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): :func:`open_tcp_listeners` will bind to both the IPv4 wildcard address (``0.0.0.0``) and also the IPv6 wildcard address (``::``). - host (str, bytes-like, or None): The local interface to bind to. This is + host (str, bytes, or None): The local interface to bind to. This is passed to :func:`~socket.getaddrinfo` with the ``AI_PASSIVE`` flag set. @@ -77,13 +90,16 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): all interfaces, pass the family-specific wildcard address: ``"0.0.0.0"`` for IPv4-only and ``"::"`` for IPv6-only. - backlog (int or None): The listen backlog to use. If you leave this as - ``None`` then Trio will pick a good default. (Currently: whatever + backlog (int, math.inf, or None): The listen backlog to use. If you leave this as + ``None`` or ``math.inf`` then Trio will pick a good default. (Currently: whatever your system has configured as the maximum backlog.) Returns: list of :class:`SocketListener` + Raises: + :class:`TypeError` if invalid arguments. + """ # getaddrinfo sometimes allows port=None, sometimes not (depending on # whether host=None). And on some systems it treats "" as 0, others it @@ -92,7 +108,7 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): if not isinstance(port, int): raise TypeError(f"port must be an int not {port!r}") - backlog = _compute_backlog(backlog) + computed_backlog = _compute_backlog(backlog) addresses = await tsocket.getaddrinfo( host, port, type=tsocket.SOCK_STREAM, flags=tsocket.AI_PASSIVE @@ -125,7 +141,7 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1) await sock.bind(sockaddr) - sock.listen(backlog) + sock.listen(computed_backlog) listeners.append(trio.SocketListener(sock)) except: @@ -149,14 +165,14 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): async def serve_tcp( - handler, - port, + handler: Callable[[trio.SocketStream], Awaitable[object]], + port: int, *, - host=None, - backlog=None, - handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED, -): + host: str | bytes | None = None, + backlog: int | float | None = None, + handler_nursery: trio.Nursery | None = None, + task_status: TaskStatus[list[trio.SocketListener]] = trio.TASK_STATUS_IGNORED, +) -> None: """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 0fcffbcb06..0c4e8a4a8d 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import sys +from collections.abc import Generator from contextlib import contextmanager +from socket import AddressFamily, SocketKind +from typing import TYPE_CHECKING import trio from trio._core._multierror import MultiError -from trio.socket import getaddrinfo, SOCK_STREAM, socket +from trio.socket import SOCK_STREAM, Address, _SocketType, getaddrinfo, socket if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -109,8 +114,8 @@ @contextmanager -def close_all(): - sockets_to_close = set() +def close_all() -> Generator[set[_SocketType], None, None]: + sockets_to_close: set[_SocketType] = set() try: yield sockets_to_close finally: @@ -126,7 +131,17 @@ def close_all(): raise MultiError(errs) -def reorder_for_rfc_6555_section_5_4(targets): +def reorder_for_rfc_6555_section_5_4( + targets: list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ] +) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address # families (e.g. IPv4 and IPv6), then you should make sure that your first # and second attempts use different families: @@ -144,7 +159,7 @@ def reorder_for_rfc_6555_section_5_4(targets): break -def format_host_port(host, port): +def format_host_port(host: str | bytes, port: int) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return f"[{host}]:{port}" @@ -173,8 +188,12 @@ def format_host_port(host, port): # AF_INET6: "..."} # this might be simpler after async def open_tcp_stream( - host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None -): + host: str | bytes, + port: int, + *, + happy_eyeballs_delay: float | None = DEFAULT_DELAY, + local_address: str | None = None, +) -> trio.abc.Stream: """Connect to the given host and port over TCP. If the given ``host`` has multiple IP addresses associated with it, then @@ -212,9 +231,9 @@ async def open_tcp_stream( port (int): The port to connect to. - happy_eyeballs_delay (float): How many seconds to wait for each + happy_eyeballs_delay (float or None): How many seconds to wait for each connection attempt to succeed or fail before getting impatient and - starting another one in parallel. Set to `math.inf` if you want + starting another one in parallel. Set to `None` if you want to limit to only one connection attempt at a time (like :func:`socket.create_connection`). Default: 0.25 (250 ms). @@ -247,9 +266,8 @@ async def open_tcp_stream( # To keep our public API surface smaller, rule out some cases that # getaddrinfo will accept in some circumstances, but that act weird or # have non-portable behavior or are just plain not useful. - # No type check on host though b/c we want to allow bytes-likes. - if host is None: - raise ValueError("host cannot be None") + if not isinstance(host, (str, bytes)): + raise ValueError(f"host must be str or bytes, not {host!r}") if not isinstance(port, int): raise TypeError(f"port must be int, not {port!r}") @@ -274,7 +292,7 @@ async def open_tcp_stream( # Keeps track of the socket that we're going to complete with, # need to make sure this isn't automatically closed - winning_socket = None + winning_socket: _SocketType | None = None # Try connecting to the specified address. Possible outcomes: # - success: record connected socket in winning_socket and cancel @@ -283,7 +301,11 @@ async def open_tcp_stream( # the next connection attempt to start early # code needs to ensure sockets can be closed appropriately in the # face of crash or cancellation - async def attempt_connect(socket_args, sockaddr, attempt_failed): + async def attempt_connect( + socket_args: tuple[AddressFamily, SocketKind, int], + sockaddr: Address, + attempt_failed: trio.Event, + ) -> None: nonlocal winning_socket try: @@ -334,7 +356,7 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed): except OSError: raise OSError( f"local_address={local_address!r} is incompatible " - f"with remote address {sockaddr}" + f"with remote address {sockaddr!r}" ) await sock.connect(sockaddr) @@ -355,12 +377,23 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed): # nursery spawns a task for each connection attempt, will be # cancelled by the task that gets a successful connection async with trio.open_nursery() as nursery: - for *sa, _, addr in targets: + for address_family, socket_type, proto, _, addr in targets: # create an event to indicate connection failure, # allowing the next target to be tried early attempt_failed = trio.Event() - nursery.start_soon(attempt_connect, sa, addr, attempt_failed) + # workaround to check types until typing of nursery.start_soon improved + if TYPE_CHECKING: + await attempt_connect( + (address_family, socket_type, proto), addr, attempt_failed + ) + + nursery.start_soon( + attempt_connect, + (address_family, socket_type, proto), + addr, + attempt_failed, + ) # give this attempt at most this time before moving on with trio.move_on_after(happy_eyeballs_delay): diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index e5aba4695f..c05b8f3fc8 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -1,8 +1,21 @@ +from __future__ import annotations + import os +from collections.abc import Generator from contextlib import contextmanager +from typing import Protocol, TypeVar import trio -from trio.socket import socket, SOCK_STREAM +from trio.socket import SOCK_STREAM, socket + + +class Closable(Protocol): + def close(self) -> None: + ... + + +CloseT = TypeVar("CloseT", bound=Closable) + try: from trio.socket import AF_UNIX @@ -13,7 +26,7 @@ @contextmanager -def close_on_error(obj): +def close_on_error(obj: CloseT) -> Generator[CloseT, None, None]: try: yield obj except: @@ -21,7 +34,9 @@ def close_on_error(obj): raise -async def open_unix_socket(filename): +async def open_unix_socket( + filename: str | bytes | os.PathLike[str] | os.PathLike[bytes], +) -> trio.SocketStream: """Opens a connection to the specified `Unix domain socket `__. diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index 0585fa516f..d5c7a3bdad 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import errno import logging import os +from typing import Any, Awaitable, Callable, NoReturn, TypeVar import trio @@ -20,14 +23,23 @@ LOGGER = logging.getLogger("trio.serve_listeners") -async def _run_handler(stream, handler): +StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource) +ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) +Handler = Callable[[StreamT], Awaitable[object]] + + +async def _run_handler(stream: StreamT, handler: Handler[StreamT]) -> None: try: await handler(stream) finally: await trio.aclose_forcefully(stream) -async def _serve_one_listener(listener, handler_nursery, handler): +async def _serve_one_listener( + listener: trio.abc.Listener[StreamT], + handler_nursery: trio.Nursery, + handler: Handler[StreamT], +) -> NoReturn: async with listener: while True: try: @@ -48,9 +60,21 @@ async def _serve_one_listener(listener, handler_nursery, handler): handler_nursery.start_soon(_run_handler, stream, handler) -async def serve_listeners( - handler, listeners, *, handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED -): +# This cannot be typed correctly, we need generic typevar bounds / HKT to indicate the +# relationship between StreamT & ListenerT. +# https://github.com/python/typing/issues/1226 +# https://github.com/python/typing/issues/548 + + +# It does never return (since _serve_one_listener never completes), but type checkers can't +# understand nurseries. +async def serve_listeners( # type: ignore[misc] + handler: Handler[StreamT], + listeners: list[ListenerT], + *, + handler_nursery: trio.Nursery | None = None, + task_status: trio.TaskStatus[list[ListenerT]] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: r"""Listen for incoming connections on ``listeners``, and for each one start a task running ``handler(stream)``. diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 1e8dc16ebc..f8d01cd755 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -1,13 +1,22 @@ # "High-level" networking interface +from __future__ import annotations import errno +from collections.abc import Generator from contextlib import contextmanager +from typing import TYPE_CHECKING, overload import trio + from . import socket as tsocket from ._util import ConflictDetector, Final from .abc import HalfCloseableStream, Listener +if TYPE_CHECKING: + from typing_extensions import Buffer + + from ._socket import _SocketType as SocketType + # XX TODO: this number was picked arbitrarily. We should do experiments to # tune it. (Or make it dynamic -- one idea is to start small and increase it # if we observe single reads filling up the whole buffer, at least within some @@ -23,7 +32,7 @@ @contextmanager -def _translate_socket_errors_to_stream_errors(): +def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]: try: yield except OSError as exc: @@ -57,7 +66,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -91,7 +100,7 @@ def __init__(self, socket): except OSError: pass - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: if self.socket.did_shutdown_SHUT_WR: raise trio.ClosedResourceError("can't send data after sending EOF") with self._send_conflict_detector: @@ -108,14 +117,14 @@ async def send_all(self, data): sent = await self.socket.send(remaining) total_sent += sent - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self.socket.fileno() == -1: raise trio.ClosedResourceError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() - async def send_eof(self): + async def send_eof(self) -> None: with self._send_conflict_detector: await trio.lowlevel.checkpoint() # On macOS, calling shutdown a second time raises ENOTCONN, but @@ -125,7 +134,7 @@ async def send_eof(self): with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: @@ -133,21 +142,53 @@ async def receive_some(self, max_bytes=None): with _translate_socket_errors_to_stream_errors(): return await self.socket.recv(max_bytes) - async def aclose(self): + async def aclose(self) -> None: self.socket.close() await trio.lowlevel.checkpoint() # __aenter__, __aexit__ inherited from HalfCloseableStream are OK - def setsockopt(self, level, option, value): + @overload + def setsockopt(self, level: int, option: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, level: int, option: int, value: None, length: int) -> None: + ... + + def setsockopt( + self, + level: int, + option: int, + value: int | Buffer | None, + length: int | None = None, + ) -> None: """Set an option on the underlying socket. See :meth:`socket.socket.setsockopt` for details. """ - return self.socket.setsockopt(level, option, value) - - def getsockopt(self, level, option, buffersize=0): + if length is None: + if value is None: + raise TypeError( + "invalid value for argument 'value', must not be None when specifying length" + ) + return self.socket.setsockopt(level, option, value) + if value is not None: + raise TypeError( + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen" + ) + return self.socket.setsockopt(level, option, value, length) + + @overload + def getsockopt(self, level: int, option: int) -> int: + ... + + @overload + def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: + ... + + def getsockopt(self, level: int, option: int, buffersize: int = 0) -> int | bytes: """Check the current value of an option on the underlying socket. See :meth:`socket.socket.getsockopt` for details. @@ -305,7 +346,7 @@ def getsockopt(self, level, option, buffersize=0): ] # Not all errnos are defined on all platforms -_ignorable_accept_errnos = set() +_ignorable_accept_errnos: set[int] = set() for name in _ignorable_accept_errno_names: try: _ignorable_accept_errnos.add(getattr(errno, name)) @@ -330,7 +371,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -346,7 +387,7 @@ def __init__(self, socket): self.socket = socket - async def accept(self): + async def accept(self) -> SocketStream: """Accept an incoming connection. Returns: @@ -374,7 +415,7 @@ async def accept(self): else: return SocketStream(sock) - async def aclose(self): + async def aclose(self) -> None: """Close this listener and its underlying socket.""" self.socket.close() await trio.lowlevel.checkpoint() diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index 19b1ff8777..1647f373c2 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -1,5 +1,10 @@ -import trio +from __future__ import annotations + import ssl +from collections.abc import Awaitable, Callable +from typing import NoReturn + +import trio from ._highlevel_open_tcp_stream import DEFAULT_DELAY @@ -14,13 +19,13 @@ # So... let's punt on that for now. Hopefully we'll be getting a new Python # TLS API soon and can revisit this then. async def open_ssl_over_tcp_stream( - host, - port, + host: str | bytes, + port: int, *, - https_compatible=False, - ssl_context=None, - happy_eyeballs_delay=DEFAULT_DELAY, -): + https_compatible: bool = False, + ssl_context: ssl.SSLContext | None = None, + happy_eyeballs_delay: float | None = DEFAULT_DELAY, +) -> trio.SSLStream: """Make a TLS-encrypted Connection to the given host and port over TCP. This is a convenience wrapper that calls :func:`open_tcp_stream` and @@ -62,8 +67,13 @@ async def open_ssl_over_tcp_stream( async def open_ssl_over_tcp_listeners( - port, ssl_context, *, host=None, https_compatible=False, backlog=None -): + port: int, + ssl_context: ssl.SSLContext, + *, + host: str | bytes | None = None, + https_compatible: bool = False, + backlog: int | float | None = None, +) -> list[trio.SSLListener]: """Start listening for SSL/TLS-encrypted TCP connections to the given port. Args: @@ -85,16 +95,16 @@ async def open_ssl_over_tcp_listeners( async def serve_ssl_over_tcp( - handler, - port, - ssl_context, + handler: Callable[[trio.SSLStream], Awaitable[object]], + port: int, + ssl_context: ssl.SSLContext, *, - host=None, - https_compatible=False, - backlog=None, - handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED, -): + host: str | bytes | None = None, + https_compatible: bool = False, + backlog: int | float | None = None, + handler_nursery: trio.Nursery | None = None, + task_status: trio.TaskStatus[list[trio.SSLListener]] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. diff --git a/trio/_path.py b/trio/_path.py index ea8cf98c34..c2763e03af 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -1,49 +1,96 @@ -# type: ignore +from __future__ import annotations -from functools import wraps, partial +import inspect import os -import types import pathlib +import sys +import types +from collections.abc import Awaitable, Callable, Iterable +from functools import partial, wraps +from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper +from typing import ( + IO, + TYPE_CHECKING, + Any, + BinaryIO, + ClassVar, + TypeVar, + Union, + cast, + overload, +) import trio -from trio._util import async_wraps, Final +from trio._file_io import AsyncIOWrapper as _AsyncIOWrapper +from trio._util import Final, async_wraps + +if TYPE_CHECKING: + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + ) + from typing_extensions import Concatenate, Literal, ParamSpec, TypeAlias + + P = ParamSpec("P") + +T = TypeVar("T") +StrPath: TypeAlias = Union[str, "os.PathLike[str]"] # Only subscriptable in 3.9+ # re-wrap return value from methods that return new instances of pathlib.Path -def rewrap_path(value): +def rewrap_path(value: T) -> T | Path: if isinstance(value, pathlib.Path): - value = Path(value) - return value + return Path(value) + else: + return value -def _forward_factory(cls, attr_name, attr): +def _forward_factory( + cls: AsyncAutoWrapperType, + attr_name: str, + attr: Callable[Concatenate[pathlib.Path, P], T], +) -> Callable[Concatenate[Path, P], T | Path]: @wraps(attr) - def wrapper(self, *args, **kwargs): + def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> T | Path: attr = getattr(self._wrapped, attr_name) value = attr(*args, **kwargs) return rewrap_path(value) + # Assigning this makes inspect and therefore Sphinx show the original parameters. + # It's not defined on functions normally though, this is a custom attribute. + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) + return wrapper -def _forward_magic(cls, attr): +def _forward_magic( + cls: AsyncAutoWrapperType, attr: Callable[..., T] +) -> Callable[..., Path | T]: sentinel = object() @wraps(attr) - def wrapper(self, other=sentinel): + def wrapper(self: Path, other: object = sentinel) -> Path | T: if other is sentinel: return attr(self._wrapped) if isinstance(other, cls): - other = other._wrapped + other = cast(Path, other)._wrapped value = attr(self._wrapped, other) return rewrap_path(value) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) return wrapper -def iter_wrapper_factory(cls, meth_name): +def iter_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> Callable[Concatenate[Path, P], Awaitable[Iterable[Path]]]: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Iterable[Path]: meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) # Make sure that the full iteration is performed in the thread @@ -54,9 +101,11 @@ async def wrapper(self, *args, **kwargs): return wrapper -def thread_wrapper_factory(cls, meth_name): +def thread_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> Callable[Concatenate[Path, P], Awaitable[Path]]: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Path: meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) @@ -65,20 +114,31 @@ async def wrapper(self, *args, **kwargs): return wrapper -def classmethod_wrapper_factory(cls, meth_name): - @classmethod +def classmethod_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> classmethod: # type: ignore[type-arg] @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(cls, *args, **kwargs): + async def wrapper(cls: type[Path], *args: Any, **kwargs: Any) -> Path: # type: ignore[misc] # contains Any meth = getattr(cls._wraps, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) return rewrap_path(value) - return wrapper + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(getattr(cls._wraps, meth_name)) + return classmethod(wrapper) class AsyncAutoWrapperType(Final): - def __init__(cls, name, bases, attrs): + _forwards: type + _wraps: type + _forward_magic: list[str] + _wrap_iter: list[str] + _forward: list[str] + + def __init__( + cls, name: str, bases: tuple[type, ...], attrs: dict[str, object] + ) -> None: super().__init__(name, bases, attrs) cls._forward = [] @@ -87,7 +147,7 @@ def __init__(cls, name, bases, attrs): type(cls).generate_magic(cls, attrs) type(cls).generate_iter(cls, attrs) - def generate_forwards(cls, attrs): + def generate_forwards(cls, attrs: dict[str, object]) -> None: # forward functions of _forwards for attr_name, attr in cls._forwards.__dict__.items(): if attr_name.startswith("_") or attr_name in attrs: @@ -101,8 +161,9 @@ def generate_forwards(cls, attrs): else: raise TypeError(attr_name, type(attr)) - def generate_wraps(cls, attrs): + def generate_wraps(cls, attrs: dict[str, object]) -> None: # generate wrappers for functions of _wraps + wrapper: classmethod | Callable[..., object] # type: ignore[type-arg] for attr_name, attr in cls._wraps.__dict__.items(): # .z. exclude cls._wrap_iter if attr_name.startswith("_") or attr_name in attrs: @@ -112,22 +173,27 @@ def generate_wraps(cls, attrs): setattr(cls, attr_name, wrapper) elif isinstance(attr, types.FunctionType): wrapper = thread_wrapper_factory(cls, attr_name) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) setattr(cls, attr_name, wrapper) else: raise TypeError(attr_name, type(attr)) - def generate_magic(cls, attrs): + def generate_magic(cls, attrs: dict[str, object]) -> None: # generate wrappers for magic for attr_name in cls._forward_magic: attr = getattr(cls._forwards, attr_name) wrapper = _forward_magic(cls, attr) setattr(cls, attr_name, wrapper) - def generate_iter(cls, attrs): + def generate_iter(cls, attrs: dict[str, object]) -> None: # generate wrappers for methods that return iterators + wrapper: Callable[..., object] for attr_name, attr in cls._wraps.__dict__.items(): if attr_name in cls._wrap_iter: wrapper = iter_wrapper_factory(cls, attr_name) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) setattr(cls, attr_name, wrapper) @@ -137,9 +203,10 @@ class Path(metaclass=AsyncAutoWrapperType): """ - _wraps = pathlib.Path - _forwards = pathlib.PurePath - _forward_magic = [ + _forward: ClassVar[list[str]] + _wraps: ClassVar[type] = pathlib.Path + _forwards: ClassVar[type] = pathlib.PurePath + _forward_magic: ClassVar[list[str]] = [ "__str__", "__bytes__", "__truediv__", @@ -151,28 +218,110 @@ class Path(metaclass=AsyncAutoWrapperType): "__ge__", "__hash__", ] - _wrap_iter = ["glob", "rglob", "iterdir"] + _wrap_iter: ClassVar[list[str]] = ["glob", "rglob", "iterdir"] - def __init__(self, *args): + def __init__(self, *args: StrPath) -> None: self._wrapped = pathlib.Path(*args) - def __getattr__(self, name): - if name in self._forward: - value = getattr(self._wrapped, name) - return rewrap_path(value) - raise AttributeError(name) + # type checkers allow accessing any attributes on class instances with `__getattr__` + # so we hide it behind a type guard forcing it to rely on the hardcoded attribute + # list below. + if not TYPE_CHECKING: - def __dir__(self): - return super().__dir__() + self._forward + def __getattr__(self, name): + if name in self._forward: + value = getattr(self._wrapped, name) + return rewrap_path(value) + raise AttributeError(name) - def __repr__(self): + def __dir__(self) -> list[str]: + return [*super().__dir__(), *self._forward] + + def __repr__(self) -> str: return f"trio.Path({repr(str(self))})" - def __fspath__(self): + def __fspath__(self) -> str: return os.fspath(self._wrapped) - @wraps(pathlib.Path.open) - async def open(self, *args, **kwargs): + @overload + def open( + self, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> _AsyncIOWrapper[TextIOWrapper]: + ... + + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[FileIO]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedRandom]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedWriter]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedReader]: + ... + + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: int = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BinaryIO]: + ... + + @overload + def open( + self, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> _AsyncIOWrapper[IO[Any]]: + ... + + @wraps(pathlib.Path.open) # type: ignore[misc] # Overload return mismatch. + async def open(self, *args: Any, **kwargs: Any) -> _AsyncIOWrapper[IO[Any]]: """Open the file pointed to by the path, like the :func:`trio.open_file` function does. @@ -182,6 +331,104 @@ async def open(self, *args, **kwargs): value = await trio.to_thread.run_sync(func) return trio.wrap_file(value) + if TYPE_CHECKING: + # the dunders listed in _forward_magic that aren't seen otherwise + # fmt: off + def __bytes__(self) -> bytes: ... + def __truediv__(self, other: StrPath) -> Path: ... + def __rtruediv__(self, other: StrPath) -> Path: ... + + # wrapped methods handled by __getattr__ + async def absolute(self) -> Path: ... + async def as_posix(self) -> str: ... + async def as_uri(self) -> str: ... + + if sys.version_info >= (3, 10): + async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: ... + async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None: ... + else: + async def stat(self) -> os.stat_result: ... + async def chmod(self, mode: int) -> None: ... + + @classmethod + async def cwd(self) -> Path: ... + + async def exists(self) -> bool: ... + async def expanduser(self) -> Path: ... + async def glob(self, pattern: str) -> Iterable[Path]: ... + async def home(self) -> Path: ... + async def is_absolute(self) -> bool: ... + async def is_block_device(self) -> bool: ... + async def is_char_device(self) -> bool: ... + async def is_dir(self) -> bool: ... + async def is_fifo(self) -> bool: ... + async def is_file(self) -> bool: ... + async def is_reserved(self) -> bool: ... + async def is_socket(self) -> bool: ... + async def is_symlink(self) -> bool: ... + async def iterdir(self) -> Iterable[Path]: ... + async def joinpath(self, *other: StrPath) -> Path: ... + async def lchmod(self, mode: int) -> None: ... + async def lstat(self) -> os.stat_result: ... + async def match(self, path_pattern: str) -> bool: ... + async def mkdir(self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False) -> None: ... + async def read_bytes(self) -> bytes: ... + async def read_text(self, encoding: str | None = None, errors: str | None = None) -> str: ... + async def relative_to(self, *other: StrPath) -> Path: ... + + if sys.version_info >= (3, 8): + def rename(self, target: str | pathlib.PurePath) -> Path: ... + def replace(self, target: str | pathlib.PurePath) -> Path: ... + else: + def rename(self, target: str | pathlib.PurePath) -> None: ... + def replace(self, target: str | pathlib.PurePath) -> None: ... + + async def resolve(self, strict: bool = False) -> Path: ... + async def rglob(self, pattern: str) -> Iterable[Path]: ... + async def rmdir(self) -> None: ... + async def samefile(self, other_path: str | bytes | int | Path) -> bool: ... + async def symlink_to(self, target: str | Path, target_is_directory: bool = False) -> None: ... + async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None: ... + if sys.version_info >= (3, 8): + def unlink(self, missing_ok: bool = False) -> None: ... + else: + def unlink(self) -> None: ... + async def with_name(self, name: str) -> Path: ... + async def with_suffix(self, suffix: str) -> Path: ... + async def write_bytes(self, data: bytes) -> int: ... + + if sys.version_info >= (3, 10): + async def write_text( + self, data: str, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> int: ... + else: + async def write_text( + self, data: str, + encoding: str | None = None, + errors: str | None = None, + ) -> int: ... + + if sys.platform != "win32": + async def owner(self) -> str: ... + async def group(self) -> str: ... + async def is_mount(self) -> bool: ... + + if sys.version_info >= (3, 9): + async def is_relative_to(self, *other: StrPath) -> bool: ... + async def with_stem(self, stem: str) -> Path: ... + async def readlink(self) -> Path: ... + if sys.version_info >= (3, 10): + async def hardlink_to(self, target: str | pathlib.Path) -> None: ... + if sys.version_info < (3, 12): + async def link_to(self, target: StrPath | bytes) -> None: ... + if sys.version_info >= (3, 12): + async def is_junction(self) -> bool: ... + walk: Any # TODO + async def with_segments(self, *pathsegments: StrPath) -> Path: ... + Path.iterdir.__doc__ = """ Like :meth:`pathlib.Path.iterdir`, but async. @@ -203,4 +450,6 @@ async def open(self, *args, **kwargs): # sense than inventing our own special docstring for this. del Path.absolute.__doc__ +# TODO: This is likely not supported by all the static tools out there, see discussion in +# https://github.com/python-trio/trio/pull/2631#discussion_r1185612528 os.PathLike.register(Path) diff --git a/trio/_path.pyi b/trio/_path.pyi deleted file mode 100644 index 85a8e1f960..0000000000 --- a/trio/_path.pyi +++ /dev/null @@ -1 +0,0 @@ -class Path: ... diff --git a/trio/_signals.py b/trio/_signals.py index cee3b7db53..fe2bde946e 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -1,9 +1,10 @@ import signal -from contextlib import contextmanager from collections import OrderedDict +from contextlib import contextmanager import trio -from ._util import signal_raise, is_main_thread, ConflictDetector + +from ._util import ConflictDetector, is_main_thread, signal_raise # Discussion of signal handling strategies: # diff --git a/trio/_socket.py b/trio/_socket.py index b12126f7e1..2834a5b055 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,15 +1,50 @@ +from __future__ import annotations + import os -import sys import select import socket as _stdlib_socket +import sys from functools import wraps as _wraps -from typing import TYPE_CHECKING +from operator import index +from socket import AddressFamily, SocketKind +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Literal, + NoReturn, + SupportsIndex, + Tuple, + TypeVar, + Union, + overload, +) import idna as _idna import trio + from . import _core +if TYPE_CHECKING: + from collections.abc import Iterable + from types import TracebackType + + from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias + + from ._abc import HostnameResolver, SocketFactory + + P = ParamSpec("P") + + +T = TypeVar("T") + +# must use old-style typing because it's evaluated at runtime +Address: TypeAlias = Union[ + str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int] +] + # Usage: # @@ -20,20 +55,27 @@ # return await do_it_properly_with_a_check_point() # class _try_sync: - def __init__(self, blocking_exc_override=None): + def __init__( + self, blocking_exc_override: Callable[[BaseException], bool] | None = None + ): self._blocking_exc_override = blocking_exc_override - def _is_blocking_io_error(self, exc): + def _is_blocking_io_error(self, exc: BaseException) -> bool: if self._blocking_exc_override is None: return isinstance(exc, BlockingIOError) else: return self._blocking_exc_override(exc) - async def __aenter__(self): + async def __aenter__(self) -> None: await trio.lowlevel.checkpoint_if_cancelled() - async def __aexit__(self, etype, value, tb): - if value is not None and self._is_blocking_io_error(value): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + if exc_value is not None and self._is_blocking_io_error(exc_value): # Discard the exception and fall through to the code below the # block return True @@ -43,27 +85,17 @@ async def __aexit__(self, etype, value, tb): return False -################################################################ -# CONSTANTS -################################################################ - -try: - from socket import IPPROTO_IPV6 -except ImportError: - # Before Python 3.8, Windows is missing IPPROTO_IPV6 - # https://bugs.python.org/issue29515 - if sys.platform == "win32": # pragma: no branch - IPPROTO_IPV6 = 41 - ################################################################ # Overrides ################################################################ -_resolver = _core.RunVar("hostname_resolver") -_socket_factory = _core.RunVar("socket_factory") +_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver") +_socket_factory: _core.RunVar[SocketFactory | None] = _core.RunVar("socket_factory") -def set_custom_hostname_resolver(hostname_resolver): +def set_custom_hostname_resolver( + hostname_resolver: HostnameResolver | None, +) -> HostnameResolver | None: """Set a custom hostname resolver. By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions @@ -95,7 +127,9 @@ def set_custom_hostname_resolver(hostname_resolver): return old -def set_custom_socket_factory(socket_factory): +def set_custom_socket_factory( + socket_factory: SocketFactory | None, +) -> SocketFactory | None: """Set a custom socket object factory. This function allows you to replace Trio's normal socket class with a @@ -129,7 +163,23 @@ def set_custom_socket_factory(socket_factory): _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV -async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): +# It would be possible to @overload the return value depending on Literal[AddressFamily.INET/6], but should probably be added in typeshed first +async def getaddrinfo( + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: """Look up a numeric address given a name. Arguments and return values are identical to :func:`socket.getaddrinfo`, @@ -150,7 +200,7 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): # skip the whole thread thing, which seems worthwhile. So we try first # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that # fails with EAI_NONAME: - def numeric_only_failure(exc): + def numeric_only_failure(exc: BaseException) -> bool: return ( isinstance(exc, _stdlib_socket.gaierror) and exc.errno == _stdlib_socket.EAI_NONAME @@ -192,7 +242,9 @@ def numeric_only_failure(exc): ) -async def getnameinfo(sockaddr, flags): +async def getnameinfo( + sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int +) -> tuple[str, str]: """Look up a name given a numeric address. Arguments and return values are identical to :func:`socket.getnameinfo`, @@ -211,7 +263,7 @@ async def getnameinfo(sockaddr, flags): ) -async def getprotobyname(name): +async def getprotobyname(name: str) -> int: """Look up a protocol number by name. (Rarely used.) Like :func:`socket.getprotobyname`, but async. @@ -230,7 +282,7 @@ async def getprotobyname(name): ################################################################ -def from_stdlib_socket(sock): +def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType: """Convert a standard library :class:`socket.socket` object into a Trio socket object. @@ -239,9 +291,14 @@ def from_stdlib_socket(sock): @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd(fd, family, type, proto=0): +def fromfd( + fd: SupportsIndex, + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, +) -> _SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" - family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd) + family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -250,27 +307,41 @@ def fromfd(fd, family, type, proto=0): ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(*args, **kwargs): - return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) + def fromshare(info: bytes) -> _SocketType: + return from_stdlib_socket(_stdlib_socket.fromshare(info)) + + +if sys.platform == "win32": + FamilyT: TypeAlias = int + TypeT: TypeAlias = int + FamilyDefault = _stdlib_socket.AF_INET +else: + FamilyDefault: Literal[None] = None + FamilyT: TypeAlias = Union[int, AddressFamily, None] + TypeT: TypeAlias = Union[_stdlib_socket.socket, int] @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair(*args, **kwargs): +def socketpair( + family: FamilyT = FamilyDefault, + type: TypeT = SocketKind.SOCK_STREAM, + proto: int = 0, +) -> tuple[_SocketType, _SocketType]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. """ - left, right = _stdlib_socket.socketpair(*args, **kwargs) + left, right = _stdlib_socket.socketpair(family, type, proto) return (from_stdlib_socket(left), from_stdlib_socket(right)) @_wraps(_stdlib_socket.socket, assigned=(), updated=()) def socket( - family=_stdlib_socket.AF_INET, - type=_stdlib_socket.SOCK_STREAM, - proto=0, - fileno=None, -): + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, +) -> _SocketType: """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using @@ -287,14 +358,24 @@ def socket( return from_stdlib_socket(stdlib_socket) -def _sniff_sockopts_for_fileno(family, type, proto, fileno): +def _sniff_sockopts_for_fileno( + family: AddressFamily | int, + type: SocketKind | int, + proto: int, + fileno: int | None, +) -> tuple[AddressFamily | int, SocketKind | int, int]: """Correct SOCKOPTS for given fileno, falling back to provided values.""" # Wrap the raw fileno into a Python socket object # This object might have the wrong metadata, but it lets us easily call getsockopt # and then we'll throw it away and construct a new one with the correct metadata. if sys.platform != "linux": return family, type, proto - from socket import SO_DOMAIN, SO_PROTOCOL, SOL_SOCKET, SO_TYPE + from socket import ( # type: ignore[attr-defined] + SO_DOMAIN, + SO_PROTOCOL, + SO_TYPE, + SOL_SOCKET, + ) sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno) try: @@ -324,26 +405,21 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): ) -# This function will modify the given socket to match the behavior in python -# 3.7. This will become unnecessary and can be removed when support for versions -# older than 3.7 is dropped. -def real_socket_type(type_num): - return type_num & _SOCK_TYPE_MASK - - -def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False): - fn = getattr(_stdlib_socket.socket, methname) - +def _make_simple_sock_method_wrapper( + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + maybe_avail: bool = False, +) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: @_wraps(fn, assigned=("__name__",), updated=()) - async def wrapper(self, *args, **kwargs): - return await self._nonblocking_helper(fn, args, kwargs, wait_fn) + async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: + return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs) - wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async. + wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async. """ if maybe_avail: wrapper.__doc__ += ( - f"Only available on platforms where :meth:`socket.socket.{methname}` is " + f"Only available on platforms where :meth:`socket.socket.{fn.__name__}` is " "available." ) return wrapper @@ -362,8 +438,21 @@ async def wrapper(self, *args, **kwargs): # local=False means that the address is being used with connect() or sendto() or # similar. # + + +# Using a TypeVar to indicate we return the same type of address appears to give errors +# when passed a union of address types. +# @overload likely works, but is extremely verbose. # NOTE: this function does not always checkpoint -async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local): +async def _resolve_address_nocp( + type: int, + family: AddressFamily, + proto: int, + *, + ipv6_v6only: bool | int, + address: Address, + local: bool, +) -> Address: # Do some pre-checking (or exit early for non-IP sockets) if family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -373,13 +462,15 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo raise ValueError( "address should be a (host, port, [flowinfo, [scopeid]]) tuple" ) - elif family == _stdlib_socket.AF_UNIX: + elif family == getattr(_stdlib_socket, "AF_UNIX"): # unwrap path-likes + assert isinstance(address, (str, bytes)) return os.fspath(address) else: return address # -- From here on we know we have IPv4 or IPV6 -- + host: str | None host, port, *_ = address # Fast path for the simple case: already-resolved IP address, # already-resolved port. This is particularly important for UDP, since @@ -417,18 +508,24 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo # The above ignored any flowid and scopeid in the passed-in address, # so restore them if present: if family == _stdlib_socket.AF_INET6: - normed = list(normed) + list_normed = list(normed) assert len(normed) == 4 + # typechecking certainly doesn't like this logic, but given just how broad + # Address is, it's quite cumbersome to write the below without type: ignore if len(address) >= 3: - normed[2] = address[2] + list_normed[2] = address[2] # type: ignore if len(address) >= 4: - normed[3] = address[3] - normed = tuple(normed) + list_normed[3] = address[3] # type: ignore + return tuple(list_normed) # type: ignore return normed +# TODO: stopping users from initializing this type should be done in a different way, +# so SocketType can be used as a type. Note that this is *far* from trivial without +# breaking subclasses of SocketType. Can maybe add abstract methods to SocketType, +# or rename _SocketType. class SocketType: - def __init__(self): + def __init__(self) -> NoReturn: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" @@ -436,14 +533,12 @@ def __init__(self): class _SocketType(SocketType): - def __init__(self, sock): + def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. raise TypeError( - "expected object of type 'socket.socket', not '{}".format( - type(sock).__name__ - ) + f"expected object of type 'socket.socket', not '{type(sock).__name__}'" ) self._sock = sock self._sock.setblocking(False) @@ -453,75 +548,121 @@ def __init__(self, sock): # Simple + portable methods and attributes ################################################################ - # NB this doesn't work because for loops don't create a scope - # for _name in [ - # ]: - # _meth = getattr(_stdlib_socket.socket, _name) - # @_wraps(_meth, assigned=("__name__", "__doc__"), updated=()) - # def _wrapped(self, *args, **kwargs): - # return getattr(self._sock, _meth)(*args, **kwargs) - # locals()[_meth] = _wrapped - # del _name, _meth, _wrapped - - _forward = { - "detach", - "get_inheritable", - "set_inheritable", - "fileno", - "getpeername", - "getsockname", - "getsockopt", - "setsockopt", - "listen", - "share", - } - - def __getattr__(self, name): - if name in self._forward: - return getattr(self._sock, name) - raise AttributeError(name) - - def __dir__(self): - return super().__dir__() + list(self._forward) - - def __enter__(self): + # forwarded methods + def detach(self) -> int: + return self._sock.detach() + + def fileno(self) -> int: + return self._sock.fileno() + + def getpeername(self) -> Any: + return self._sock.getpeername() + + def getsockname(self) -> Any: + return self._sock.getsockname() + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + if buflen is None: + return self._sock.getsockopt(level, optname) + return self._sock.getsockopt(level, optname, buflen) + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + if optlen is None: + if value is None: + raise TypeError( + "invalid value for argument 'value', must not be None when specifying optlen" + ) + return self._sock.setsockopt(level, optname, value) + if value is not None: + raise TypeError( + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen" + ) + + # Note: PyPy may crash here due to setsockopt only supporting + # four parameters. + return self._sock.setsockopt(level, optname, value, optlen) + + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + return self._sock.listen(backlog) + + def get_inheritable(self) -> bool: + return self._sock.get_inheritable() + + def set_inheritable(self, inheritable: bool) -> None: + return self._sock.set_inheritable(inheritable) + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, /, process_id: int) -> bytes: + return self._sock.share(process_id) + + def __enter__(self) -> Self: return self - def __exit__(self, *exc_info): - return self._sock.__exit__(*exc_info) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self._sock.__exit__(exc_type, exc_value, traceback) @property - def family(self): + def family(self) -> AddressFamily: return self._sock.family @property - def type(self): - # Modify the socket type do match what is done on python 3.7. When - # support for versions older than 3.7 is dropped, this can be updated - # to just return self._sock.type - return real_socket_type(self._sock.type) + def type(self) -> SocketKind: + return self._sock.type @property - def proto(self): + def proto(self) -> int: return self._sock.proto @property - def did_shutdown_SHUT_WR(self): + def did_shutdown_SHUT_WR(self) -> bool: return self._did_shutdown_SHUT_WR - def __repr__(self): + def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self): + def dup(self) -> _SocketType: """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) - def close(self): + def close(self) -> None: if self._sock.fileno() != -1: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address): + async def bind(self, address: Address) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -539,14 +680,14 @@ async def bind(self, address): await trio.lowlevel.checkpoint() return self._sock.bind(address) - def shutdown(self, flag): + def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: self._sock.shutdown(flag) # only do this if the call succeeded: if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: self._did_shutdown_SHUT_WR = True - def is_readable(self): + def is_readable(self) -> bool: # use select.select on Windows, and select.poll everywhere else if sys.platform == "win32": rready, _, _ = select.select([self._sock], [], [], 0) @@ -555,13 +696,18 @@ def is_readable(self): p.register(self._sock, select.POLLIN) return bool(p.poll(0)) - async def wait_writable(self): + async def wait_writable(self) -> None: await _core.wait_writable(self._sock) - async def _resolve_address_nocp(self, address, *, local): + async def _resolve_address_nocp( + self, + address: Address, + *, + local: bool, + ) -> Address: if self.family == _stdlib_socket.AF_INET6: ipv6_v6only = self._sock.getsockopt( - IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY + _stdlib_socket.IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY ) else: ipv6_v6only = False @@ -574,7 +720,19 @@ async def _resolve_address_nocp(self, address, *, local): local=local, ) - async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): + # args and kwargs must be starred, otherwise pyright complains: + # '"args" member of ParamSpec is valid only when used with *args parameter' + # '"kwargs" member of ParamSpec is valid only when used with **kwargs parameter' + # wait_fn and fn must also be first in the signature + # 'Keyword parameter cannot appear in signature after ParamSpec args parameter' + + async def _nonblocking_helper( + self, + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: # We have to reconcile two conflicting goals: # - We want to make it look like we always blocked in doing these # operations. The obvious way is to always do an IO wait before @@ -610,9 +768,11 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # accept ################################################################ - _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable) + _accept = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.accept, _core.wait_readable + ) - async def accept(self): + async def accept(self) -> tuple[_SocketType, object]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -621,7 +781,7 @@ async def accept(self): # connect ################################################################ - async def connect(self, address): + async def connect(self, address: Address) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -689,32 +849,71 @@ async def connect(self, address): # Okay, the connect finished, but it might have failed: err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR) if err != 0: - raise OSError(err, f"Error connecting to {address}: {os.strerror(err)}") + raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}") ################################################################ # recv ################################################################ - recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) + # Not possible to typecheck with a Callable (due to DefaultArg), nor with a + # callback Protocol (https://github.com/python/typing/discussions/1040) + # but this seems to work. If not explicitly defined then pyright --verifytypes will + # complain about AmbiguousType + if TYPE_CHECKING: + + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: + ... + + # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct + # this requires that we refrain from using `/` to specify pos-only + # args, or mypy thinks the signature differs from typeshed. + recv = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv, _core.wait_readable + ) ################################################################ # recv_into ################################################################ - recv_into = _make_simple_sock_method_wrapper("recv_into", _core.wait_readable) + if TYPE_CHECKING: + + def recv_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + ... + + recv_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv_into, _core.wait_readable + ) ################################################################ # recvfrom ################################################################ - recvfrom = _make_simple_sock_method_wrapper("recvfrom", _core.wait_readable) + if TYPE_CHECKING: + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: + ... + + recvfrom = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom, _core.wait_readable + ) ################################################################ # recvfrom_into ################################################################ - recvfrom_into = _make_simple_sock_method_wrapper( - "recvfrom_into", _core.wait_readable + if TYPE_CHECKING: + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, Address]]: + ... + + recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom_into, _core.wait_readable ) ################################################################ @@ -722,8 +921,15 @@ async def connect(self, address): ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg"): - recvmsg = _make_simple_sock_method_wrapper( - "recvmsg", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg( + __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True ) ################################################################ @@ -731,29 +937,58 @@ async def connect(self, address): ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg_into"): - recvmsg_into = _make_simple_sock_method_wrapper( - "recvmsg_into", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg_into( + __self, + __buffers: Iterable[Buffer], + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True ) ################################################################ # send ################################################################ - send = _make_simple_sock_method_wrapper("send", _core.wait_writable) + if TYPE_CHECKING: + + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + ... + + send = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.send, _core.wait_writable + ) ################################################################ # sendto ################################################################ - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) - async def sendto(self, *args): + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( + self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] + async def sendto(self, *args: Any) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address) # and kwargs are not accepted - args = list(args) - args[-1] = await self._resolve_address_nocp(args[-1], local=False) + args_list = list(args) + args_list[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendto, args, {}, _core.wait_writable + _core.wait_writable, _stdlib_socket.socket.sendto, *args_list ) ################################################################ @@ -765,20 +1000,28 @@ async def sendto(self, *args): ): @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg(self, *args): + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: Address | None = None, + ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. Only available on platforms where :meth:`socket.socket.sendmsg` is available. """ - # args is: buffers[, ancdata[, flags[, address]]] - # and kwargs are not accepted - if len(args) == 4 and args[-1] is not None: - args = list(args) - args[-1] = await self._resolve_address_nocp(args[-1], local=False) + if __address is not None: + __address = await self._resolve_address_nocp(__address, local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable + _core.wait_writable, + _stdlib_socket.socket.sendmsg, + __buffers, + __ancdata, + __flags, + __address, ) ################################################################ diff --git a/trio/_ssl.py b/trio/_ssl.py index 8f005c2c9a..f0f01f7583 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -1,3 +1,18 @@ +from __future__ import annotations + +import operator as _operator +import ssl as _stdlib_ssl +from collections.abc import Awaitable, Callable +from enum import Enum as _Enum +from typing import Any, Final as TFinal, TypeVar + +import trio + +from . import _sync +from ._highlevel_generic import aclose_forcefully +from ._util import ConflictDetector, Final +from .abc import Listener, Stream + # General theory of operation: # # We implement an API that closely mirrors the stdlib ssl module's blocking @@ -149,16 +164,8 @@ # docs will need to make very clear that this is different from all the other # cancellations in core Trio -import operator as _operator -import ssl as _stdlib_ssl -from enum import Enum as _Enum - -import trio -from .abc import Stream, Listener -from ._highlevel_generic import aclose_forcefully -from . import _sync -from ._util import ConflictDetector, Final +T = TypeVar("T") ################################################################ # SSLStream @@ -187,16 +194,16 @@ # MTU and an initial window of 10 (see RFC 6928), then the initial burst of # data will be limited to ~15000 bytes (or a bit less due to IP-level framing # overhead), so this is chosen to be larger than that. -STARTING_RECEIVE_SIZE = 16384 +STARTING_RECEIVE_SIZE: TFinal = 16384 -def _is_eof(exc): +def _is_eof(exc: BaseException | None) -> bool: # There appears to be a bug on Python 3.10, where SSLErrors # aren't properly translated into SSLEOFErrors. # This stringly-typed error check is borrowed from the AnyIO # project. return isinstance(exc, _stdlib_ssl.SSLEOFError) or ( - hasattr(exc, "strerror") and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror + "UNEXPECTED_EOF_WHILE_READING" in getattr(exc, "strerror", ()) ) @@ -209,13 +216,13 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn, *args): + def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None: self._afn = afn self._args = args self.started = False self._done = _sync.Event() - async def ensure(self, *, checkpoint): + async def ensure(self, *, checkpoint: bool) -> None: if not self.started: self.started = True await self._afn(*self._args) @@ -226,8 +233,8 @@ async def ensure(self, *, checkpoint): await self._done.wait() @property - def done(self): - return self._done.is_set() + def done(self) -> bool: + return bool(self._done.is_set()) _State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) @@ -257,8 +264,8 @@ class SSLStream(Stream, metaclass=Final): this connection. Required. Usually created by calling :func:`ssl.create_default_context`. - server_hostname (str or None): The name of the server being connected - to. Used for `SNI + server_hostname (str, bytes, or None): The name of the server being + connected to. Used for `SNI `__ and for validating the server's certificate (if hostname checking is enabled). This is effectively mandatory for clients, and actually @@ -331,24 +338,24 @@ class SSLStream(Stream, metaclass=Final): # SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers. def __init__( self, - transport_stream, - ssl_context, + transport_stream: Stream, + ssl_context: _stdlib_ssl.SSLContext, *, - server_hostname=None, - server_side=False, - https_compatible=False, - ): - self.transport_stream = transport_stream + server_hostname: str | bytes | None = None, + server_side: bool = False, + https_compatible: bool = False, + ) -> None: + self.transport_stream: Stream = transport_stream self._state = _State.OK self._https_compatible = https_compatible self._outgoing = _stdlib_ssl.MemoryBIO() - self._delayed_outgoing = None + self._delayed_outgoing: bytes | None = None self._incoming = _stdlib_ssl.MemoryBIO() self._ssl_object = ssl_context.wrap_bio( self._incoming, self._outgoing, server_side=server_side, - server_hostname=server_hostname, + server_hostname=server_hostname, # type: ignore[arg-type] # Typeshed bug, does accept bytes as well (typeshed#10590) ) # Tracks whether we've already done the initial handshake self._handshook = _Once(self._do_handshake) @@ -399,7 +406,7 @@ def __init__( "version", } - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in self._forwarded: if name in self._after_handshake and not self._handshook.done: raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") @@ -408,16 +415,16 @@ def __getattr__(self, name): else: raise AttributeError(name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: object) -> None: if name in self._forwarded: setattr(self._ssl_object, name, value) else: super().__setattr__(name, value) - def __dir__(self): - return super().__dir__() + list(self._forwarded) + def __dir__(self) -> list[str]: + return list(super().__dir__()) + list(self._forwarded) - def _check_status(self): + def _check_status(self) -> None: if self._state is _State.OK: return elif self._state is _State.BROKEN: @@ -431,7 +438,13 @@ def _check_status(self): # comments, though, just make sure to think carefully if you ever have to # touch it. The big comment at the top of this file will help explain # too. - async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): + async def _retry( + self, + fn: Callable[..., T], + *args: object, + ignore_want_read: bool = False, + is_handshake: bool = False, + ) -> T | None: await trio.lowlevel.checkpoint_if_cancelled() yielded = False finished = False @@ -603,14 +616,14 @@ async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): await trio.lowlevel.cancel_shielded_checkpoint() return ret - async def _do_handshake(self): + async def _do_handshake(self) -> None: try: await self._retry(self._ssl_object.do_handshake, is_handshake=True) except: self._state = _State.BROKEN raise - async def do_handshake(self): + async def do_handshake(self) -> None: """Ensure that the initial handshake has completed. The SSL protocol requires an initial handshake to exchange @@ -645,7 +658,7 @@ async def do_handshake(self): # https://bugs.python.org/issue30141 # So we *definitely* have to make sure that do_handshake is called # before doing anything else. - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """Read some data from the underlying transport, decrypt it, and return it. @@ -684,7 +697,9 @@ async def receive_some(self, max_bytes=None): if max_bytes < 1: raise ValueError("max_bytes must be >= 1") try: - return await self._retry(self._ssl_object.read, max_bytes) + received = await self._retry(self._ssl_object.read, max_bytes) + assert received is not None + return received except trio.BrokenResourceError as exc: # This isn't quite equivalent to just returning b"" in the # first place, because we still end up with self._state set to @@ -698,7 +713,7 @@ async def receive_some(self, max_bytes=None): else: raise - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Encrypt some data and then send it on the underlying transport. See :meth:`trio.abc.SendStream.send_all` for details. @@ -719,7 +734,7 @@ async def send_all(self, data): return await self._retry(self._ssl_object.write, data) - async def unwrap(self): + async def unwrap(self) -> tuple[Stream, bytes | bytearray]: """Cleanly close down the SSL/TLS encryption layer, allowing the underlying stream to be used for unencrypted communication. @@ -741,11 +756,11 @@ async def unwrap(self): await self._handshook.ensure(checkpoint=False) await self._retry(self._ssl_object.unwrap) transport_stream = self.transport_stream - self.transport_stream = None self._state = _State.CLOSED + self.transport_stream = None # type: ignore[assignment] # State is CLOSED now, nothing should use return (transport_stream, self._incoming.read()) - async def aclose(self): + async def aclose(self) -> None: """Gracefully shut down this connection, and close the underlying transport. @@ -832,7 +847,7 @@ async def aclose(self): finally: self._state = _State.CLOSED - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.""" # This method's implementation is deceptively simple. # @@ -897,16 +912,16 @@ class SSLListener(Listener[SSLStream], metaclass=Final): def __init__( self, - transport_listener, - ssl_context, + transport_listener: Listener[Stream], + ssl_context: _stdlib_ssl.SSLContext, *, - https_compatible=False, - ): + https_compatible: bool = False, + ) -> None: self.transport_listener = transport_listener self._ssl_context = ssl_context self._https_compatible = https_compatible - async def accept(self): + async def accept(self) -> SSLStream: """Accept the next connection and wrap it in an :class:`SSLStream`. See :meth:`trio.abc.Listener.accept` for details. @@ -920,6 +935,6 @@ async def accept(self): https_compatible=self._https_compatible, ) - async def aclose(self): + async def aclose(self) -> None: """Close the transport listener.""" await self.transport_listener.aclose() diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 34eeb22dbb..978f7e6188 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -1,24 +1,37 @@ +from __future__ import annotations + import os +import signal import subprocess import sys +import warnings +from collections.abc import Awaitable, Callable, Mapping, Sequence from contextlib import ExitStack -from typing import Optional from functools import partial -import warnings -from typing import TYPE_CHECKING +from io import TextIOWrapper +from typing import TYPE_CHECKING, Final, Literal, Protocol, Union, overload + +import trio -from ._abc import AsyncResource, SendStream, ReceiveStream -from ._core import ClosedResourceError +from ._abc import AsyncResource, ReceiveStream, SendStream +from ._core import ClosedResourceError, TaskStatus +from ._deprecate import deprecated from ._highlevel_generic import StapledStream -from ._sync import Lock from ._subprocess_platform import ( - wait_child_exiting, - create_pipe_to_child_stdin, create_pipe_from_child_output, + create_pipe_to_child_stdin, + wait_child_exiting, ) -from ._deprecate import deprecated +from ._sync import Lock from ._util import NoPublicConstructor -import trio + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + +# Only subscriptable in 3.9+ +StrOrBytesPath: TypeAlias = Union[str, bytes, "os.PathLike[str]", "os.PathLike[bytes]"] + # Linux-specific, but has complex lifetime management stuff so we hard-code it # here instead of hiding it behind the _subprocess_platform abstraction @@ -65,6 +78,13 @@ def pidfd_open(fd: int, flags: int) -> int: can_try_pidfd_open = False +class HasFileno(Protocol): + """Represents any file-like object that has a file descriptor.""" + + def fileno(self) -> int: + ... + + class Process(AsyncResource, metaclass=NoPublicConstructor): r"""A child process. Like :class:`subprocess.Popen`, but async. @@ -107,32 +127,38 @@ class Process(AsyncResource, metaclass=NoPublicConstructor): available; otherwise this will be None. """ - - universal_newlines = False - encoding = None - errors = None + # We're always in binary mode. + universal_newlines: Final = False + encoding: Final = None + errors: Final = None # Available for the per-platform wait_child_exiting() implementations # to stash some state; waitid platforms use this to avoid spawning # arbitrarily many threads if wait() keeps getting cancelled. - _wait_for_exit_data = None - - def __init__(self, popen, stdin, stdout, stderr): + _wait_for_exit_data: object = None + + def __init__( + self, + popen: subprocess.Popen[bytes], + stdin: SendStream | None, + stdout: ReceiveStream | None, + stderr: ReceiveStream | None, + ) -> None: self._proc = popen - self.stdin: Optional[SendStream] = stdin - self.stdout: Optional[ReceiveStream] = stdout - self.stderr: Optional[ReceiveStream] = stderr + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr - self.stdio: Optional[StapledStream] = None + self.stdio: StapledStream[SendStream, ReceiveStream] | None = None if self.stdin is not None and self.stdout is not None: self.stdio = StapledStream(self.stdin, self.stdout) - self._wait_lock = Lock() + self._wait_lock: Lock = Lock() - self._pidfd = None + self._pidfd: TextIOWrapper | None = None if can_try_pidfd_open: try: - fd = pidfd_open(self._proc.pid, 0) + fd: int = pidfd_open(self._proc.pid, 0) except OSError: # Well, we tried, but it didn't work (probably because we're # running on an older kernel, or in an older sandbox, that @@ -144,10 +170,10 @@ def __init__(self, popen, stdin, stdout, stderr): # make sure it'll get closed. self._pidfd = open(fd) - self.args = self._proc.args - self.pid = self._proc.pid + self.args: StrOrBytesPath | Sequence[StrOrBytesPath] = self._proc.args + self.pid: int = self._proc.pid - def __repr__(self): + def __repr__(self) -> str: returncode = self.returncode if returncode is None: status = f"running with PID {self.pid}" @@ -159,7 +185,7 @@ def __repr__(self): return f"" @property - def returncode(self): + def returncode(self) -> int | None: """The exit status of the process (an integer), or ``None`` if it's still running. @@ -186,13 +212,13 @@ def returncode(self): issue=1104, instead="run_process or nursery.start(run_process, ...)", ) - async def __aenter__(self): + async def __aenter__(self) -> Process: return self @deprecated( "0.20.0", issue=1104, instead="run_process or nursery.start(run_process, ...)" ) - async def aclose(self): + async def aclose(self) -> None: """Close any pipes we have to the process (both input and output) and wait for it to exit. @@ -214,13 +240,13 @@ async def aclose(self): with trio.CancelScope(shield=True): await self.wait() - def _close_pidfd(self): + def _close_pidfd(self) -> None: if self._pidfd is not None: trio.lowlevel.notify_closing(self._pidfd.fileno()) self._pidfd.close() self._pidfd = None - async def wait(self): + async def wait(self) -> int: """Block until the process exits. Returns: @@ -230,7 +256,7 @@ async def wait(self): if self.poll() is None: if self._pidfd is not None: try: - await trio.lowlevel.wait_readable(self._pidfd) + await trio.lowlevel.wait_readable(self._pidfd.fileno()) except ClosedResourceError: # something else (probably a call to poll) already closed the # pidfd @@ -248,7 +274,7 @@ async def wait(self): assert self._proc.returncode is not None return self._proc.returncode - def poll(self): + def poll(self) -> int | None: """Returns the exit status of the process (an integer), or ``None`` if it's still running. @@ -260,7 +286,7 @@ def poll(self): """ return self.returncode - def send_signal(self, sig): + def send_signal(self, sig: signal.Signals | int) -> None: """Send signal ``sig`` to the process. On UNIX, ``sig`` may be any signal defined in the @@ -270,7 +296,7 @@ def send_signal(self, sig): """ self._proc.send_signal(sig) - def terminate(self): + def terminate(self) -> None: """Terminate the process, politely if possible. On UNIX, this is equivalent to @@ -281,7 +307,7 @@ def terminate(self): """ self._proc.terminate() - def kill(self): + def kill(self) -> None: """Immediately terminate the process. On UNIX, this is equivalent to @@ -294,8 +320,13 @@ def kill(self): self._proc.kill() -async def open_process( - command, *, stdin=None, stdout=None, stderr=None, **options +async def _open_process( + command: list[str] | str, + *, + stdin: int | HasFileno | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + **options: object, ) -> Process: r"""Execute a child program in a new process. @@ -366,9 +397,9 @@ async def open_process( "on UNIX systems" ) - trio_stdin: Optional[ClosableSendStream] = None - trio_stdout: Optional[ClosableReceiveStream] = None - trio_stderr: Optional[ClosableReceiveStream] = None + trio_stdin: ClosableSendStream | None = None + trio_stdout: ClosableReceiveStream | None = None + trio_stderr: ClosableReceiveStream | None = None # Close the parent's handle for each child side of a pipe; we want the child to # have the only copy, so that when it exits we can read EOF on our side. The # trio ends of pipes will be transferred to the Process object, which will be @@ -414,14 +445,14 @@ async def open_process( return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) -async def _windows_deliver_cancel(p): +async def _windows_deliver_cancel(p: Process) -> None: try: p.terminate() except OSError as exc: warnings.warn(RuntimeWarning(f"TerminateProcess on {p!r} failed with: {exc!r}")) -async def _posix_deliver_cancel(p): +async def _posix_deliver_cancel(p: Process) -> None: try: p.terminate() await trio.sleep(5) @@ -439,17 +470,18 @@ async def _posix_deliver_cancel(p): ) -async def run_process( - command, +# Use a private name, so we can declare platform-specific stubs below. +async def _run_process( + command: StrOrBytesPath | Sequence[StrOrBytesPath], *, - stdin=b"", - capture_stdout=False, - capture_stderr=False, - check=True, - deliver_cancel=None, - task_status=trio.TASK_STATUS_IGNORED, - **options, -): + stdin: bytes | bytearray | memoryview | int | HasFileno | None = b"", + capture_stdout: bool = False, + capture_stderr: bool = False, + check: bool = True, + deliver_cancel: Callable[[Process], Awaitable[object]] | None = None, + task_status: TaskStatus[Process] = trio.TASK_STATUS_IGNORED, + **options: object, +) -> subprocess.CompletedProcess[bytes]: """Run ``command`` in a subprocess and wait for it to complete. This function can be called in two different ways. @@ -687,23 +719,28 @@ async def my_deliver_cancel(process): assert os.name == "posix" deliver_cancel = _posix_deliver_cancel - stdout_chunks = [] - stderr_chunks = [] + stdout_chunks: list[bytes | bytearray] = [] + stderr_chunks: list[bytes | bytearray] = [] - async def feed_input(stream): + async def feed_input(stream: SendStream) -> None: async with stream: try: + assert input is not None await stream.send_all(input) except trio.BrokenResourceError: pass - async def read_output(stream, chunks): + async def read_output( + stream: ReceiveStream, + chunks: list[bytes | bytearray], + ) -> None: async with stream: async for chunk in stream: chunks.append(chunk) async with trio.open_nursery() as nursery: - proc = await open_process(command, **options) + # options needs a complex TypedDict. The overload error only occurs on Unix. + proc = await open_process(command, **options) # type: ignore[arg-type, call-overload, unused-ignore] try: if input is not None: nursery.start_soon(feed_input, proc.stdin) @@ -722,7 +759,7 @@ async def read_output(stream, chunks): with trio.CancelScope(shield=True): killer_cscope = trio.CancelScope(shield=True) - async def killer(): + async def killer() -> None: with killer_cscope: await deliver_cancel(proc) @@ -739,4 +776,147 @@ async def killer(): proc.returncode, proc.args, output=stdout, stderr=stderr ) else: + assert proc.returncode is not None return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) + + +# There's a lot of duplication here because type checkers don't +# have a good way to represent overloads that differ only +# slightly. A cheat sheet: +# - on Windows, command is Union[str, Sequence[str]]; +# on Unix, command is str if shell=True and Sequence[str] otherwise +# - on Windows, there are startupinfo and creationflags options; +# on Unix, there are preexec_fn, restore_signals, start_new_session, and pass_fds +# - run_process() has the signature of open_process() plus arguments +# capture_stdout, capture_stderr, check, deliver_cancel, and the ability to pass +# bytes as stdin + +if TYPE_CHECKING: + if sys.platform == "win32": + + async def open_process( + command: Union[StrOrBytesPath, Sequence[StrOrBytesPath]], + *, + stdin: int | HasFileno | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + close_fds: bool = True, + shell: bool = False, + cwd: StrOrBytesPath | None = None, + env: Mapping[str, str] | None = None, + startupinfo: subprocess.STARTUPINFO | None = None, + creationflags: int = 0, + ) -> trio.Process: + ... + + async def run_process( + command: StrOrBytesPath | Sequence[StrOrBytesPath], + *, + task_status: TaskStatus[Process] = trio.TASK_STATUS_IGNORED, + stdin: bytes | bytearray | memoryview | int | HasFileno | None = None, + capture_stdout: bool = False, + capture_stderr: bool = False, + check: bool = True, + deliver_cancel: Callable[[Process], Awaitable[object]] | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + close_fds: bool = True, + shell: bool = False, + cwd: StrOrBytesPath | None = None, + env: Mapping[str, str] | None = None, + startupinfo: subprocess.STARTUPINFO | None = None, + creationflags: int = 0, + ) -> subprocess.CompletedProcess[bytes]: + ... + + else: # Unix + + @overload # type: ignore[no-overload-impl] + async def open_process( + command: StrOrBytesPath, + *, + stdin: int | HasFileno | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + close_fds: bool = True, + shell: Literal[True], + cwd: StrOrBytesPath | None = None, + env: Mapping[str, str] | None = None, + preexec_fn: Callable[[], object] | None = None, + restore_signals: bool = True, + start_new_session: bool = False, + pass_fds: Sequence[int] = (), + ) -> trio.Process: + ... + + @overload + async def open_process( + command: Sequence[StrOrBytesPath], + *, + stdin: int | HasFileno | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + close_fds: bool = True, + shell: bool = False, + cwd: StrOrBytesPath | None = None, + env: Mapping[str, str] | None = None, + preexec_fn: Callable[[], object] | None = None, + restore_signals: bool = True, + start_new_session: bool = False, + pass_fds: Sequence[int] = (), + ) -> trio.Process: + ... + + @overload # type: ignore[no-overload-impl] + async def run_process( + command: StrOrBytesPath, + *, + task_status: TaskStatus[Process] = trio.TASK_STATUS_IGNORED, + stdin: bytes | bytearray | memoryview | int | HasFileno | None = None, + capture_stdout: bool = False, + capture_stderr: bool = False, + check: bool = True, + deliver_cancel: Callable[[Process], Awaitable[object]] | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + close_fds: bool = True, + shell: Literal[True], + cwd: StrOrBytesPath | None = None, + env: Mapping[str, str] | None = None, + preexec_fn: Callable[[], object] | None = None, + restore_signals: bool = True, + start_new_session: bool = False, + pass_fds: Sequence[int] = (), + ) -> subprocess.CompletedProcess[bytes]: + ... + + @overload + async def run_process( + command: Sequence[StrOrBytesPath], + *, + task_status: TaskStatus[Process] = trio.TASK_STATUS_IGNORED, + stdin: bytes | bytearray | memoryview | int | HasFileno | None = None, + capture_stdout: bool = False, + capture_stderr: bool = False, + check: bool = True, + deliver_cancel: Callable[[Process], Awaitable[None]] | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + close_fds: bool = True, + shell: bool = False, + cwd: StrOrBytesPath | None = None, + env: Mapping[str, str] | None = None, + preexec_fn: Callable[[], object] | None = None, + restore_signals: bool = True, + start_new_session: bool = False, + pass_fds: Sequence[int] = (), + ) -> subprocess.CompletedProcess[bytes]: + ... + +else: + # At runtime, use the actual implementations. + open_process = _open_process + open_process.__name__ = open_process.__qualname__ = "open_process" + + run_process = _run_process + run_process.__name__ = run_process.__qualname__ = "run_process" diff --git a/trio/_subprocess_platform/__init__.py b/trio/_subprocess_platform/__init__.py index 7a131e090c..b6767af8f5 100644 --- a/trio/_subprocess_platform/__init__.py +++ b/trio/_subprocess_platform/__init__.py @@ -2,12 +2,12 @@ import os import sys -from typing import Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple import trio -from .. import _core, _subprocess -from .._abc import SendStream, ReceiveStream +from .. import _core, _subprocess +from .._abc import ReceiveStream, SendStream _wait_child_exiting_error: Optional[ImportError] = None _create_child_pipe_error: Optional[ImportError] = None @@ -74,7 +74,8 @@ def create_pipe_from_child_output() -> Tuple["ClosableReceiveStream", int]: elif sys.platform != "linux" and (TYPE_CHECKING or hasattr(_core, "wait_kevent")): from .kqueue import wait_child_exiting # noqa: F811 else: - from .waitid import wait_child_exiting # noqa: F811 + # noqa'd as it's an exported symbol + from .waitid import wait_child_exiting # noqa: F811, F401 except ImportError as ex: # pragma: no cover _wait_child_exiting_error = ex @@ -94,7 +95,7 @@ def create_pipe_from_child_output(): # noqa: F811 return trio.lowlevel.FdStream(rfd), wfd elif os.name == "nt": - from .._windows_pipes import PipeSendStream, PipeReceiveStream + import msvcrt # This isn't exported or documented, but it's also not # underscore-prefixed, and seems kosher to use. The asyncio docs @@ -103,7 +104,8 @@ def create_pipe_from_child_output(): # noqa: F811 # when asyncio.windows_utils.socketpair was removed in 3.7, the # removal was mentioned in the release notes. from asyncio.windows_utils import pipe as windows_pipe - import msvcrt + + from .._windows_pipes import PipeReceiveStream, PipeSendStream def create_pipe_to_child_stdin(): # noqa: F811 # for stdin, we want the write end (our end) to use overlapped I/O diff --git a/trio/_subprocess_platform/kqueue.py b/trio/_subprocess_platform/kqueue.py index 412ccf8732..efd0562fc2 100644 --- a/trio/_subprocess_platform/kqueue.py +++ b/trio/_subprocess_platform/kqueue.py @@ -1,6 +1,9 @@ -import sys +from __future__ import annotations + import select +import sys from typing import TYPE_CHECKING + from .. import _core, _subprocess assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING @@ -34,7 +37,7 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: # in Chromium it seems we should still keep the check. return - def abort(_): + def abort(_: _core.RaiseCancelT) -> _core.Abort: kqueue.control([make_event(select.KQ_EV_DELETE)], 0) return _core.Abort.SUCCEEDED diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index ad69017219..2a2ca6719d 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -2,15 +2,19 @@ import math import os import sys +from typing import TYPE_CHECKING from .. import _core, _subprocess from .._sync import CapacityLimiter, Event from .._threads import to_thread_run_sync +assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING + + try: from os import waitid - def sync_wait_reapable(pid): + def sync_wait_reapable(pid: int) -> None: waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT) except ImportError: @@ -39,9 +43,9 @@ def sync_wait_reapable(pid): int waitid(int idtype, int id, siginfo_t* result, int options); """ ) - waitid = waitid_ffi.dlopen(None).waitid + waitid_cffi = waitid_ffi.dlopen(None).waitid - def sync_wait_reapable(pid): + def sync_wait_reapable(pid: int) -> None: P_PID = 1 WEXITED = 0x00000004 if sys.platform == "darwin": # pragma: no cover @@ -52,7 +56,7 @@ def sync_wait_reapable(pid): else: WNOWAIT = 0x01000000 result = waitid_ffi.new("siginfo_t *") - while waitid(P_PID, pid, result, WEXITED | WNOWAIT) < 0: + while waitid_cffi(P_PID, pid, result, WEXITED | WNOWAIT) < 0: got_errno = waitid_ffi.errno if got_errno == errno.EINTR: continue @@ -101,7 +105,7 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: # process. if process._wait_for_exit_data is None: - process._wait_for_exit_data = event = Event() # type: ignore + process._wait_for_exit_data = event = Event() _core.spawn_system_task(_waitid_system_task, process.pid, event) assert isinstance(process._wait_for_exit_data, Event) await process._wait_for_exit_data.wait() diff --git a/trio/_subprocess_platform/windows.py b/trio/_subprocess_platform/windows.py index 958be8675c..1634e74fa7 100644 --- a/trio/_subprocess_platform/windows.py +++ b/trio/_subprocess_platform/windows.py @@ -3,4 +3,5 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: - await WaitForSingleObject(int(process._proc._handle)) + # _handle is not in Popen stubs, though it is present on Windows. + await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined] diff --git a/trio/_sync.py b/trio/_sync.py index 8d2fdc0a2d..df4790ae74 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,17 +1,35 @@ +from __future__ import annotations + import math +from typing import TYPE_CHECKING, Protocol import attr import trio from . import _core -from ._core import enable_ki_protection, ParkingLot +from ._core import Abort, ParkingLot, RaiseCancelT, enable_ki_protection from ._util import Final +if TYPE_CHECKING: + from types import TracebackType + + from ._core import Task + from ._core._parking_lot import ParkingLotStatistics + + +@attr.s(frozen=True, slots=True) +class EventStatistics: + """An object containing debugging information. -@attr.s(frozen=True) -class _EventStatistics: - tasks_waiting = attr.ib() + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this event's + :meth:`trio.Event.wait` method. + + """ + + tasks_waiting: int = attr.ib() @attr.s(repr=False, eq=False, hash=False, slots=True) @@ -41,15 +59,15 @@ class Event(metaclass=Final): """ - _tasks = attr.ib(factory=set, init=False) - _flag = attr.ib(default=False, init=False) + _tasks: set[Task] = attr.ib(factory=set, init=False) + _flag: bool = attr.ib(default=False, init=False) - def is_set(self): + def is_set(self) -> bool: """Return the current value of the internal flag.""" return self._flag @enable_ki_protection - def set(self): + def set(self) -> None: """Set the internal flag value to True, and wake any waiting tasks.""" if not self._flag: self._flag = True @@ -57,7 +75,7 @@ def set(self): _core.reschedule(task) self._tasks.clear() - async def wait(self): + async def wait(self) -> None: """Block until the internal flag value becomes True. If it's already True, then this method returns immediately. @@ -69,13 +87,13 @@ async def wait(self): task = _core.current_task() self._tasks.add(task) - def abort_fn(_): + def abort_fn(_: RaiseCancelT) -> Abort: self._tasks.remove(task) return _core.Abort.SUCCEEDED await _core.wait_task_rescheduled(abort_fn) - def statistics(self): + def statistics(self) -> EventStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -84,27 +102,62 @@ def statistics(self): :meth:`wait` method. """ - return _EventStatistics(tasks_waiting=len(self._tasks)) + return EventStatistics(tasks_waiting=len(self._tasks)) + + +class _HasAcquireRelease(Protocol): + """Only classes with acquire() and release() can use the mixin's implementations.""" + + async def acquire(self) -> object: + ... + + def release(self) -> object: + ... class AsyncContextManagerMixin: @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self: _HasAcquireRelease) -> None: await self.acquire() @enable_ki_protection - async def __aexit__(self, *args): + async def __aexit__( + self: _HasAcquireRelease, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.release() -@attr.s(frozen=True) -class _CapacityLimiterStatistics: - borrowed_tokens = attr.ib() - total_tokens = attr.ib() - borrowers = attr.ib() - tasks_waiting = attr.ib() +@attr.s(frozen=True, slots=True) +class CapacityLimiterStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + * ``borrowed_tokens``: The number of tokens currently borrowed from + the sack. + * ``total_tokens``: The total number of tokens in the sack. Usually + this will be larger than ``borrowed_tokens``, but it's possibly for + it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased. + * ``borrowers``: A list of all tasks or other entities that currently + hold a token. + * ``tasks_waiting``: The number of tasks blocked on this + :class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or + :meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods. + """ + + borrowed_tokens: int = attr.ib() + total_tokens: int | float = attr.ib() + borrowers: list[Task | object] = attr.ib() + tasks_waiting: int = attr.ib() + + +# Can be a generic type with a default of Task if/when PEP 696 is released +# and implemented in type checkers. Making it fully generic would currently +# introduce a lot of unnecessary hassle. class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): """An object for controlling access to a resource with limited capacity. @@ -159,22 +212,23 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, total_tokens): + # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing + def __init__(self, total_tokens: int | float): self._lot = ParkingLot() - self._borrowers = set() + self._borrowers: set[Task | object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of - self._pending_borrowers = {} + self._pending_borrowers: dict[Task, Task | object] = {} # invoke the property setter for validation - self.total_tokens = total_tokens + self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), len(self._borrowers), self._total_tokens, len(self._lot) ) @property - def total_tokens(self): + def total_tokens(self) -> int | float: """The total capacity available. You can change :attr:`total_tokens` by assigning to this attribute. If @@ -189,7 +243,7 @@ def total_tokens(self): return self._total_tokens @total_tokens.setter - def total_tokens(self, new_total_tokens): + def total_tokens(self, new_total_tokens: int | float) -> None: if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf: raise TypeError("total_tokens must be an int or math.inf") if new_total_tokens < 1: @@ -197,23 +251,23 @@ def total_tokens(self, new_total_tokens): self._total_tokens = new_total_tokens self._wake_waiters() - def _wake_waiters(self): + def _wake_waiters(self) -> None: available = self._total_tokens - len(self._borrowers) for woken in self._lot.unpark(count=available): self._borrowers.add(self._pending_borrowers.pop(woken)) @property - def borrowed_tokens(self): + def borrowed_tokens(self) -> int: """The amount of capacity that's currently in use.""" return len(self._borrowers) @property - def available_tokens(self): + def available_tokens(self) -> int | float: """The amount of capacity that's available to use.""" return self.total_tokens - self.borrowed_tokens @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Borrow a token from the sack, without blocking. Raises: @@ -225,7 +279,7 @@ def acquire_nowait(self): self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -253,7 +307,7 @@ def acquire_on_behalf_of_nowait(self, borrower): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Borrow a token from the sack, blocking if necessary. Raises: @@ -264,7 +318,7 @@ async def acquire(self): await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -293,7 +347,7 @@ async def acquire_on_behalf_of(self, borrower): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Put a token back into the sack. Raises: @@ -304,7 +358,7 @@ def release(self): self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower: Task | object) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: @@ -319,7 +373,7 @@ def release_on_behalf_of(self, borrower): self._borrowers.remove(borrower) self._wake_waiters() - def statistics(self): + def statistics(self) -> CapacityLimiterStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -336,7 +390,7 @@ def statistics(self): :meth:`acquire_on_behalf_of` methods. """ - return _CapacityLimiterStatistics( + return CapacityLimiterStatistics( borrowed_tokens=len(self._borrowers), total_tokens=self._total_tokens, # Use a list instead of a frozenset just in case we start to allow @@ -373,7 +427,7 @@ class Semaphore(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, initial_value, *, max_value=None): + def __init__(self, initial_value: int, *, max_value: int | None = None): if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -391,7 +445,7 @@ def __init__(self, initial_value, *, max_value=None): self._value = initial_value self._max_value = max_value - def __repr__(self): + def __repr__(self) -> str: if self._max_value is None: max_value_str = "" else: @@ -401,17 +455,17 @@ def __repr__(self): ) @property - def value(self): + def value(self) -> int: """The current value of the semaphore.""" return self._value @property - def max_value(self): + def max_value(self) -> int | None: """The maximum allowed value. May be None to indicate no limit.""" return self._max_value @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to decrement the semaphore value, without blocking. Raises: @@ -425,7 +479,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Decrement the semaphore value, blocking if necessary to avoid letting it drop below zero. @@ -439,7 +493,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Increment the semaphore value, possibly waking a task blocked in :meth:`acquire`. @@ -456,7 +510,7 @@ def release(self): raise ValueError("semaphore released too many times") self._value += 1 - def statistics(self): + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -468,19 +522,31 @@ def statistics(self): return self._lot.statistics() -@attr.s(frozen=True) -class _LockStatistics: - locked = attr.ib() - owner = attr.ib() - tasks_waiting = attr.ib() +@attr.s(frozen=True, slots=True) +class LockStatistics: + """An object containing debugging information for a Lock. + + Currently the following fields are defined: + + * ``locked`` (boolean): indicating whether the lock is held. + * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, + or None if the lock is not held. + * ``tasks_waiting`` (int): The number of tasks blocked on this lock's + :meth:`trio.Lock.acquire` method. + + """ + + locked: bool = attr.ib() + owner: Task | None = attr.ib() + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False, repr=False) class _LockImpl(AsyncContextManagerMixin): - _lot = attr.ib(factory=ParkingLot, init=False) - _owner = attr.ib(default=None, init=False) + _lot: ParkingLot = attr.ib(factory=ParkingLot, init=False) + _owner: Task | None = attr.ib(default=None, init=False) - def __repr__(self): + def __repr__(self) -> str: if self.locked(): s1 = "locked" s2 = f" with {len(self._lot)} waiters" @@ -491,7 +557,7 @@ def __repr__(self): s1, self.__class__.__name__, id(self), s2 ) - def locked(self): + def locked(self) -> bool: """Check whether the lock is currently held. Returns: @@ -501,7 +567,7 @@ def locked(self): return self._owner is not None @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the lock, without blocking. Raises: @@ -519,7 +585,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Acquire the lock, blocking if necessary.""" await trio.lowlevel.checkpoint_if_cancelled() try: @@ -533,7 +599,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Release the lock. Raises: @@ -548,7 +614,7 @@ def release(self): else: self._owner = None - def statistics(self): + def statistics(self) -> LockStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -560,7 +626,7 @@ def statistics(self): :meth:`acquire` method. """ - return _LockStatistics( + return LockStatistics( locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot) ) @@ -642,10 +708,20 @@ class StrictFIFOLock(_LockImpl, metaclass=Final): """ -@attr.s(frozen=True) -class _ConditionStatistics: - tasks_waiting = attr.ib() - lock_statistics = attr.ib() +@attr.s(frozen=True, slots=True) +class ConditionStatistics: + r"""An object containing debugging information for a Condition. + + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this condition's + :meth:`trio.Condition.wait` method. + * ``lock_statistics``: The result of calling the underlying + :class:`Lock`\s :meth:`~Lock.statistics` method. + + """ + tasks_waiting: int = attr.ib() + lock_statistics: LockStatistics = attr.ib() class Condition(AsyncContextManagerMixin, metaclass=Final): @@ -663,7 +739,7 @@ class Condition(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, lock=None): + def __init__(self, lock: Lock | None = None): if lock is None: lock = Lock() if not type(lock) is Lock: @@ -671,7 +747,7 @@ def __init__(self, lock=None): self._lock = lock self._lot = trio.lowlevel.ParkingLot() - def locked(self): + def locked(self) -> bool: """Check whether the underlying lock is currently held. Returns: @@ -680,7 +756,7 @@ def locked(self): """ return self._lock.locked() - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the underlying lock, without blocking. Raises: @@ -689,16 +765,16 @@ def acquire_nowait(self): """ return self._lock.acquire_nowait() - async def acquire(self): + async def acquire(self) -> None: """Acquire the underlying lock, blocking if necessary.""" await self._lock.acquire() - def release(self): + def release(self) -> None: """Release the underlying lock.""" self._lock.release() @enable_ki_protection - async def wait(self): + async def wait(self) -> None: """Wait for another task to call :meth:`notify` or :meth:`notify_all`. @@ -733,7 +809,7 @@ async def wait(self): await self.acquire() raise - def notify(self, n=1): + def notify(self, n: int = 1) -> None: """Wake one or more tasks that are blocked in :meth:`wait`. Args: @@ -747,7 +823,7 @@ def notify(self, n=1): raise RuntimeError("must hold the lock to notify") self._lot.repark(self._lock._lot, count=n) - def notify_all(self): + def notify_all(self) -> None: """Wake all tasks that are currently blocked in :meth:`wait`. Raises: @@ -758,7 +834,7 @@ def notify_all(self): raise RuntimeError("must hold the lock to notify") self._lot.repark_all(self._lock._lot) - def statistics(self): + def statistics(self) -> ConditionStatistics: r"""Return an object containing debugging information. Currently the following fields are defined: @@ -769,6 +845,6 @@ def statistics(self): :class:`Lock`\s :meth:`~Lock.statistics` method. """ - return _ConditionStatistics( + return ConditionStatistics( tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics() ) diff --git a/trio/tests/__init__.py b/trio/_tests/__init__.py similarity index 100% rename from trio/tests/__init__.py rename to trio/_tests/__init__.py diff --git a/.github/workflows/astrill-codesigning-cert.cer b/trio/_tests/astrill-codesigning-cert.cer similarity index 100% rename from .github/workflows/astrill-codesigning-cert.cer rename to trio/_tests/astrill-codesigning-cert.cer diff --git a/trio/_tests/check_type_completeness.py b/trio/_tests/check_type_completeness.py new file mode 100755 index 0000000000..1352926be3 --- /dev/null +++ b/trio/_tests/check_type_completeness.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +# this file is not run as part of the tests, instead it's run standalone from check.sh +import argparse +import json +import subprocess +import sys +from collections.abc import Mapping +from pathlib import Path + +# the result file is not marked in MANIFEST.in so it's not included in the package +failed = False + + +def get_result_file_name(platform: str) -> Path: + return Path(__file__).parent / f"verify_types_{platform.lower()}.json" + + +# TODO: consider checking manually without `--ignoreexternal`, and/or +# removing it from the below call later on. +def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]: + return subprocess.run( + [ + "pyright", + # Specify a platform and version to keep imported modules consistent. + f"--pythonplatform={platform}", + "--pythonversion=3.8", + "--verifytypes=trio", + "--outputjson", + "--ignoreexternal", + ], + capture_output=True, + ) + + +def check_less_than( + key: str, + current_dict: Mapping[str, float], + last_dict: Mapping[str, float], + /, + invert: bool = False, +) -> None: + global failed + current = current_dict[key] + last = last_dict[key] + assert isinstance(current, (float, int)) + assert isinstance(last, (float, int)) + if current == last: + return + if (current > last) ^ invert: + failed = True + print("ERROR: ", end="") + if isinstance(current, float): + strcurrent = f"{current:.4}" + else: + strcurrent = str(current) + if isinstance(last, float): + strlast = f"{last:.4}" + else: + strlast = str(last) + print( + f"{key} has gone {'down' if current None: + global failed + if current_dict[key] != 0: + failed = True + print(f"ERROR: {key} is {current_dict[key]}") + + +def check_type(args: argparse.Namespace, platform: str) -> int: + print("*" * 20, "\nChecking type completeness hasn't gone down...") + + res = run_pyright(platform) + current_result = json.loads(res.stdout) + py_typed_file: Path | None = None + + # check if py.typed file was missing + if ( + current_result["generalDiagnostics"] + and current_result["generalDiagnostics"][0]["message"] + == "No py.typed file found" + ): + print("creating py.typed") + py_typed_file = ( + Path(current_result["typeCompleteness"]["packageRootDirectory"]) + / "py.typed" + ) + py_typed_file.write_text("") + + res = run_pyright(platform) + current_result = json.loads(res.stdout) + + if res.stderr: + print(res.stderr) + + last_result = json.loads(get_result_file_name(platform).read_text()) + + for key in "errorCount", "warningCount", "informationCount": + check_zero(key, current_result["summary"]) + + for key, invert in ( + ("missingFunctionDocStringCount", False), + ("missingClassDocStringCount", False), + ("missingDefaultParamCount", False), + ("completenessScore", True), + ): + check_less_than( + key, + current_result["typeCompleteness"], + last_result["typeCompleteness"], + invert=invert, + ) + + for key, invert in ( + ("withUnknownType", False), + ("withAmbiguousType", False), + ("withKnownType", True), + ): + check_less_than( + key, + current_result["typeCompleteness"]["exportedSymbolCounts"], + last_result["typeCompleteness"]["exportedSymbolCounts"], + invert=invert, + ) + + if args.overwrite_file: + print("Overwriting file") + + # don't care about differences in time taken + del current_result["time"] + del current_result["summary"]["timeInSec"] + + # don't fail on version diff so pyright updates can be automerged + del current_result["version"] + + for key in ( + # don't save path (because that varies between machines) + "moduleRootDirectory", + "packageRootDirectory", + "pyTypedPath", + ): + del current_result["typeCompleteness"][key] + + # prune the symbols to only be the name of the symbols with + # errors, instead of saving a huge file. + new_symbols: list[dict[str, str]] = [] + for symbol in current_result["typeCompleteness"]["symbols"]: + if symbol["diagnostics"]: + # function name + message should be enough context for people! + new_symbols.extend( + {"name": symbol["name"], "message": diagnostic["message"]} + for diagnostic in symbol["diagnostics"] + ) + continue + + # Ensure order of arrays does not affect result. + new_symbols.sort(key=lambda module: module.get("name", "")) + current_result["generalDiagnostics"].sort() + current_result["typeCompleteness"]["modules"].sort( + key=lambda module: module.get("name", "") + ) + + del current_result["typeCompleteness"]["symbols"] + current_result["typeCompleteness"]["diagnostics"] = new_symbols + + with open(get_result_file_name(platform), "w") as file: + json.dump(current_result, file, sort_keys=True, indent=2) + # add newline at end of file so it's easier to manually modify + file.write("\n") + + if py_typed_file is not None: + print("deleting py.typed") + py_typed_file.unlink() + + print("*" * 20) + + return int(failed) + + +def main(args: argparse.Namespace) -> int: + res = 0 + for platform in "Linux", "Windows", "Darwin": + res += check_type(args, platform) + return res + + +parser = argparse.ArgumentParser() +parser.add_argument("--overwrite-file", action="store_true", default=False) +parser.add_argument("--full-diagnostics-file", type=Path, default=None) +args = parser.parse_args() + +assert __name__ == "__main__", "This script should be run standalone" +sys.exit(main(args)) diff --git a/trio/tests/module_with_deprecations.py b/trio/_tests/module_with_deprecations.py similarity index 100% rename from trio/tests/module_with_deprecations.py rename to trio/_tests/module_with_deprecations.py diff --git a/trio/tests/conftest.py b/trio/_tests/pytest_plugin.py similarity index 76% rename from trio/tests/conftest.py rename to trio/_tests/pytest_plugin.py index 772486e1eb..c6d73e25ea 100644 --- a/trio/tests/conftest.py +++ b/trio/_tests/pytest_plugin.py @@ -1,13 +1,8 @@ -# XX this does not belong here -- b/c it's here, these things only apply to -# the tests in trio/_core/tests, not in trio/tests. For now there's some -# copy-paste... -# -# this stuff should become a proper pytest plugin +import inspect import pytest -import inspect -from ..testing import trio_test, MockClock +from ..testing import MockClock, trio_test RUN_SLOW = True diff --git a/trio/tests/test_abc.py b/trio/_tests/test_abc.py similarity index 96% rename from trio/tests/test_abc.py rename to trio/_tests/test_abc.py index c445c97103..2b0b7088b0 100644 --- a/trio/tests/test_abc.py +++ b/trio/_tests/test_abc.py @@ -1,8 +1,6 @@ -import pytest - import attr +import pytest -from ..testing import assert_checkpoints from .. import abc as tabc diff --git a/trio/tests/test_channel.py b/trio/_tests/test_channel.py similarity index 99% rename from trio/tests/test_channel.py rename to trio/_tests/test_channel.py index aabb368799..4478c523f5 100644 --- a/trio/tests/test_channel.py +++ b/trio/_tests/test_channel.py @@ -1,8 +1,9 @@ import pytest -from ..testing import wait_all_tasks_blocked, assert_checkpoints import trio -from trio import open_memory_channel, EndOfChannel +from trio import EndOfChannel, open_memory_channel + +from ..testing import assert_checkpoints, wait_all_tasks_blocked async def test_channel(): diff --git a/trio/tests/test_contextvars.py b/trio/_tests/test_contextvars.py similarity index 70% rename from trio/tests/test_contextvars.py rename to trio/_tests/test_contextvars.py index 63853f5171..ae0c25f876 100644 --- a/trio/tests/test_contextvars.py +++ b/trio/_tests/test_contextvars.py @@ -1,15 +1,19 @@ +from __future__ import annotations + import contextvars from .. import _core -trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar") +trio_testing_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar( + "trio_testing_contextvar" +) -async def test_contextvars_default(): +async def test_contextvars_default() -> None: trio_testing_contextvar.set("main") - record = [] + record: list[str] = [] - async def child(): + async def child() -> None: value = trio_testing_contextvar.get() record.append(value) @@ -18,11 +22,11 @@ async def child(): assert record == ["main"] -async def test_contextvars_set(): +async def test_contextvars_set() -> None: trio_testing_contextvar.set("main") - record = [] + record: list[str] = [] - async def child(): + async def child() -> None: trio_testing_contextvar.set("child") value = trio_testing_contextvar.get() record.append(value) @@ -34,13 +38,13 @@ async def child(): assert value == "main" -async def test_contextvars_copy(): +async def test_contextvars_copy() -> None: trio_testing_contextvar.set("main") context = contextvars.copy_context() trio_testing_contextvar.set("second_main") - record = [] + record: list[str] = [] - async def child(): + async def child() -> None: value = trio_testing_contextvar.get() record.append(value) diff --git a/trio/tests/test_deprecate.py b/trio/_tests/test_deprecate.py similarity index 88% rename from trio/tests/test_deprecate.py rename to trio/_tests/test_deprecate.py index e5e1da8c5f..33c05ffd25 100644 --- a/trio/tests/test_deprecate.py +++ b/trio/_tests/test_deprecate.py @@ -1,15 +1,14 @@ -import pytest - import inspect import warnings +import pytest + from .._deprecate import ( TrioDeprecationWarning, - warn_deprecated, deprecated, deprecated_alias, + warn_deprecated, ) - from . import module_with_deprecations @@ -241,3 +240,32 @@ def test_module_with_deprecations(recwarn_always): with pytest.raises(AttributeError): module_with_deprecations.asdf + + +def test_tests_is_deprecated1() -> None: + with pytest.warns(TrioDeprecationWarning): + from trio import tests # warning on import + + # warning on access of any member + with pytest.warns(TrioDeprecationWarning): + assert tests.test_abc # type: ignore[attr-defined] + + +def test_tests_is_deprecated2() -> None: + # warning on direct import of test since that accesses `__spec__` + with pytest.warns(TrioDeprecationWarning): + import trio.tests + + with pytest.warns(TrioDeprecationWarning): + assert trio.tests.test_deprecate # type: ignore[attr-defined] + + +def test_tests_is_deprecated3() -> None: + import trio + + # no warning on accessing the submodule + assert trio.tests + + # only when accessing a submodule member + with pytest.warns(TrioDeprecationWarning): + assert trio.tests.test_abc # type: ignore[attr-defined] diff --git a/trio/tests/test_dtls.py b/trio/_tests/test_dtls.py similarity index 98% rename from trio/tests/test_dtls.py rename to trio/_tests/test_dtls.py index 680a8793eb..8cb06ccb3d 100644 --- a/trio/tests/test_dtls.py +++ b/trio/_tests/test_dtls.py @@ -1,25 +1,26 @@ -import pytest -import trio -import trio.testing -from trio import DTLSEndpoint import random -import attr from contextlib import asynccontextmanager from itertools import count +import attr +import pytest import trustme from OpenSSL import SSL +import trio +import trio.testing +from trio import DTLSEndpoint from trio.testing._fake_net import FakeNet -from .._core.tests.tutil import slow, binds_ipv6, gc_collect_harder + +from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow ca = trustme.CA() server_cert = ca.issue_cert("example.com") -server_ctx = SSL.Context(SSL.DTLS_METHOD) +server_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined] server_cert.configure_cert(server_ctx) -client_ctx = SSL.Context(SSL.DTLS_METHOD) +client_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined] ca.configure_trust(client_ctx) @@ -101,10 +102,12 @@ async def test_smoke(ipv6): @slow async def test_handshake_over_terrible_network(autojump_clock): - HANDSHAKES = 1000 + HANDSHAKES = 100 r = random.Random(0) fn = FakeNet() fn.enable() + # avoid spurious timeouts on slow machines + autojump_clock.autojump_threshold = 0.001 async with dtls_echo_server() as (_, address): async with trio.open_nursery() as nursery: @@ -333,13 +336,13 @@ async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): fn.enable() from trio._dtls import ( - Record, - encode_record, - HandshakeFragment, - encode_handshake_fragment, ContentType, + HandshakeFragment, HandshakeType, ProtocolVersion, + Record, + encode_handshake_fragment, + encode_record, ) client_hello = encode_record( @@ -444,7 +447,7 @@ async def test_invalid_cookie_rejected(autojump_clock): fn = FakeNet() fn.enable() - from trio._dtls import decode_client_hello_untrusted, BadPacket + from trio._dtls import BadPacket, decode_client_hello_untrusted with trio.CancelScope() as cscope: # the first 11 bytes of ClientHello aren't protected by the cookie, so only test diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py new file mode 100644 index 0000000000..2f1157db06 --- /dev/null +++ b/trio/_tests/test_exports.py @@ -0,0 +1,522 @@ +from __future__ import annotations # isort: split +import __future__ # Regular import, not special! + +import enum +import functools +import importlib +import inspect +import json +import socket as stdlib_socket +import sys +from pathlib import Path +from types import ModuleType +from typing import Protocol + +import attrs +import pytest + +import trio +import trio.testing + +from .. import _core, _util +from .._core._tests.tutil import slow +from .pytest_plugin import RUN_SLOW + +mypy_cache_updated = False + + +try: # If installed, check both versions of this class. + from typing_extensions import Protocol as Protocol_ext +except ImportError: # pragma: no cover + Protocol_ext = Protocol # type: ignore[assignment] + + +def _ensure_mypy_cache_updated(): + # This pollutes the `empty` dir. Should this be changed? + from mypy.api import run + + global mypy_cache_updated + if not mypy_cache_updated: + # mypy cache was *probably* already updated by the other tests, + # but `pytest -k ...` might run just this test on its own + result = run( + [ + "--config-file=", + "--cache-dir=./.mypy_cache", + "--no-error-summary", + "-c", + "import trio", + ] + ) + assert not result[1] # stderr + assert not result[0] # stdout + mypy_cache_updated = True + + +def test_core_is_properly_reexported(): + # Each export from _core should be re-exported by exactly one of these + # three modules: + sources = [trio, trio.lowlevel, trio.testing] + for symbol in dir(_core): + if symbol.startswith("_"): + continue + found = 0 + for source in sources: + if symbol in dir(source) and getattr(source, symbol) is getattr( + _core, symbol + ): + found += 1 + print(symbol, found) + assert found == 1 + + +def public_modules(module): + yield module + for name, class_ in module.__dict__.items(): + if name.startswith("_"): # pragma: no cover + continue + if not isinstance(class_, ModuleType): + continue + if not class_.__name__.startswith(module.__name__): # pragma: no cover + continue + if class_ is module: # pragma: no cover + continue + yield from public_modules(class_) + + +PUBLIC_MODULES = list(public_modules(trio)) +PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES] + + +# It doesn't make sense for downstream redistributors to run this test, since +# they might be using a newer version of Python with additional symbols which +# won't be reflected in trio.socket, and this shouldn't cause downstream test +# runs to start failing. +@pytest.mark.redistributors_should_skip +# Static analysis tools often have trouble with alpha releases, where Python's +# internals are in flux, grammar may not have settled down, etc. +@pytest.mark.skipif( + sys.version_info.releaselevel == "alpha", + reason="skip static introspection tools on Python dev/alpha releases", +) +@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES) +@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"]) +@pytest.mark.filterwarnings( + # https://github.com/pypa/setuptools/issues/3274 + "ignore:module 'sre_constants' is deprecated:DeprecationWarning", +) +def test_static_tool_sees_all_symbols(tool, modname, tmpdir): + module = importlib.import_module(modname) + + def no_underscores(symbols): + return {symbol for symbol in symbols if not symbol.startswith("_")} + + runtime_names = no_underscores(dir(module)) + + # ignore deprecated module `tests` being invisible + if modname == "trio": + runtime_names.discard("tests") + + # Ignore any __future__ feature objects, if imported under that name. + for name in __future__.all_feature_names: + if getattr(module, name, None) is getattr(__future__, name): + runtime_names.remove(name) + + if tool in ("mypy", "pyright_verifytypes"): + # create py.typed file + py_typed_path = Path(trio.__file__).parent / "py.typed" + py_typed_exists = py_typed_path.exists() + if not py_typed_exists: # pragma: no branch + py_typed_path.write_text("") + + if tool == "pylint": + from pylint.lint import PyLinter + + linter = PyLinter() + ast = linter.get_ast(module.__file__, modname) + static_names = no_underscores(ast) + elif tool == "jedi": + import jedi + + # Simulate typing "import trio; trio." + script = jedi.Script(f"import {modname}; {modname}.") + completions = script.complete() + static_names = no_underscores(c.name for c in completions) + elif tool == "mypy": + if not RUN_SLOW: # pragma: no cover + pytest.skip("use --run-slow to check against mypy") + if sys.implementation.name != "cpython": + pytest.skip("mypy not installed in tests on pypy") + + cache = Path.cwd() / ".mypy_cache" + + _ensure_mypy_cache_updated() + + trio_cache = next(cache.glob("*/trio")) + _, modname = (modname + ".").split(".", 1) + modname = modname[:-1] + mod_cache = trio_cache / modname if modname else trio_cache + if mod_cache.is_dir(): + mod_cache = mod_cache / "__init__.data.json" + else: + mod_cache = trio_cache / (modname + ".data.json") + + assert mod_cache.exists() and mod_cache.is_file() + with mod_cache.open() as cache_file: + cache_json = json.loads(cache_file.read()) + static_names = no_underscores( + key + for key, value in cache_json["names"].items() + if not key.startswith(".") and value["kind"] == "Gdef" + ) + elif tool == "pyright_verifytypes": + if not RUN_SLOW: # pragma: no cover + pytest.skip("use --run-slow to check against mypy") + import subprocess + + res = subprocess.run( + ["pyright", f"--verifytypes={modname}", "--outputjson"], + capture_output=True, + ) + current_result = json.loads(res.stdout) + + static_names = { + x["name"][len(modname) + 1 :] + for x in current_result["typeCompleteness"]["symbols"] + if x["name"].startswith(modname) + } + + # pyright ignores the symbol defined behind `if False` + if modname == "trio": + static_names.add("testing") + + # these are hidden behind `if sys.platform != "win32" or not TYPE_CHECKING` + # so presumably pyright is parsing that if statement, in which case we don't + # care about them being missing. + if modname == "trio.socket" and sys.platform == "win32": + ignored_missing_names = {"if_indextoname", "if_nameindex", "if_nametoindex"} + assert static_names.isdisjoint(ignored_missing_names) + static_names.update(ignored_missing_names) + + else: # pragma: no cover + assert False + + # remove py.typed file + if tool in ("mypy", "pyright_verifytypes") and not py_typed_exists: + py_typed_path.unlink() + + # mypy handles errors with an `assert` in its branch + if tool == "mypy": + return + + # It's expected that the static set will contain more names than the + # runtime set: + # - static tools are sometimes sloppy and include deleted names + # - some symbols are platform-specific at runtime, but always show up in + # static analysis (e.g. in trio.socket or trio.lowlevel) + # So we check that the runtime names are a subset of the static names. + missing_names = runtime_names - static_names + + # ignore warnings about deprecated module tests + missing_names -= {"tests"} + + if missing_names: # pragma: no cover + print(f"{tool} can't see the following names in {modname}:") + print() + for name in sorted(missing_names): + print(f" {name}") + assert False + + +# this could be sped up by only invoking mypy once per module, or even once for all +# modules, instead of once per class. +@slow +# see comment on test_static_tool_sees_all_symbols +@pytest.mark.redistributors_should_skip +# Static analysis tools often have trouble with alpha releases, where Python's +# internals are in flux, grammar may not have settled down, etc. +@pytest.mark.skipif( + sys.version_info.releaselevel == "alpha", + reason="skip static introspection tools on Python dev/alpha releases", +) +@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES) +@pytest.mark.parametrize("tool", ["jedi", "mypy"]) +def test_static_tool_sees_class_members( + tool: str, module_name: str, tmpdir: Path +) -> None: + module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)] + + # ignore hidden, but not dunder, symbols + def no_hidden(symbols): + return { + symbol + for symbol in symbols + if (not symbol.startswith("_")) or symbol.startswith("__") + } + + py_typed_path = Path(trio.__file__).parent / "py.typed" + py_typed_exists = py_typed_path.exists() + + if tool == "mypy": + if sys.implementation.name != "cpython": + pytest.skip("mypy not installed in tests on pypy") + # create py.typed file + # remove this logic when trio is marked with py.typed proper + if not py_typed_exists: # pragma: no branch + py_typed_path.write_text("") + + cache = Path.cwd() / ".mypy_cache" + + _ensure_mypy_cache_updated() + + trio_cache = next(cache.glob("*/trio")) + modname = module_name + _, modname = (modname + ".").split(".", 1) + modname = modname[:-1] + mod_cache = trio_cache / modname if modname else trio_cache + if mod_cache.is_dir(): + mod_cache = mod_cache / "__init__.data.json" + else: + mod_cache = trio_cache / (modname + ".data.json") + + assert mod_cache.exists() and mod_cache.is_file() + with mod_cache.open() as cache_file: + cache_json = json.loads(cache_file.read()) + + # skip a bunch of file-system activity (probably can un-memoize?) + @functools.lru_cache + def lookup_symbol(symbol): + topname, *modname, name = symbol.split(".") + version = next(cache.glob("3.*/")) + mod_cache = version / topname + if not mod_cache.is_dir(): + mod_cache = version / (topname + ".data.json") + + if modname: + for piece in modname[:-1]: + mod_cache /= piece + next_cache = mod_cache / modname[-1] + if next_cache.is_dir(): + mod_cache = next_cache / "__init__.data.json" + else: + mod_cache = mod_cache / (modname[-1] + ".data.json") + + with mod_cache.open() as f: + return json.loads(f.read())["names"][name] + + errors: dict[str, object] = {} + for class_name, class_ in module.__dict__.items(): + if not isinstance(class_, type): + continue + if module_name == "trio.socket" and class_name in dir(stdlib_socket): + continue + # Deprecated classes are exported with a leading underscore + # We don't care about errors in _MultiError as that's on its way out anyway + if class_name.startswith("_"): # pragma: no cover + continue + + # dir() and inspect.getmembers doesn't display properties from the metaclass + # also ignore some dunder methods that tend to differ but are of no consequence + ignore_names = set(dir(type(class_))) | { + "__annotations__", + "__attrs_attrs__", + "__attrs_own_setattr__", + "__callable_proto_members_only__", + "__class_getitem__", + "__final__", + "__getstate__", + "__match_args__", + "__order__", + "__orig_bases__", + "__parameters__", + "__protocol_attrs__", + "__setstate__", + "__slots__", + "__weakref__", + } + + # pypy seems to have some additional dunders that differ + if sys.implementation.name == "pypy": + ignore_names |= { + "__basicsize__", + "__dictoffset__", + "__itemsize__", + "__sizeof__", + "__weakrefoffset__", + "__unicode__", + } + + # inspect.getmembers sees `name` and `value` in Enums, otherwise + # it behaves the same way as `dir` + # runtime_names = no_underscores(dir(class_)) + runtime_names = ( + no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names + ) + + if tool == "jedi": + import jedi + + script = jedi.Script( + f"from {module_name} import {class_name}; {class_name}." + ) + completions = script.complete() + static_names = no_hidden(c.name for c in completions) - ignore_names + + elif tool == "mypy": + # load the cached type information + cached_type_info = cache_json["names"][class_name] + if "node" not in cached_type_info: + cached_type_info = lookup_symbol(cached_type_info["cross_ref"]) + + assert "node" in cached_type_info + node = cached_type_info["node"] + static_names = no_hidden(k for k in node["names"] if not k.startswith(".")) + for symbol in node["mro"][1:]: + node = lookup_symbol(symbol)["node"] + static_names |= no_hidden( + k for k in node["names"] if not k.startswith(".") + ) + static_names -= ignore_names + + else: # pragma: no cover + assert False, "unknown tool" + + missing = runtime_names - static_names + extra = static_names - runtime_names + + # using .remove() instead of .delete() to get an error in case they start not + # being missing + + if ( + tool == "jedi" + and BaseException in class_.__mro__ + and sys.version_info >= (3, 11) + ): + missing.remove("add_note") + + if ( + tool == "mypy" + and BaseException in class_.__mro__ + and sys.version_info >= (3, 11) + ): + extra.remove("__notes__") + + if tool == "mypy" and attrs.has(class_): + # e.g. __trio__core__run_CancelScope_AttrsAttributes__ + before = len(extra) + extra = {e for e in extra if not e.endswith("AttrsAttributes__")} + assert len(extra) == before - 1 + + # TODO: this *should* be visible via `dir`!! + if tool == "mypy" and class_ == trio.Nursery: + extra.remove("cancel_scope") + + # TODO: I'm not so sure about these, but should still be looked at. + EXTRAS = { + trio.DTLSChannel: {"peer_address", "endpoint"}, + trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"}, + trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"}, + trio.SSLListener: {"transport_listener"}, + trio.SSLStream: {"transport_stream"}, + trio.SocketListener: {"socket"}, + trio.SocketStream: {"socket"}, + trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"}, + trio.testing.MemorySendStream: { + "close_hook", + "send_all_hook", + "wait_send_all_might_not_block_hook", + }, + } + if tool == "mypy" and class_ in EXTRAS: + before = len(extra) + extra -= EXTRAS[class_] + assert len(extra) == before - len(EXTRAS[class_]) + + # probably an issue with mypy.... + if tool == "mypy" and class_ == trio.Path and sys.platform == "win32": + before = len(missing) + missing -= {"owner", "group", "is_mount"} + assert len(missing) == before - 3 + + # TODO: why is this? Is it a problem? + # see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916 + if class_ == trio.StapledStream: + extra.remove("receive_stream") + extra.remove("send_stream") + + # I have not researched why these are missing, should maybe create an issue + # upstream with jedi + if tool == "jedi" and sys.version_info >= (3, 11): + if class_ in ( + trio.DTLSChannel, + trio.MemoryReceiveChannel, + trio.MemorySendChannel, + trio.SSLListener, + trio.SocketListener, + ): + missing.remove("__aenter__") + missing.remove("__aexit__") + if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel): + missing.remove("__aiter__") + missing.remove("__anext__") + + # intentionally hidden behind type guard + if class_ == trio.Path: + missing.remove("__getattr__") + + if missing or extra: # pragma: no cover + errors[f"{module_name}.{class_name}"] = { + "missing": missing, + "extra": extra, + } + + # clean up created py.typed file + if tool == "mypy" and not py_typed_exists: + py_typed_path.unlink() + + # `assert not errors` will not print the full content of errors, even with + # `--verbose`, so we manually print it + if errors: # pragma: no cover + from pprint import pprint + + print(f"\n{tool} can't see the following symbols in {module_name}:") + pprint(errors) + assert not errors + + +def test_classes_are_final() -> None: + for module in PUBLIC_MODULES: + for name, class_ in module.__dict__.items(): + if not isinstance(class_, type): + continue + # Deprecated classes are exported with a leading underscore + if name.startswith("_"): # pragma: no cover + continue + + # Abstract classes can be subclassed, because that's the whole + # point of ABCs + if inspect.isabstract(class_): + continue + # Same with protocols, but only direct children. + if Protocol in class_.__bases__ or Protocol_ext in class_.__bases__: + continue + # Exceptions are allowed to be subclassed, because exception + # subclassing isn't used to inherit behavior. + if issubclass(class_, BaseException): + continue + # These are classes that are conceptually abstract, but + # inspect.isabstract returns False for boring reasons. + if class_ is trio.abc.Instrument or class_ is trio.socket.SocketType: + continue + # Enums have their own metaclass, so we can't use our metaclasses. + # And I don't think there's a lot of risk from people subclassing + # enums... + if issubclass(class_, enum.Enum): + continue + # ... insert other special cases here ... + + # don't care about the *Statistics classes + if name.endswith("Statistics"): + continue + + assert isinstance(class_, _util.Final) diff --git a/trio/tests/test_fakenet.py b/trio/_tests/test_fakenet.py similarity index 77% rename from trio/tests/test_fakenet.py rename to trio/_tests/test_fakenet.py index bc691c9db5..d250a105a3 100644 --- a/trio/tests/test_fakenet.py +++ b/trio/_tests/test_fakenet.py @@ -1,16 +1,18 @@ +import errno + import pytest import trio from trio.testing._fake_net import FakeNet -def fn(): +def fn() -> FakeNet: fn = FakeNet() fn.enable() return fn -async def test_basic_udp(): +async def test_basic_udp() -> None: fn() s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) @@ -19,6 +21,11 @@ async def test_basic_udp(): ip, port = s1.getsockname() assert ip == "127.0.0.1" assert port != 0 + + with pytest.raises(OSError) as exc: # Cannot rebind. + await s1.bind(("192.0.2.1", 0)) + assert exc.value.errno == errno.EINVAL + await s2.sendto(b"xyz", s1.getsockname()) data, addr = await s1.recvfrom(10) assert data == b"xyz" @@ -29,7 +36,7 @@ async def test_basic_udp(): assert addr == s1.getsockname() -async def test_msg_trunc(): +async def test_msg_trunc() -> None: fn() s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) @@ -38,7 +45,7 @@ async def test_msg_trunc(): data, addr = await s1.recvfrom(10) -async def test_basic_tcp(): +async def test_basic_tcp() -> None: fn() with pytest.raises(NotImplementedError): trio.socket.socket() diff --git a/trio/tests/test_file_io.py b/trio/_tests/test_file_io.py similarity index 73% rename from trio/tests/test_file_io.py rename to trio/_tests/test_file_io.py index dcbd1a63bb..bae426cf48 100644 --- a/trio/tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -1,13 +1,16 @@ +import importlib import io import os - -import pytest +import re +from typing import List, Tuple from unittest import mock from unittest.mock import sentinel +import pytest + import trio -from trio import _core -from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS +from trio import _core, _file_io +from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper @pytest.fixture @@ -78,6 +81,46 @@ def unsupported_attr(self): # pragma: no cover getattr(async_file, "unsupported_attr") +def test_type_stubs_match_lists() -> None: + """Check the manual stubs match the list of wrapped methods.""" + # Fetch the module's source code. + assert _file_io.__spec__ is not None + loader = _file_io.__spec__.loader + assert isinstance(loader, importlib.abc.SourceLoader) + source = io.StringIO(loader.get_source("trio._file_io")) + + # Find the class, then find the TYPE_CHECKING block. + for line in source: + if "class AsyncIOWrapper" in line: + break + else: # pragma: no cover - should always find this + pytest.fail("No class definition line?") + + for line in source: + if "if TYPE_CHECKING" in line: + break + else: # pragma: no cover - should always find this + pytest.fail("No TYPE CHECKING line?") + + # Now we should be at the type checking block. + found: List[Tuple[str, str]] = [] + for line in source: # pragma: no branch - expected to break early + if line.strip() and not line.startswith(" " * 8): + break # Dedented out of the if TYPE_CHECKING block. + match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line) + if match is not None: + kind = "async" if match.group(1) is not None else "sync" + found.append((match.group(2), kind)) + + # Compare two lists so that we can easily see duplicates, and see what is different overall. + expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS] + expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS] + # Ignore order, error if duplicates are present. + found.sort() + expected.sort() + assert found == expected + + def test_sync_attrs_forwarded(async_file, wrapped): for attr_name in _FILE_SYNC_ATTRS: if attr_name not in dir(async_file): diff --git a/trio/tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py similarity index 98% rename from trio/tests/test_highlevel_generic.py rename to trio/_tests/test_highlevel_generic.py index df2b2cecf7..38bcedee25 100644 --- a/trio/tests/test_highlevel_generic.py +++ b/trio/_tests/test_highlevel_generic.py @@ -1,9 +1,8 @@ -import pytest - import attr +import pytest -from ..abc import SendStream, ReceiveStream from .._highlevel_generic import StapledStream +from ..abc import ReceiveStream, SendStream @attr.s diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py similarity index 94% rename from trio/tests/test_highlevel_open_tcp_listeners.py rename to trio/_tests/test_highlevel_open_tcp_listeners.py index 0c38b4ca69..6eca844f0c 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -1,17 +1,17 @@ -import sys - -import pytest - -import socket as stdlib_socket import errno +import socket as stdlib_socket +import sys +from math import inf import attr +import pytest import trio -from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream +from trio import SocketListener, open_tcp_listeners, open_tcp_stream, serve_tcp from trio.testing import open_stream_to_socket_listener + from .. import socket as tsocket -from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6 +from .._core._tests.tutil import binds_ipv6 if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup @@ -290,6 +290,7 @@ async def test_open_tcp_listeners_backlog(): tsocket.set_custom_socket_factory(fsf) for given, expected in [ (None, 0xFFFF), + (inf, 0xFFFF), (99999999, 0xFFFF), (10, 10), (1, 1), @@ -298,3 +299,13 @@ async def test_open_tcp_listeners_backlog(): assert listeners for listener in listeners: assert listener.socket.backlog == expected + + +async def test_open_tcp_listeners_backlog_float_error(): + fsf = FakeSocketFactory(99) + tsocket.set_custom_socket_factory(fsf) + for should_fail in (0.0, 2.18, 3.14, 9.75): + with pytest.raises( + ValueError, match=f"Only accepts infinity, not {should_fail!r}" + ): + await open_tcp_listeners(0, backlog=should_fail) diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py similarity index 99% rename from trio/tests/test_highlevel_open_tcp_stream.py rename to trio/_tests/test_highlevel_open_tcp_stream.py index 35ddd3e118..24f82bddd5 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -1,17 +1,17 @@ -import pytest -import sys import socket +import sys import attr +import pytest import trio -from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP from trio._highlevel_open_tcp_stream import ( - reorder_for_rfc_6555_section_5_4, close_all, - open_tcp_stream, format_host_port, + open_tcp_stream, + reorder_for_rfc_6555_section_5_4, ) +from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py similarity index 97% rename from trio/tests/test_highlevel_open_unix_stream.py rename to trio/_tests/test_highlevel_open_unix_stream.py index 211aff3e70..64a15f9e9d 100644 --- a/trio/tests/test_highlevel_open_unix_stream.py +++ b/trio/_tests/test_highlevel_open_unix_stream.py @@ -4,7 +4,7 @@ import pytest -from trio import open_unix_socket, Path +from trio import Path, open_unix_socket from trio._highlevel_open_unix_stream import close_on_error if not hasattr(socket, "AF_UNIX"): diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py similarity index 97% rename from trio/tests/test_highlevel_serve_listeners.py rename to trio/_tests/test_highlevel_serve_listeners.py index b028092eb9..65804f4222 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -1,9 +1,8 @@ -import pytest - -from functools import partial import errno +from functools import partial import attr +import pytest import trio from trio.testing import memory_stream_pair, wait_all_tasks_blocked @@ -13,7 +12,9 @@ class MemoryListener(trio.abc.Listener): closed = attr.ib(default=False) accepted_streams = attr.ib(factory=list) - queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1))) + queued_streams = attr.ib( + factory=(lambda: trio.open_memory_channel[trio.StapledStream](1)) + ) accept_hook = attr.ib(default=None) async def connect(self): diff --git a/trio/tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py similarity index 98% rename from trio/tests/test_highlevel_socket.py rename to trio/_tests/test_highlevel_socket.py index 9dcb834d2c..1a987df3f3 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/_tests/test_highlevel_socket.py @@ -1,17 +1,17 @@ -import pytest - -import sys -import socket as stdlib_socket import errno +import socket as stdlib_socket +import sys -from .. import _core +import pytest + +from .. import _core, socket as tsocket +from .._highlevel_socket import * from ..testing import ( + assert_checkpoints, check_half_closeable_stream, wait_all_tasks_blocked, - assert_checkpoints, ) -from .._highlevel_socket import * -from .. import socket as tsocket +from .test_socket import setsockopt_tests async def test_SocketStream_basics(): @@ -51,6 +51,8 @@ async def test_SocketStream_basics(): b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) assert isinstance(b, bytes) + setsockopt_tests(s) + async def test_SocketStream_send_all(): BIG = 10000000 diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py similarity index 95% rename from trio/tests/test_highlevel_ssl_helpers.py rename to trio/_tests/test_highlevel_ssl_helpers.py index c00f5dc464..f6eda0b578 100644 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ b/trio/_tests/test_highlevel_ssl_helpers.py @@ -1,20 +1,21 @@ -import pytest - from functools import partial import attr +import pytest import trio -from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP import trio.testing -from .test_ssl import client_ctx, SERVER_CTX +from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM from .._highlevel_ssl_helpers import ( - open_ssl_over_tcp_stream, open_ssl_over_tcp_listeners, + open_ssl_over_tcp_stream, serve_ssl_over_tcp, ) +# noqa is needed because flake8 doesn't understand how pytest fixtures work. +from .test_ssl import SERVER_CTX, client_ctx # noqa: F401 + async def echo_handler(stream): async with stream: diff --git a/trio/tests/test_path.py b/trio/_tests/test_path.py similarity index 100% rename from trio/tests/test_path.py rename to trio/_tests/test_path.py index b4345e4d55..bfef1aaf2c 100644 --- a/trio/tests/test_path.py +++ b/trio/_tests/test_path.py @@ -4,8 +4,8 @@ import pytest import trio -from trio._path import AsyncAutoWrapperType as Type from trio._file_io import AsyncIOWrapper +from trio._path import AsyncAutoWrapperType as Type @pytest.fixture diff --git a/trio/tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py similarity index 100% rename from trio/tests/test_scheduler_determinism.py rename to trio/_tests/test_scheduler_determinism.py diff --git a/trio/tests/test_signals.py b/trio/_tests/test_signals.py similarity index 99% rename from trio/tests/test_signals.py rename to trio/_tests/test_signals.py index 235772f900..313cce259f 100644 --- a/trio/tests/test_signals.py +++ b/trio/_tests/test_signals.py @@ -3,9 +3,10 @@ import pytest import trio + from .. import _core +from .._signals import _signal_handler, open_signal_receiver from .._util import signal_raise -from .._signals import open_signal_receiver, _signal_handler async def test_open_signal_receiver(): diff --git a/trio/tests/test_socket.py b/trio/_tests/test_socket.py similarity index 96% rename from trio/tests/test_socket.py rename to trio/_tests/test_socket.py index db21096fac..036098b8e5 100644 --- a/trio/tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -1,17 +1,15 @@ import errno - -import pytest -import attr - +import inspect import os import socket as stdlib_socket -import inspect +import sys import tempfile -import sys as _sys -from .._core.tests.tutil import creates_ipv6, binds_ipv6 -from .. import _core -from .. import _socket as _tsocket -from .. import socket as tsocket + +import attr +import pytest + +from .. import _core, socket as tsocket +from .._core._tests.tutil import binds_ipv6, creates_ipv6 from .._socket import _NUMERIC_ONLY, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked @@ -279,7 +277,7 @@ async def test_socket_v6(): assert s.family == tsocket.AF_INET6 -@pytest.mark.skipif(not _sys.platform == "linux", reason="linux only") +@pytest.mark.skipif(not sys.platform == "linux", reason="linux only") async def test_sniff_sockopts(): from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM @@ -356,12 +354,37 @@ async def test_SocketType_basics(): # type family proto stdlib_sock = stdlib_socket.socket() sock = tsocket.from_stdlib_socket(stdlib_sock) - assert sock.type == _tsocket.real_socket_type(stdlib_sock.type) + assert sock.type == stdlib_sock.type assert sock.family == stdlib_sock.family assert sock.proto == stdlib_sock.proto sock.close() +async def test_SocketType_setsockopt() -> None: + sock = tsocket.socket() + with sock as _: + setsockopt_tests(sock) + + +def setsockopt_tests(sock): + """Extract these out, to be reused for SocketStream also.""" + # specifying optlen. Not supported on pypy, and I couldn't find + # valid calls on darwin or win32. + if hasattr(tsocket, "SO_BINDTODEVICE"): + sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0) + + # specifying value + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) + + # specifying both + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) + + # specifying neither + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) + + async def test_SocketType_dup(): a, b = tsocket.socketpair() with a, b: diff --git a/trio/tests/test_ssl.py b/trio/_tests/test_ssl.py similarity index 96% rename from trio/tests/test_ssl.py rename to trio/_tests/test_ssl.py index 26e107e08f..f91cea8549 100644 --- a/trio/tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -1,45 +1,35 @@ from __future__ import annotations import os -import re -import sys -from typing import TYPE_CHECKING - -import pytest - -import threading import socket as stdlib_socket import ssl +import sys +import threading from contextlib import asynccontextmanager, contextmanager from functools import partial -from OpenSSL import SSL +import pytest import trustme +from OpenSSL import SSL import trio -from .. import _core -from .._highlevel_socket import SocketStream, SocketListener + +from .. import _core, socket as tsocket +from .._core import BrokenResourceError, ClosedResourceError +from .._core._tests.tutil import slow from .._highlevel_generic import aclose_forcefully -from .._core import ClosedResourceError, BrokenResourceError from .._highlevel_open_tcp_stream import open_tcp_stream -from .. import socket as tsocket -from .._ssl import SSLStream, SSLListener, NeedHandshakeError, _is_eof +from .._highlevel_socket import SocketListener, SocketStream +from .._ssl import NeedHandshakeError, SSLListener, SSLStream, _is_eof from .._util import ConflictDetector - -from .._core.tests.tutil import slow - from ..testing import ( - assert_checkpoints, Sequencer, - memory_stream_pair, - lockstep_stream_pair, + assert_checkpoints, check_two_way_stream, + lockstep_stream_pair, + memory_stream_pair, ) -if TYPE_CHECKING: - from _pytest.mark import MarkDecorator - - # We have two different kinds of echo server fixtures we use for testing. The # first is a real server written using the stdlib ssl module and blocking # sockets. It runs in a thread and we talk to it over a real socketpair(), to @@ -68,12 +58,6 @@ TRIO_TEST_1_CERT.configure_cert(SERVER_CTX) -skip_on_broken_openssl: MarkDecorator = pytest.mark.skipif( - sys.version_info < (3, 8) and ssl.OPENSSL_VERSION_INFO[0] > 1, - reason="Python 3.7 does not work with OpenSSL versions higher than 1.X", -) - - # TLS 1.3 has a lot of changes from previous versions. So we want to run tests # with both TLS 1.3, and TLS 1.2. # "tls13" means that we're willing to negotiate TLS 1.3. Usually that's @@ -116,22 +100,6 @@ def ssl_echo_serve_sync(sock, *, expect_fail=False): wrapped.unwrap() except exceptions: pass - except ssl.SSLWantWriteError: # pragma: no cover - # Under unclear conditions, CPython sometimes raises - # SSLWantWriteError here. This is a bug (bpo-32219), - # but it's not our bug. Christian Heimes thinks - # it's fixed in 'recent' CPython versions so we fail - # the test for those and ignore it for earlier - # versions. - if ( - sys.implementation.name != "cpython" - or sys.version_info >= (3, 8) - ): - pytest.fail( - "still an issue on recent python versions " - "add a comment to " - "https://bugs.python.org/issue32219" - ) return wrapped.sendall(data) # This is an obscure workaround for an openssl bug. In server mode, in @@ -822,7 +790,6 @@ async def test_send_all_empty_string(client_ctx): await s.aclose() -@skip_on_broken_openssl @pytest.mark.parametrize("https_compatible", [False, True]) async def test_SSLStream_generic(client_ctx, https_compatible): async def stream_maker(): @@ -1038,7 +1005,6 @@ async def test_ssl_bad_shutdown(client_ctx): await server.aclose() -@skip_on_broken_openssl async def test_ssl_bad_shutdown_but_its_ok(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, @@ -1103,7 +1069,6 @@ def close_hook(): assert transport_close_count == 1 -@skip_on_broken_openssl async def test_ssl_https_compatibility_disagreement(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, @@ -1128,7 +1093,6 @@ async def receive_and_expect_error(): nursery.start_soon(receive_and_expect_error) -@skip_on_broken_openssl async def test_https_mode_eof_before_handshake(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, diff --git a/trio/tests/test_subprocess.py b/trio/_tests/test_subprocess.py similarity index 98% rename from trio/tests/test_subprocess.py rename to trio/_tests/test_subprocess.py index e2d66f654d..7986dfd71e 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import random import signal @@ -6,6 +8,7 @@ from contextlib import asynccontextmanager from functools import partial from pathlib import Path as SyncPath +from typing import TYPE_CHECKING import pytest @@ -20,12 +23,19 @@ sleep, sleep_forever, ) -from .._core.tests.tutil import skip_if_fbsd_pipes_broken, slow +from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow from ..lowlevel import open_process from ..testing import assert_no_checkpoints, wait_all_tasks_blocked +if TYPE_CHECKING: + ... + from signal import Signals + posix = os.name == "posix" -if posix: +SIGKILL: Signals | None +SIGTERM: Signals | None +SIGUSR1: Signals | None +if (not TYPE_CHECKING and posix) or sys.platform != "win32": from signal import SIGKILL, SIGTERM, SIGUSR1 else: SIGKILL, SIGTERM, SIGUSR1 = None, None, None @@ -574,7 +584,7 @@ async def test_for_leaking_fds(): async def test_subprocess_pidfd_unnotified(): noticed_exit = None - async def wait_and_tell(proc) -> None: + async def wait_and_tell(proc: Process) -> None: nonlocal noticed_exit noticed_exit = Event() await proc.wait() diff --git a/trio/tests/test_sync.py b/trio/_tests/test_sync.py similarity index 99% rename from trio/tests/test_sync.py rename to trio/_tests/test_sync.py index 33f79c4df2..7de42b86f9 100644 --- a/trio/tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -1,13 +1,11 @@ -import pytest - import weakref -from ..testing import wait_all_tasks_blocked, assert_checkpoints +import pytest from .. import _core -from .. import _timeouts -from .._timeouts import sleep_forever, move_on_after from .._sync import * +from .._timeouts import sleep_forever +from ..testing import assert_checkpoints, wait_all_tasks_blocked async def test_Event(): @@ -401,8 +399,8 @@ async def waiter(i): assert c.locked() -from .._sync import AsyncContextManagerMixin from .._channel import open_memory_channel +from .._sync import AsyncContextManagerMixin # Three ways of implementing a Lock in terms of a channel. Used to let us put # the channel through the generic lock tests. diff --git a/trio/tests/test_testing.py b/trio/_tests/test_testing.py similarity index 99% rename from trio/tests/test_testing.py rename to trio/_tests/test_testing.py index a2dba728d5..3b5a57d3ec 100644 --- a/trio/tests/test_testing.py +++ b/trio/_tests/test_testing.py @@ -4,15 +4,13 @@ import pytest -from .._core.tests.tutil import can_bind_ipv6 -from .. import sleep -from .. import _core +from .. import _core, sleep, socket as tsocket +from .._core._tests.tutil import can_bind_ipv6 from .._highlevel_generic import aclose_forcefully +from .._highlevel_socket import SocketListener from ..testing import * from ..testing._check_streams import _assert_raises from ..testing._memory_streams import _UnboundedByteQueue -from .. import socket as tsocket -from .._highlevel_socket import SocketListener async def test_wait_all_tasks_blocked(): diff --git a/trio/tests/test_threads.py b/trio/_tests/test_threads.py similarity index 91% rename from trio/tests/test_threads.py rename to trio/_tests/test_threads.py index 920b3d95f0..0be067da5b 100644 --- a/trio/tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextvars import queue as stdlib_queue import re @@ -9,16 +11,16 @@ from typing import Callable, Optional import pytest -from sniffio import current_async_library_cvar +import sniffio -from .. import CapacityLimiter, Event, _core, sleep, sleep_forever, fail_after -from .._core.tests.test_ki import ki_self -from .._core.tests.tutil import buggy_pypy_asyncgens +from .. import CapacityLimiter, Event, _core, fail_after, sleep, sleep_forever +from .._core._tests.test_ki import ki_self +from .._core._tests.tutil import buggy_pypy_asyncgens from .._threads import ( current_default_thread_limiter, + from_thread_check_cancelled, from_thread_run, from_thread_run_sync, - from_thread_check_cancelled, to_thread_run_sync, ) from ..testing import wait_all_tasks_blocked @@ -167,9 +169,9 @@ async def main(): async def test_named_thread(): - ending = " from trio.tests.test_threads.test_named_thread" + ending = " from trio._tests.test_threads.test_named_thread" - def inner(name="inner" + ending) -> threading.Thread: + def inner(name: str = "inner" + ending) -> threading.Thread: assert threading.current_thread().name == name return threading.current_thread() @@ -184,7 +186,7 @@ def f(name: str) -> Callable[[None], threading.Thread]: await to_thread_run_sync(f("None" + ending)) # test that you can set a custom name, and that it's reset afterwards - async def test_thread_name(name: str): + async def test_thread_name(name: str) -> None: thread = await to_thread_run_sync(f(name), thread_name=name) assert re.match("Trio thread [0-9]*", thread.name) @@ -234,7 +236,7 @@ def _get_thread_name(ident: Optional[int] = None) -> Optional[str]: # and most mac machines. So unless the platform is linux it will just skip # in case it fails to fetch the os thread name. async def test_named_thread_os(): - def inner(name) -> threading.Thread: + def inner(name: str) -> threading.Thread: os_thread_name = _get_thread_name() if os_thread_name is None and sys.platform != "linux": pytest.skip(f"no pthread OS support on {sys.platform}") @@ -247,12 +249,12 @@ def f(name: str) -> Callable[[None], threading.Thread]: return partial(inner, name) # test defaults - default = "None from trio.tests.test_threads.test_named_thread" + default = "None from trio._tests.test_threads.test_named_thread" await to_thread_run_sync(f(default)) await to_thread_run_sync(f(default), thread_name=None) # test that you can set a custom name, and that it's reset afterwards - async def test_thread_name(name: str, expected: Optional[str] = None): + async def test_thread_name(name: str, expected: Optional[str] = None) -> None: if expected is None: expected = name thread = await to_thread_run_sync(f(expected), thread_name=name) @@ -583,7 +585,9 @@ async def async_fn(): # pragma: no cover await to_thread_run_sync(async_fn) -trio_test_contextvar = contextvars.ContextVar("trio_test_contextvar") +trio_test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar( + "trio_test_contextvar" +) async def test_trio_to_thread_run_sync_contextvars(): @@ -592,42 +596,35 @@ async def test_trio_to_thread_run_sync_contextvars(): def f(): value = trio_test_contextvar.get() - sniffio_cvar_value = current_async_library_cvar.get() - return (value, sniffio_cvar_value, threading.current_thread()) + with pytest.raises(sniffio.AsyncLibraryNotFoundError): + sniffio.current_async_library() + return (value, threading.current_thread()) - value, sniffio_cvar_value, child_thread = await to_thread_run_sync(f) + value, child_thread = await to_thread_run_sync(f) assert value == "main" - assert sniffio_cvar_value == None assert child_thread != trio_thread def g(): parent_value = trio_test_contextvar.get() trio_test_contextvar.set("worker") inner_value = trio_test_contextvar.get() - sniffio_cvar_value = current_async_library_cvar.get() + with pytest.raises(sniffio.AsyncLibraryNotFoundError): + sniffio.current_async_library() return ( parent_value, inner_value, - sniffio_cvar_value, threading.current_thread(), ) - ( - parent_value, - inner_value, - sniffio_cvar_value, - child_thread, - ) = await to_thread_run_sync(g) + parent_value, inner_value, child_thread = await to_thread_run_sync(g) current_value = trio_test_contextvar.get() - sniffio_outer_value = current_async_library_cvar.get() assert parent_value == "main" assert inner_value == "worker" assert current_value == "main", ( "The contextvar value set on the worker would not propagate back to the main" " thread" ) - assert sniffio_cvar_value is None - assert sniffio_outer_value == "trio" + assert sniffio.current_async_library() == "trio" async def test_trio_from_thread_run_sync(): @@ -699,7 +696,7 @@ def thread_fn(token): assert callee_token == caller_token -def test_from_thread_no_token(): +async def test_from_thread_no_token(): # Test that a "raw call" to trio.from_thread.run() fails because no token # has been provided @@ -714,50 +711,40 @@ def thread_fn(): thread_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("worker") thread_current_value = trio_test_contextvar.get() - sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + with pytest.raises(sniffio.AsyncLibraryNotFoundError): + sniffio.current_async_library() def back_in_main(): back_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("back_in_main") back_current_value = trio_test_contextvar.get() - sniffio_cvar_back_value = current_async_library_cvar.get() - return back_parent_value, back_current_value, sniffio_cvar_back_value + assert sniffio.current_async_library() == "trio" + return back_parent_value, back_current_value - ( - back_parent_value, - back_current_value, - sniffio_cvar_back_value, - ) = from_thread_run_sync(back_in_main) + back_parent_value, back_current_value = from_thread_run_sync(back_in_main) thread_after_value = trio_test_contextvar.get() - sniffio_cvar_thread_after_value = current_async_library_cvar.get() + with pytest.raises(sniffio.AsyncLibraryNotFoundError): + sniffio.current_async_library() return ( thread_parent_value, thread_current_value, thread_after_value, - sniffio_cvar_thread_pre_value, - sniffio_cvar_thread_after_value, back_parent_value, back_current_value, - sniffio_cvar_back_value, ) ( thread_parent_value, thread_current_value, thread_after_value, - sniffio_cvar_thread_pre_value, - sniffio_cvar_thread_after_value, back_parent_value, back_current_value, - sniffio_cvar_back_value, ) = await to_thread_run_sync(thread_fn) current_value = trio_test_contextvar.get() - sniffio_cvar_out_value = current_async_library_cvar.get() assert current_value == thread_parent_value == "main" assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert sniffio.current_async_library() == "trio" assert back_current_value == "back_in_main" - assert sniffio_cvar_out_value == sniffio_cvar_back_value == "trio" - assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None async def test_trio_from_thread_run_contextvars(): @@ -767,49 +754,40 @@ def thread_fn(): thread_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("worker") thread_current_value = trio_test_contextvar.get() - sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + with pytest.raises(sniffio.AsyncLibraryNotFoundError): + sniffio.current_async_library() async def async_back_in_main(): back_parent_value = trio_test_contextvar.get() trio_test_contextvar.set("back_in_main") back_current_value = trio_test_contextvar.get() - sniffio_cvar_back_value = current_async_library_cvar.get() - return back_parent_value, back_current_value, sniffio_cvar_back_value + assert sniffio.current_async_library() == "trio" + return back_parent_value, back_current_value - ( - back_parent_value, - back_current_value, - sniffio_cvar_back_value, - ) = from_thread_run(async_back_in_main) + back_parent_value, back_current_value = from_thread_run(async_back_in_main) thread_after_value = trio_test_contextvar.get() - sniffio_cvar_thread_after_value = current_async_library_cvar.get() + with pytest.raises(sniffio.AsyncLibraryNotFoundError): + sniffio.current_async_library() return ( thread_parent_value, thread_current_value, thread_after_value, - sniffio_cvar_thread_pre_value, - sniffio_cvar_thread_after_value, back_parent_value, back_current_value, - sniffio_cvar_back_value, ) ( thread_parent_value, thread_current_value, thread_after_value, - sniffio_cvar_thread_pre_value, - sniffio_cvar_thread_after_value, back_parent_value, back_current_value, - sniffio_cvar_back_value, ) = await to_thread_run_sync(thread_fn) current_value = trio_test_contextvar.get() assert current_value == thread_parent_value == "main" assert thread_current_value == back_parent_value == thread_after_value == "worker" assert back_current_value == "back_in_main" - assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None - assert sniffio_cvar_back_value == "trio" + assert sniffio.current_async_library() == "trio" def test_run_fn_as_system_task_catched_badly_typed_token(): diff --git a/trio/tests/test_timeouts.py b/trio/_tests/test_timeouts.py similarity index 81% rename from trio/tests/test_timeouts.py rename to trio/_tests/test_timeouts.py index 382c015b1d..9507d88a78 100644 --- a/trio/tests/test_timeouts.py +++ b/trio/_tests/test_timeouts.py @@ -1,11 +1,12 @@ +import time + import outcome import pytest -import time -from .._core.tests.tutil import slow from .. import _core -from ..testing import assert_checkpoints +from .._core._tests.tutil import slow from .._timeouts import * +from ..testing import assert_checkpoints async def check_takes_about(f, expected_dur): @@ -53,9 +54,6 @@ async def sleep_2(): await check_takes_about(sleep_2, TARGET) - with pytest.raises(ValueError): - await sleep(-1) - with assert_checkpoints(): await sleep(0) # This also serves as a test of the trivial move_on_at @@ -66,10 +64,6 @@ async def sleep_2(): @slow async def test_move_on_after(): - with pytest.raises(ValueError): - with move_on_after(-1): - pass # pragma: no cover - async def sleep_3(): with move_on_after(TARGET): await sleep(100) @@ -99,6 +93,29 @@ async def sleep_5(): with fail_after(100): await sleep(0) - with pytest.raises(ValueError): - with fail_after(-1): - pass # pragma: no cover + +async def test_timeouts_raise_value_error(): + # deadlines are allowed to be negative, but not delays. + # neither delays nor deadlines are allowed to be NaN + + nan = float("nan") + + for fun, val in ( + (sleep, -1), + (sleep, nan), + (sleep_until, nan), + ): + with pytest.raises(ValueError): + await fun(val) + + for cm, val in ( + (fail_after, -1), + (fail_after, nan), + (fail_at, nan), + (move_on_after, -1), + (move_on_after, nan), + (move_on_at, nan), + ): + with pytest.raises(ValueError): + with cm(val): + pass # pragma: no cover diff --git a/trio/tests/test_tracing.py b/trio/_tests/test_tracing.py similarity index 84% rename from trio/tests/test_tracing.py rename to trio/_tests/test_tracing.py index 07d1ff7609..e5110eaff3 100644 --- a/trio/tests/test_tracing.py +++ b/trio/_tests/test_tracing.py @@ -1,26 +1,26 @@ import trio -async def coro1(event: trio.Event): +async def coro1(event: trio.Event) -> None: event.set() await trio.sleep_forever() -async def coro2(event: trio.Event): +async def coro2(event: trio.Event) -> None: await coro1(event) -async def coro3(event: trio.Event): +async def coro3(event: trio.Event) -> None: await coro2(event) -async def coro2_async_gen(event: trio.Event): +async def coro2_async_gen(event): yield await trio.lowlevel.checkpoint() yield await coro1(event) yield await trio.lowlevel.checkpoint() -async def coro3_async_gen(event: trio.Event): +async def coro3_async_gen(event: trio.Event) -> None: async for x in coro2_async_gen(event): pass diff --git a/trio/tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py similarity index 96% rename from trio/tests/test_unix_pipes.py rename to trio/_tests/test_unix_pipes.py index cf98942ea4..0b0d2ceb23 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/_tests/test_unix_pipes.py @@ -1,17 +1,22 @@ +from __future__ import annotations + import errno -import select import os -import tempfile +import select import sys +from typing import TYPE_CHECKING import pytest -from .._core.tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken -from .. import _core, move_on_after -from ..testing import wait_all_tasks_blocked, check_one_way_stream +from .. import _core +from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken +from ..testing import check_one_way_stream, wait_all_tasks_blocked posix = os.name == "posix" pytestmark = pytest.mark.skipif(not posix, reason="posix only") + +assert not TYPE_CHECKING or sys.platform == "unix" + if posix: from .._unix_pipes import FdStream else: @@ -20,7 +25,7 @@ # Have to use quoted types so import doesn't crash on windows -async def make_pipe() -> "Tuple[FdStream, FdStream]": +async def make_pipe() -> "tuple[FdStream, FdStream]": """Makes a new pair of pipes.""" (r, w) = os.pipe() return FdStream(w), FdStream(r) diff --git a/trio/tests/test_util.py b/trio/_tests/test_util.py similarity index 68% rename from trio/tests/test_util.py rename to trio/_tests/test_util.py index 15ab09a80b..1ab6f825de 100644 --- a/trio/tests/test_util.py +++ b/trio/_tests/test_util.py @@ -1,22 +1,25 @@ import signal import sys +import types import pytest import trio + from .. import _core -from .._core.tests.tutil import ( - ignore_coroutine_never_awaited_warnings, +from .._core._tests.tutil import ( create_asyncio_future_in_new_loop, + ignore_coroutine_never_awaited_warnings, ) from .._util import ( - signal_raise, ConflictDetector, - is_main_thread, - coroutine_or_error, - generic_function, Final, NoPublicConstructor, + coroutine_or_error, + fixup_module_metadata, + generic_function, + is_main_thread, + signal_raise, ) from ..testing import wait_all_tasks_blocked @@ -191,3 +194,70 @@ class SubClass(SpecialClass): # Private constructor should not raise assert isinstance(SpecialClass._create(), SpecialClass) + + +def test_fixup_module_metadata(): + # Ignores modules not in the trio.X tree. + non_trio_module = types.ModuleType("not_trio") + non_trio_module.some_func = lambda: None + non_trio_module.some_func.__name__ = "some_func" + non_trio_module.some_func.__qualname__ = "some_func" + + fixup_module_metadata(non_trio_module.__name__, vars(non_trio_module)) + + assert non_trio_module.some_func.__name__ == "some_func" + assert non_trio_module.some_func.__qualname__ == "some_func" + + # Bulild up a fake module to test. Just use lambdas since all we care about is the names. + mod = types.ModuleType("trio._somemodule_impl") + mod.some_func = lambda: None + mod.some_func.__name__ = "_something_else" + mod.some_func.__qualname__ = "_something_else" + + # No __module__ means it's unchanged. + mod.not_funclike = types.SimpleNamespace() + mod.not_funclike.__name__ = "not_funclike" + + # Check __qualname__ being absent works. + mod.only_has_name = types.SimpleNamespace() + mod.only_has_name.__module__ = "trio._somemodule_impl" + mod.only_has_name.__name__ = "only_name" + + # Underscored names are unchanged. + mod._private = lambda: None + mod._private.__module__ = "trio._somemodule_impl" + mod._private.__name__ = mod._private.__qualname__ = "_private" + + # We recurse into classes. + mod.SomeClass = type( + "SomeClass", + (), + { + "__init__": lambda self: None, + "method": lambda self: None, + }, + ) + mod.SomeClass.recursion = mod.SomeClass # Reference loop is fine. + + fixup_module_metadata("trio.somemodule", vars(mod)) + assert mod.some_func.__name__ == "some_func" + assert mod.some_func.__module__ == "trio.somemodule" + assert mod.some_func.__qualname__ == "some_func" + + assert mod.not_funclike.__name__ == "not_funclike" + assert mod._private.__name__ == "_private" + assert mod._private.__module__ == "trio._somemodule_impl" + assert mod._private.__qualname__ == "_private" + + assert mod.only_has_name.__name__ == "only_has_name" + assert mod.only_has_name.__module__ == "trio.somemodule" + assert not hasattr(mod.only_has_name, "__qualname__") + + assert mod.SomeClass.method.__name__ == "method" + assert mod.SomeClass.method.__module__ == "trio.somemodule" + assert mod.SomeClass.method.__qualname__ == "SomeClass.method" + # Make coverage happy. + non_trio_module.some_func() + mod.some_func() + mod._private() + mod.SomeClass().method() diff --git a/trio/tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py similarity index 97% rename from trio/tests/test_wait_for_object.py rename to trio/_tests/test_wait_for_object.py index 38acfa802d..ea16684289 100644 --- a/trio/tests/test_wait_for_object.py +++ b/trio/_tests/test_wait_for_object.py @@ -6,17 +6,14 @@ # Mark all the tests in this file as being windows-only pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") -from .._core.tests.tutil import slow import trio -from .. import _core -from .. import _timeouts + +from .. import _core, _timeouts +from .._core._tests.tutil import slow if on_windows: from .._core._windows_cffi import ffi, kernel32 - from .._wait_for_object import ( - WaitForSingleObject, - WaitForMultipleObjects_sync, - ) + from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject async def test_WaitForMultipleObjects_sync(): diff --git a/trio/tests/test_windows_pipes.py b/trio/_tests/test_windows_pipes.py similarity index 90% rename from trio/tests/test_windows_pipes.py rename to trio/_tests/test_windows_pipes.py index 2bcc64a072..5c4bae7d25 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/_tests/test_windows_pipes.py @@ -1,21 +1,16 @@ -import errno -import select - -import os import sys -from typing import Any -from typing import Tuple +from typing import Any, Tuple import pytest -from .._core.tests.tutil import gc_collect_harder -from .. import _core, move_on_after -from ..testing import wait_all_tasks_blocked, check_one_way_stream +from .. import _core +from ..testing import check_one_way_stream, wait_all_tasks_blocked if sys.platform == "win32": - from .._windows_pipes import PipeSendStream, PipeReceiveStream - from .._core._windows_cffi import _handle, kernel32 from asyncio.windows_utils import pipe + + from .._core._windows_cffi import _handle, kernel32 + from .._windows_pipes import PipeReceiveStream, PipeSendStream else: pytestmark = pytest.mark.skip(reason="windows only") pipe: Any = None diff --git a/trio/tests/tools/__init__.py b/trio/_tests/tools/__init__.py similarity index 100% rename from trio/tests/tools/__init__.py rename to trio/_tests/tools/__init__.py diff --git a/trio/tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py similarity index 52% rename from trio/tests/tools/test_gen_exports.py rename to trio/_tests/tools/test_gen_exports.py index 73eacc098a..e7d8ab94f2 100644 --- a/trio/tests/tools/test_gen_exports.py +++ b/trio/_tests/tools/test_gen_exports.py @@ -1,17 +1,18 @@ import ast -import astor -import pytest -import os import sys -from shutil import copyfile +import pytest + from trio._tools.gen_exports import ( - get_public_methods, + File, create_passthrough_args, + get_public_methods, process, + run_linters, ) SOURCE = '''from _run import _public +from somewhere import Thing class Test: @_public @@ -21,7 +22,7 @@ def public_func(self): @ignore_this @_public @another_decorator - async def public_async_func(self): + async def public_async_func(self) -> Thing: pass # no doc string def not_public(self): @@ -31,6 +32,21 @@ async def not_public_async(self): pass ''' +IMPORT_1 = """\ +from somewhere import Thing +""" + +IMPORT_2 = """\ +from somewhere import Thing +import os +""" + +IMPORT_3 = """\ +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from somewhere import Thing +""" + def test_get_public_methods(): methods = list(get_public_methods(ast.parse(SOURCE))) @@ -55,18 +71,46 @@ def test_create_pass_through_args(): assert create_passthrough_args(func_node) == expected -def test_process(tmp_path): +skip_lints = pytest.mark.skipif( + sys.implementation.name != "cpython", + reason="gen_exports is internal, black/isort only runs on CPython", +) + + +@skip_lints +@pytest.mark.parametrize("imports", ["", IMPORT_1, IMPORT_2, IMPORT_3]) +def test_process(tmp_path, imports): modpath = tmp_path / "_module.py" genpath = tmp_path / "_generated_module.py" modpath.write_text(SOURCE, encoding="utf-8") + file = File(modpath, "runner", platform="linux", imports=imports) assert not genpath.exists() with pytest.raises(SystemExit) as excinfo: - process([(str(modpath), "runner")], do_test=True) + process([file], do_test=True) assert excinfo.value.code == 1 - process([(str(modpath), "runner")], do_test=False) + process([file], do_test=False) assert genpath.exists() - process([(str(modpath), "runner")], do_test=True) + process([file], do_test=True) # But if we change the lookup path it notices with pytest.raises(SystemExit) as excinfo: - process([(str(modpath), "runner.io_manager")], do_test=True) + process( + [File(modpath, "runner.io_manager", platform="linux", imports=imports)], + do_test=True, + ) + assert excinfo.value.code == 1 + # Also if the platform is changed. + with pytest.raises(SystemExit) as excinfo: + process([File(modpath, "runner", imports=imports)], do_test=True) assert excinfo.value.code == 1 + + +@skip_lints +def test_lint_failure(tmp_path) -> None: + """Test that processing properly fails if black or isort does.""" + file = File(tmp_path / "module.py", "module") + + with pytest.raises(SystemExit): + run_linters(file, "class not valid code ><") + + with pytest.raises(SystemExit): + run_linters(file, "# isort: skip_file") diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json new file mode 100644 index 0000000000..667f4fe680 --- /dev/null +++ b/trio/_tests/verify_types_darwin.json @@ -0,0 +1,84 @@ +{ + "generalDiagnostics": [], + "summary": { + "errorCount": 0, + "filesAnalyzed": 8, + "informationCount": 0, + "warningCount": 0 + }, + "typeCompleteness": { + "completenessScore": 1, + "diagnostics": [ + { + "message": "No docstring found for function \"trio.lowlevel.current_kqueue\"", + "name": "trio.lowlevel.current_kqueue" + }, + { + "message": "No docstring found for function \"trio.lowlevel.monitor_kevent\"", + "name": "trio.lowlevel.monitor_kevent" + }, + { + "message": "No docstring found for function \"trio.lowlevel.notify_closing\"", + "name": "trio.lowlevel.notify_closing" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_kevent\"", + "name": "trio.lowlevel.wait_kevent" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_readable\"", + "name": "trio.lowlevel.wait_readable" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_writable\"", + "name": "trio.lowlevel.wait_writable" + }, + { + "message": "No docstring found for class \"trio.tests.TestsDeprecationWrapper\"", + "name": "trio.tests.TestsDeprecationWrapper" + } + ], + "exportedSymbolCounts": { + "withAmbiguousType": 0, + "withKnownType": 632, + "withUnknownType": 0 + }, + "ignoreUnknownTypesFromImports": true, + "missingClassDocStringCount": 1, + "missingDefaultParamCount": 0, + "missingFunctionDocStringCount": 6, + "moduleName": "trio", + "modules": [ + { + "name": "trio" + }, + { + "name": "trio.abc" + }, + { + "name": "trio.from_thread" + }, + { + "name": "trio.lowlevel" + }, + { + "name": "trio.socket" + }, + { + "name": "trio.testing" + }, + { + "name": "trio.tests" + }, + { + "name": "trio.to_thread" + } + ], + "otherSymbolCounts": { + "withAmbiguousType": 0, + "withKnownType": 680, + "withUnknownType": 0 + }, + "packageName": "trio" + } +} diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json new file mode 100644 index 0000000000..02ce7516eb --- /dev/null +++ b/trio/_tests/verify_types_linux.json @@ -0,0 +1,72 @@ +{ + "generalDiagnostics": [], + "summary": { + "errorCount": 0, + "filesAnalyzed": 8, + "informationCount": 0, + "warningCount": 0 + }, + "typeCompleteness": { + "completenessScore": 1, + "diagnostics": [ + { + "message": "No docstring found for function \"trio.lowlevel.notify_closing\"", + "name": "trio.lowlevel.notify_closing" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_readable\"", + "name": "trio.lowlevel.wait_readable" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_writable\"", + "name": "trio.lowlevel.wait_writable" + }, + { + "message": "No docstring found for class \"trio.tests.TestsDeprecationWrapper\"", + "name": "trio.tests.TestsDeprecationWrapper" + } + ], + "exportedSymbolCounts": { + "withAmbiguousType": 0, + "withKnownType": 629, + "withUnknownType": 0 + }, + "ignoreUnknownTypesFromImports": true, + "missingClassDocStringCount": 1, + "missingDefaultParamCount": 0, + "missingFunctionDocStringCount": 3, + "moduleName": "trio", + "modules": [ + { + "name": "trio" + }, + { + "name": "trio.abc" + }, + { + "name": "trio.from_thread" + }, + { + "name": "trio.lowlevel" + }, + { + "name": "trio.socket" + }, + { + "name": "trio.testing" + }, + { + "name": "trio.tests" + }, + { + "name": "trio.to_thread" + } + ], + "otherSymbolCounts": { + "withAmbiguousType": 0, + "withKnownType": 680, + "withUnknownType": 0 + }, + "packageName": "trio" + } +} diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json new file mode 100644 index 0000000000..90b4324578 --- /dev/null +++ b/trio/_tests/verify_types_windows.json @@ -0,0 +1,188 @@ +{ + "generalDiagnostics": [], + "summary": { + "errorCount": 0, + "filesAnalyzed": 8, + "informationCount": 0, + "warningCount": 0 + }, + "typeCompleteness": { + "completenessScore": 0.9857594936708861, + "diagnostics": [ + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.current_iocp" + }, + { + "message": "No docstring found for function \"trio.lowlevel.current_iocp\"", + "name": "trio.lowlevel.current_iocp" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.monitor_completion_key" + }, + { + "message": "No docstring found for function \"trio.lowlevel.monitor_completion_key\"", + "name": "trio.lowlevel.monitor_completion_key" + }, + { + "message": "Type annotation for parameter \"handle\" is missing", + "name": "trio.lowlevel.notify_closing" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.notify_closing" + }, + { + "message": "No docstring found for function \"trio.lowlevel.notify_closing\"", + "name": "trio.lowlevel.notify_closing" + }, + { + "message": "No docstring found for function \"trio.lowlevel.open_process\"", + "name": "trio.lowlevel.open_process" + }, + { + "message": "Type annotation for parameter \"handle\" is missing", + "name": "trio.lowlevel.readinto_overlapped" + }, + { + "message": "Type annotation for parameter \"buffer\" is missing", + "name": "trio.lowlevel.readinto_overlapped" + }, + { + "message": "Type annotation for parameter \"file_offset\" is missing", + "name": "trio.lowlevel.readinto_overlapped" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.readinto_overlapped" + }, + { + "message": "No docstring found for function \"trio.lowlevel.readinto_overlapped\"", + "name": "trio.lowlevel.readinto_overlapped" + }, + { + "message": "Type annotation for parameter \"handle\" is missing", + "name": "trio.lowlevel.register_with_iocp" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.register_with_iocp" + }, + { + "message": "No docstring found for function \"trio.lowlevel.register_with_iocp\"", + "name": "trio.lowlevel.register_with_iocp" + }, + { + "message": "Type annotation for parameter \"handle\" is missing", + "name": "trio.lowlevel.wait_overlapped" + }, + { + "message": "Type annotation for parameter \"lpOverlapped\" is missing", + "name": "trio.lowlevel.wait_overlapped" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.wait_overlapped" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_overlapped\"", + "name": "trio.lowlevel.wait_overlapped" + }, + { + "message": "Type annotation for parameter \"sock\" is missing", + "name": "trio.lowlevel.wait_readable" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.wait_readable" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_readable\"", + "name": "trio.lowlevel.wait_readable" + }, + { + "message": "Type annotation for parameter \"sock\" is missing", + "name": "trio.lowlevel.wait_writable" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.wait_writable" + }, + { + "message": "No docstring found for function \"trio.lowlevel.wait_writable\"", + "name": "trio.lowlevel.wait_writable" + }, + { + "message": "Type annotation for parameter \"handle\" is missing", + "name": "trio.lowlevel.write_overlapped" + }, + { + "message": "Type annotation for parameter \"data\" is missing", + "name": "trio.lowlevel.write_overlapped" + }, + { + "message": "Type annotation for parameter \"file_offset\" is missing", + "name": "trio.lowlevel.write_overlapped" + }, + { + "message": "Return type annotation is missing", + "name": "trio.lowlevel.write_overlapped" + }, + { + "message": "No docstring found for function \"trio.lowlevel.write_overlapped\"", + "name": "trio.lowlevel.write_overlapped" + }, + { + "message": "No docstring found for function \"trio.run_process\"", + "name": "trio.run_process" + }, + { + "message": "No docstring found for class \"trio.tests.TestsDeprecationWrapper\"", + "name": "trio.tests.TestsDeprecationWrapper" + } + ], + "exportedSymbolCounts": { + "withAmbiguousType": 0, + "withKnownType": 623, + "withUnknownType": 9 + }, + "ignoreUnknownTypesFromImports": true, + "missingClassDocStringCount": 1, + "missingDefaultParamCount": 0, + "missingFunctionDocStringCount": 11, + "moduleName": "trio", + "modules": [ + { + "name": "trio" + }, + { + "name": "trio.abc" + }, + { + "name": "trio.from_thread" + }, + { + "name": "trio.lowlevel" + }, + { + "name": "trio.socket" + }, + { + "name": "trio.testing" + }, + { + "name": "trio.tests" + }, + { + "name": "trio.to_thread" + } + ], + "otherSymbolCounts": { + "withAmbiguousType": 0, + "withKnownType": 671, + "withUnknownType": 0 + }, + "packageName": "trio" + } +} diff --git a/trio/_threads.py b/trio/_threads.py index cd76f1907a..9f88df443f 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -1,16 +1,20 @@ +from __future__ import annotations + import contextvars import functools import inspect import queue as stdlib_queue import threading +from collections.abc import Awaitable, Callable from itertools import count -from typing import Optional +from typing import Generic, TypeVar import attr import outcome from sniffio import current_async_library_cvar import trio +from trio._core._traps import RaiseCancelT from ._core import ( RunVar, @@ -22,17 +26,27 @@ from ._sync import CapacityLimiter from ._util import coroutine_or_error -# Global due to Threading API, thread local storage for trio token and raise_cancel -THREAD_LOCAL = threading.local() +RetT = TypeVar("RetT") + + +class _TokenLocal(threading.local): + """Global due to Threading API, thread local storage for trio token.""" + + token: TrioToken + cancel_register: list[RaiseCancelT | None] + task_register: list[trio.lowlevel.Task | None] + -_limiter_local = RunVar("limiter") +TOKEN_LOCAL = _TokenLocal() + +_limiter_local: RunVar[CapacityLimiter] = RunVar("limiter") # I pulled this number out of the air; it isn't based on anything. Probably we # should make some kind of measurements to pick a good value. DEFAULT_LIMIT = 40 _thread_counter = count() -def current_default_thread_limiter(): +def current_default_thread_limiter() -> CapacityLimiter: """Get the default `~trio.CapacityLimiter` used by `trio.to_thread.run_sync`. @@ -54,28 +68,34 @@ def current_default_thread_limiter(): # keep track of who's holding the CapacityLimiter's token. @attr.s(frozen=True, eq=False, hash=False) class ThreadPlaceholder: - name = attr.ib() + name: str = attr.ib() # Types for the to_thread_run_sync message loop +class ThreadMessage(Generic[RetT]): + pass + + @attr.s(frozen=True, eq=False) -class ThreadDone: - result = attr.ib() +class ThreadDone(ThreadMessage[RetT]): + result: outcome.Outcome[RetT] = attr.ib() @attr.s(frozen=True, eq=False) -class Run: - afn = attr.ib() - args = attr.ib() - context = attr.ib() - queue = attr.ib(init=False, factory=stdlib_queue.SimpleQueue) +class Run(ThreadMessage[RetT]): + afn: Callable[..., Awaitable[RetT]] = attr.ib() + args: tuple[object, ...] = attr.ib() + context: contextvars.Context = attr.ib() + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( + init=False, factory=stdlib_queue.SimpleQueue + ) @disable_ki_protection - async def unprotected_afn(self): + async def unprotected_afn(self) -> RetT: coro = coroutine_or_error(self.afn, *self.args) return await coro - async def run(self): + async def run(self) -> None: task = trio.lowlevel.current_task() old_context = task.context task.context = self.context.copy() @@ -88,20 +108,22 @@ async def run(self): task.context = old_context await trio.lowlevel.cancel_shielded_checkpoint() - async def run_system(self): + async def run_system(self) -> None: result = await outcome.acapture(self.unprotected_afn) self.queue.put_nowait(result) @attr.s(frozen=True, eq=False) -class RunSync: - fn = attr.ib() - args = attr.ib() - context = attr.ib() - queue = attr.ib(init=False, factory=stdlib_queue.SimpleQueue) +class RunSync(ThreadMessage[RetT]): + fn: Callable[..., RetT] = attr.ib() + args: tuple[object, ...] = attr.ib() + context: contextvars.Context = attr.ib() + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( + init=False, factory=stdlib_queue.SimpleQueue + ) @disable_ki_protection - def unprotected_fn(self): + def unprotected_fn(self) -> RetT: ret = self.fn(*self.args) if inspect.iscoroutine(ret): @@ -114,16 +136,20 @@ def unprotected_fn(self): return ret - def run_sync(self): + def run_sync(self) -> None: self.context.run(current_async_library_cvar.set, "trio") result = outcome.capture(self.context.run, self.unprotected_fn) self.queue.put_nowait(result) -@enable_ki_protection -async def to_thread_run_sync( - sync_fn, *args, thread_name: Optional[str] = None, cancellable=False, limiter=None -): +@enable_ki_protection # Decorator used on function with Coroutine[Any, Any, RetT] +async def to_thread_run_sync( # type: ignore[misc] + sync_fn: Callable[..., RetT], + *args: object, + thread_name: str | None = None, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> RetT: """Convert a blocking operation into an async operation using a thread. These two lines are equivalent:: @@ -215,23 +241,23 @@ async def to_thread_run_sync( # Holds a reference to the task that's blocked in this function waiting # for the result – or None if this function was cancelled and we should # discard the result. - task_register = [trio.lowlevel.current_task()] + task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()] # Holds a reference to the raise_cancel function provided if a cancellation # is attempted against this task - or None if no such delivery has happened. - cancel_register = [None] + cancel_register: list[RaiseCancelT | None] = [None] # type: ignore[assignment] name = f"trio.to_thread.run_sync-{next(_thread_counter)}" placeholder = ThreadPlaceholder(name) # This function gets scheduled into the Trio run loop to deliver the # thread's result. - def report_back_in_trio_thread_fn(result): - def do_release_then_return_result(): + def report_back_in_trio_thread_fn(result: outcome.Outcome[RetT]) -> None: + def do_release_then_return_result() -> RetT: # release_on_behalf_of is an arbitrary user-defined method, so it # might raise an error. If it does, we want that error to # replace the regular return value, and if the regular return was # already an exception then we want them to chain. try: - return result.unwrap() + return result.unwrap() # type: ignore[no-any-return] # Until outcome is typed finally: limiter.release_on_behalf_of(placeholder) @@ -246,11 +272,16 @@ def do_release_then_return_result(): if thread_name is None: thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}" - def worker_fn(): + def worker_fn() -> RetT: + # Trio doesn't use current_async_library_cvar, but if someone + # else set it, it would now shine through since + # snifio.thread_local isn't set in the new thread. Make sure + # the new thread sees that it's not running in async context. current_async_library_cvar.set(None) - THREAD_LOCAL.token = current_trio_token - THREAD_LOCAL.cancel_register = cancel_register - THREAD_LOCAL.task_register = task_register + + TOKEN_LOCAL.token = current_trio_token + TOKEN_LOCAL.cancel_register = cancel_register + TOKEN_LOCAL.task_register = task_register try: ret = sync_fn(*args) @@ -264,14 +295,15 @@ def worker_fn(): return ret finally: - del THREAD_LOCAL.token - del THREAD_LOCAL.cancel_register - del THREAD_LOCAL.task_register + del TOKEN_LOCAL.token + del TOKEN_LOCAL.cancel_register + del TOKEN_LOCAL.task_register context = contextvars.copy_context() - contextvars_aware_worker_fn = functools.partial(context.run, worker_fn) + # Partial confuses type checkers, coerce to a callable. + contextvars_aware_worker_fn: Callable[[], RetT] = functools.partial(context.run, worker_fn) # type: ignore[assignment] - def deliver_worker_fn_result(result): + def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: try: current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result) except trio.RunFinishedError: @@ -289,7 +321,7 @@ def deliver_worker_fn_result(result): limiter.release_on_behalf_of(placeholder) raise - def abort(raise_cancel): + def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: # fill so from_thread_check_cancelled can raise cancel_register[0] = raise_cancel if cancellable: @@ -300,12 +332,15 @@ def abort(raise_cancel): return trio.lowlevel.Abort.FAILED while True: - msg_from_thread = await trio.lowlevel.wait_task_rescheduled(abort) - if isinstance(msg_from_thread, ThreadDone): - return msg_from_thread.result.unwrap() - elif isinstance(msg_from_thread, Run): + # wait_task_rescheduled return value cannot be typed + msg_from_thread: ThreadMessage[ + RetT + ] = await trio.lowlevel.wait_task_rescheduled(abort) + if type(msg_from_thread) is ThreadDone: + return msg_from_thread.result.unwrap() # type: ignore[no-any-return] + elif type(msg_from_thread) is Run: await msg_from_thread.run() - elif isinstance(msg_from_thread, RunSync): + elif type(msg_from_thread) is RunSync: msg_from_thread.run_sync() else: # pragma: no cover, internal debugging guard raise TypeError( @@ -315,7 +350,7 @@ def abort(raise_cancel): del msg_from_thread -def from_thread_check_cancelled(): +def from_thread_check_cancelled() -> None: """Raise trio.Cancelled if the associated Trio task entered a cancelled status. Only applicable to threads spawned by `trio.to_thread.run_sync`. Poll to allow @@ -328,24 +363,24 @@ def from_thread_check_cancelled(): ``cancellable`` supplied as an argument to it. AttributeError: If this thread is not spawned from `trio.to_thread.run_sync`. """ - raise_cancel = THREAD_LOCAL.cancel_register[0] + raise_cancel = TOKEN_LOCAL.cancel_register[0] if raise_cancel is not None: raise_cancel() -def _check_token(trio_token): +def _check_token(trio_token: TrioToken | None) -> TrioToken: """Raise a RuntimeError if this function is called within a trio run. Avoids deadlock by making sure we're not called from inside a context that we might be waiting for and blocking it. """ - if trio_token and not isinstance(trio_token, TrioToken): + if trio_token is not None and not isinstance(trio_token, TrioToken): raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") - if not trio_token: + if trio_token is None: try: - trio_token = THREAD_LOCAL.token + trio_token = TOKEN_LOCAL.token except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, pass kwarg trio_token=..." @@ -362,11 +397,13 @@ def _check_token(trio_token): return trio_token -def _send_message_to_host_task(message, trio_token): - task_register = THREAD_LOCAL.task_register - cancel_register = THREAD_LOCAL.cancel_register +def _send_message_to_host_task( + message: Run[RetT] | RunSync[RetT], trio_token: TrioToken +) -> RetT: + task_register = TOKEN_LOCAL.task_register + cancel_register = TOKEN_LOCAL.cancel_register - def in_trio_thread(): + def in_trio_thread() -> None: task = task_register[0] if task is None: raise_cancel = cancel_register[0] @@ -375,15 +412,17 @@ def in_trio_thread(): trio.lowlevel.reschedule(task, outcome.Value(message)) trio_token.run_sync_soon(in_trio_thread) - return message.queue.get().unwrap() + return message.queue.get().unwrap() # type: ignore[no-any-return] -def _send_message_to_system_task(message, trio_token): - if isinstance(message, RunSync): +def _send_message_to_system_task( + message: Run[RetT] | RunSync[RetT], trio_token: TrioToken +) -> RetT: + if type(message) is RunSync: run_sync = message.run_sync - elif isinstance(message, Run): + elif type(message) is Run: - def run_sync(): + def run_sync() -> None: try: trio.lowlevel.spawn_system_task( message.run_system, name=message.afn, context=message.context @@ -400,10 +439,14 @@ def run_sync(): ) trio_token.run_sync_soon(run_sync) - return message.queue.get().unwrap() + return message.queue.get().unwrap() # type: ignore[no-any-return] -def from_thread_run(afn, *args, trio_token=None): +def from_thread_run( + afn: Callable[..., Awaitable[RetT]], + *args: object, + trio_token: TrioToken | None = None, +) -> RetT: """Run the given async function in the parent Trio thread, blocking until it is complete. @@ -448,7 +491,11 @@ def from_thread_run(afn, *args, trio_token=None): return _send_message_to_host_task(message_to_trio, checked_token) -def from_thread_run_sync(fn, *args, trio_token=None): +def from_thread_run_sync( + fn: Callable[..., RetT], + *args: tuple[object, ...], + trio_token: TrioToken | None = None, +) -> RetT: """Run the given sync function in the parent Trio thread, blocking until it is complete. diff --git a/trio/_timeouts.py b/trio/_timeouts.py index 1f7878f89e..1d03b2f2e3 100644 --- a/trio/_timeouts.py +++ b/trio/_timeouts.py @@ -1,20 +1,29 @@ -from contextlib import contextmanager +from __future__ import annotations + +import math +from contextlib import AbstractContextManager, contextmanager +from typing import TYPE_CHECKING import trio -def move_on_at(deadline): +def move_on_at(deadline: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope with the given absolute deadline. Args: deadline (float): The deadline. + Raises: + ValueError: if deadline is NaN. + """ + if math.isnan(deadline): + raise ValueError("deadline must not be NaN") return trio.CancelScope(deadline=deadline) -def move_on_after(seconds): +def move_on_after(seconds: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope whose deadline is set to now + *seconds*. @@ -22,16 +31,15 @@ def move_on_after(seconds): seconds (float): The timeout. Raises: - ValueError: if timeout is less than zero. + ValueError: if timeout is less than zero or NaN. """ - if seconds < 0: raise ValueError("timeout must be non-negative") return move_on_at(trio.current_time() + seconds) -async def sleep_forever(): +async def sleep_forever() -> None: """Pause execution of the current task forever (or until cancelled). Equivalent to calling ``await sleep(math.inf)``. @@ -40,7 +48,7 @@ async def sleep_forever(): await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED) -async def sleep_until(deadline): +async def sleep_until(deadline: float) -> None: """Pause execution of the current task until the given time. The difference between :func:`sleep` and :func:`sleep_until` is that the @@ -52,12 +60,15 @@ async def sleep_until(deadline): the past, in which case this function executes a checkpoint but does not block. + Raises: + ValueError: if deadline is NaN. + """ with move_on_at(deadline): await sleep_forever() -async def sleep(seconds): +async def sleep(seconds: float) -> None: """Pause execution of the current task for the given number of seconds. Args: @@ -65,7 +76,7 @@ async def sleep(seconds): insert a checkpoint without actually blocking. Raises: - ValueError: if *seconds* is negative. + ValueError: if *seconds* is negative or NaN. """ if seconds < 0: @@ -83,8 +94,9 @@ class TooSlowError(Exception): """ -@contextmanager -def fail_at(deadline): +# workaround for PyCharm not being able to infer return type from @contextmanager +# see https://youtrack.jetbrains.com/issue/PY-36444/PyCharm-doesnt-infer-types-when-using-contextlib.contextmanager-decorator +def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # type: ignore[misc] """Creates a cancel scope with the given deadline, and raises an error if it is actually cancelled. @@ -96,19 +108,26 @@ def fail_at(deadline): :func:`fail_at`, then it's caught and :exc:`TooSlowError` is raised in its place. + Args: + deadline (float): The deadline. + Raises: TooSlowError: if a :exc:`Cancelled` exception is raised in this scope and caught by the context manager. + ValueError: if deadline is NaN. """ - with move_on_at(deadline) as scope: yield scope if scope.cancelled_caught: raise TooSlowError -def fail_after(seconds): +if not TYPE_CHECKING: + fail_at = contextmanager(fail_at) + + +def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]: """Creates a cancel scope with the given timeout, and raises an error if it is actually cancelled. @@ -119,10 +138,13 @@ def fail_after(seconds): it's caught and discarded. When it reaches :func:`fail_after`, then it's caught and :exc:`TooSlowError` is raised in its place. + Args: + seconds (float): The timeout. + Raises: TooSlowError: if a :exc:`Cancelled` exception is raised in this scope and caught by the context manager. - ValueError: if *seconds* is less than zero. + ValueError: if *seconds* is less than zero or NaN. """ if seconds < 0: diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 3c18a86298..6549d473ab 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -3,28 +3,36 @@ Code generation script for class methods to be exported as public API """ +from __future__ import annotations + import argparse import ast -import astor import os -from pathlib import Path +import subprocess import sys - +import traceback +from collections.abc import Iterable, Iterator +from pathlib import Path from textwrap import indent +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + +import astor +import attr +import isort.api +import isort.exceptions PREFIX = "_generated" HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._instrumentation import Instrument +from __future__ import annotations -# fmt: off -""" - -FOOTER = """# fmt: on +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT """ TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True @@ -35,7 +43,15 @@ """ -def is_function(node): +@attr.define +class File: + path: Path + modname: str + platform: str = attr.field(default="", kw_only=True) + imports: str = attr.field(default="", kw_only=True) + + +def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node is either a function or an async function """ @@ -44,17 +60,18 @@ def is_function(node): return False -def is_public(node): +def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node has a _public decorator""" - if not is_function(node): - return False - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == "_public": - return True + if is_function(node): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == "_public": + return True return False -def get_public_methods(tree): +def get_public_methods( + tree: ast.AST, +) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked @@ -65,7 +82,7 @@ def get_public_methods(tree): yield node -def create_passthrough_args(funcdef): +def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. @@ -85,18 +102,75 @@ def create_passthrough_args(funcdef): return "({})".format(", ".join(call_args)) -def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: +def run_linters(file: File, source: str) -> str: + """Run isort and black on the specified file, returning the new source. + + :raises ValueError: If either failed. + """ + # Black has an undocumented API, but it doesn't easily allow reading configuration from + # pyproject.toml, and simultaneously pass in / receive the code as a string. + # https://github.com/psf/black/issues/779 + try: + result = subprocess.run( + # "-" as a filename = use stdin, return on stdout. + [sys.executable, "-m", "black", "--stdin-filename", file.path, "-"], + input=source, + capture_output=True, + encoding="utf8", + check=True, + ) + except subprocess.CalledProcessError as exc: + print("Failed to run black!") + traceback.print_exception(type(exc), exc, exc.__traceback__) + sys.exit(1) + # isort does have a public API, makes things easy. + try: + isort_res = isort.api.sort_code_string( + result.stdout, + file_path=file.path, + ) + except isort.exceptions.ISortError as exc: + print("Failed to run isort!") + traceback.print_exception(type(exc), exc, exc.__traceback__) + sys.exit(1) + return isort_res + + +def gen_public_wrappers_source(file: File) -> str: """Scan the given .py file for @_public decorators, and generate wrapper functions. """ - generated = [HEADER] - source = astor.code_to_ast.parse_file(source_path) + header = [HEADER] + + if file.imports: + header.append(file.imports) + if file.platform: + # Simple checks to avoid repeating imports. If this messes up, type checkers/tests will + # just give errors. + if "TYPE_CHECKING" not in file.imports: + header.append("from typing import TYPE_CHECKING\n") + if "import sys" not in file.imports: # pragma: no cover + header.append("import sys\n") + header.append( + f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n' + ) + + generated = ["".join(header)] + + source = astor.code_to_ast.parse_file(file.path) for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] + for dec in method.decorator_list: # pragma: no cover + if isinstance(dec, ast.Name) and dec.id == "contextmanager": + is_cm = True + break + else: + is_cm = False + # Remove decorators method.decorator_list = [] @@ -113,10 +187,13 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: # Create the function definition including the body func = astor.to_source(method, indent_with=" " * 4) + if is_cm: # pragma: no cover + func = func.replace("->Iterator", "->ContextManager") + # Create export function body template = TEMPLATE.format( " await " if isinstance(method, ast.AsyncFunctionDef) else " ", - lookup_path, + file.modname, method.name + new_args, ) @@ -125,11 +202,10 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: # Append the snippet to the corresponding module generated.append(snippet) - generated.append(FOOTER) return "\n\n".join(generated) -def matches_disk_files(new_files): +def matches_disk_files(new_files: dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False @@ -140,12 +216,13 @@ def matches_disk_files(new_files): return True -def process(sources_and_lookups, *, do_test): +def process(files: Iterable[File], *, do_test: bool) -> None: new_files = {} - for source_path, lookup_path in sources_and_lookups: - print("Scanning:", source_path) - new_source = gen_public_wrappers_source(source_path, lookup_path) - dirname, basename = os.path.split(source_path) + for file in files: + print("Scanning:", file.path) + new_source = gen_public_wrappers_source(file) + new_source = run_linters(file, new_source) + dirname, basename = os.path.split(file.path) new_path = os.path.join(dirname, PREFIX + basename) new_files[new_path] = new_source if do_test: @@ -163,7 +240,7 @@ def process(sources_and_lookups, *, do_test): # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. -def main(): # pragma: no cover +def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers" ) @@ -177,15 +254,62 @@ def main(): # pragma: no cover assert (source_root / "LICENSE").exists() core = source_root / "trio/_core" to_wrap = [ - (core / "_run.py", "runner"), - (core / "_instrumentation.py", "runner.instruments"), - (core / "_io_windows.py", "runner.io_manager"), - (core / "_io_epoll.py", "runner.io_manager"), - (core / "_io_kqueue.py", "runner.io_manager"), + File(core / "_run.py", "runner", imports=IMPORTS_RUN), + File( + core / "_instrumentation.py", + "runner.instruments", + imports=IMPORTS_INSTRUMENT, + ), + File(core / "_io_windows.py", "runner.io_manager", platform="win32"), + File( + core / "_io_epoll.py", + "runner.io_manager", + platform="linux", + imports=IMPORTS_EPOLL, + ), + File( + core / "_io_kqueue.py", + "runner.io_manager", + platform="darwin", + imports=IMPORTS_KQUEUE, + ), ] process(to_wrap, do_test=parsed_args.test) +IMPORTS_RUN = """\ +from collections.abc import Awaitable, Callable +from typing import Any + +from outcome import Outcome +import contextvars + +from ._run import _NO_SEND, RunStatistics, Task +from ._entry_queue import TrioToken +from .._abc import Clock +""" +IMPORTS_INSTRUMENT = """\ +from ._instrumentation import Instrument +""" + +IMPORTS_EPOLL = """\ +from socket import socket +""" + +IMPORTS_KQUEUE = """\ +from typing import Callable, ContextManager, TYPE_CHECKING + +if TYPE_CHECKING: + import select + from socket import socket + + from ._traps import Abort, RaiseCancelT + + from .. import _core + +""" + + if __name__ == "__main__": # pragma: no cover main() diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index fa98e79521..1a389e12dd 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -1,16 +1,19 @@ from __future__ import annotations -import os import errno +import os +import sys from typing import TYPE_CHECKING +import trio + from ._abc import Stream from ._util import ConflictDetector, Final -import trio - if TYPE_CHECKING: - from typing_extensions import Final as FinalType + from typing import Final as FinalType + +assert not TYPE_CHECKING or sys.platform != "win32" if os.name != "posix": # We raise an error here rather than gating the import in lowlevel.py diff --git a/trio/_util.py b/trio/_util.py index b60e0104e8..73fc024831 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -1,18 +1,38 @@ # Little utilities we use internally +from __future__ import annotations -from abc import ABCMeta +import collections.abc +import inspect import os import signal -from functools import update_wrapper -import typing as t import threading -import collections -import inspect +import typing as t +from abc import ABCMeta +from functools import update_wrapper +from types import AsyncGeneratorType, TracebackType + +from sniffio import thread_local as sniffio_loop import trio +CallT = t.TypeVar("CallT", bound=t.Callable[..., t.Any]) +T = t.TypeVar("T") +RetT = t.TypeVar("RetT") + +if t.TYPE_CHECKING: + from typing_extensions import ParamSpec, Self + + ArgsT = ParamSpec("ArgsT") + + +if t.TYPE_CHECKING: + # Don't type check the implementation below, pthread_kill does not exist on Windows. + def signal_raise(signum: int) -> None: + ... + + # Equivalent to the C function raise(), which Python doesn't wrap -if os.name == "nt": +elif os.name == "nt": # On Windows, os.kill exists but is really weird. # # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver @@ -56,7 +76,7 @@ signal_raise = getattr(_lib, "raise") else: - def signal_raise(signum): + def signal_raise(signum: int) -> None: signal.pthread_kill(threading.get_ident(), signum) @@ -68,7 +88,7 @@ def signal_raise(signum): # Trying to use signal out of the main thread will fail, so we can then # reliably check if this is the main thread without relying on a # potentially modified threading. -def is_main_thread(): +def is_main_thread() -> bool: """Attempt to reliably check if we are in the main thread.""" try: signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT)) @@ -81,8 +101,11 @@ def is_main_thread(): # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. Returns coroutine object. ###### -def coroutine_or_error(async_fn, *args): - def _return_value_looks_like_wrong_library(value): +# TODO: Use TypeVarTuple here. +def coroutine_or_error( + async_fn: t.Callable[..., t.Awaitable[RetT]], *args: t.Any +) -> collections.abc.Coroutine[object, t.NoReturn, RetT]: + def _return_value_looks_like_wrong_library(value: object) -> bool: # Returned by legacy @asyncio.coroutine functions, which includes # a surprising proportion of asyncio builtins. if isinstance(value, collections.abc.Generator): @@ -97,6 +120,10 @@ def _return_value_looks_like_wrong_library(value): return True return False + # Make sure a sync-fn-that-returns-coroutine still sees itself as being + # in trio context + prev_loop, sniffio_loop.name = sniffio_loop.name, "trio" + try: coro = async_fn(*args) @@ -134,11 +161,15 @@ def _return_value_looks_like_wrong_library(value): raise + finally: + sniffio_loop.name = prev_loop + # We can't check iscoroutinefunction(async_fn), because that will fail # for things like functools.partial objects wrapping an async # function. So we have to just call it and then check whether the # return value is a coroutine object. # Note: will not be necessary on python>=3.8, see https://bugs.python.org/issue34890 + # TODO: python3.7 support is now dropped, so the above can be addressed. if not isinstance(coro, collections.abc.Coroutine): # Give good error for: nursery.start_soon(func_returning_future) if _return_value_looks_like_wrong_library(coro): @@ -177,24 +208,33 @@ class ConflictDetector: """ - def __init__(self, msg): + def __init__(self, msg: str) -> None: self._msg = msg self._held = False - def __enter__(self): + def __enter__(self) -> None: if self._held: raise trio.BusyResourceError(self._msg) else: self._held = True - def __exit__(self, *args): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self._held = False -def async_wraps(cls, wrapped_cls, attr_name): +def async_wraps( + cls: type[object], + wrapped_cls: type[object], + attr_name: str, +) -> t.Callable[[CallT], CallT]: """Similar to wraps, but for async wrappers of non-async functions.""" - def decorator(func): + def decorator(func: CallT) -> CallT: func.__name__ = attr_name func.__qualname__ = ".".join((cls.__qualname__, attr_name)) @@ -209,10 +249,12 @@ def decorator(func): return decorator -def fixup_module_metadata(module_name, namespace): - seen_ids = set() +def fixup_module_metadata( + module_name: str, namespace: collections.abc.Mapping[str, object] +) -> None: + seen_ids: set[int] = set() - def fix_one(qualname, name, obj): + def fix_one(qualname: str, name: str, obj: object) -> None: # avoid infinite recursion (relevant when using # typing.Generic, for example) if id(obj) in seen_ids: @@ -227,7 +269,8 @@ def fix_one(qualname, name, obj): # rewriting these. if hasattr(obj, "__name__") and "." not in obj.__name__: obj.__name__ = name - obj.__qualname__ = qualname + if hasattr(obj, "__qualname__"): + obj.__qualname__ = qualname if isinstance(obj, type): for attr_name, attr_value in obj.__dict__.items(): fix_one(objname + "." + attr_name, attr_name, attr_value) @@ -237,7 +280,10 @@ def fix_one(qualname, name, obj): fix_one(objname, objname, obj) -class generic_function: +# We need ParamSpec to type this "properly", but that requires a runtime typing_extensions import +# to use as a class base. This is only used at runtime and isn't correct for type checkers anyway, +# so don't bother. +class generic_function(t.Generic[RetT]): """Decorator that makes a function indexable, to communicate non-inferrable generic type parameters to a static type checker. @@ -254,14 +300,14 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn): + def __init__(self, fn: t.Callable[..., RetT]) -> None: update_wrapper(self, fn) self._fn = fn - def __call__(self, *args, **kwargs): + def __call__(self, *args: t.Any, **kwargs: t.Any) -> RetT: return self._fn(*args, **kwargs) - def __getitem__(self, _): + def __getitem__(self, subscript: object) -> Self: return self @@ -280,7 +326,12 @@ class SomeClass(metaclass=Final): - TypeError if a subclass is created """ - def __new__(cls, name, bases, cls_namespace): + def __new__( + cls, + name: str, + bases: tuple[type, ...], + cls_namespace: dict[str, object], + ) -> Final: for base in bases: if isinstance(base, Final): raise TypeError( @@ -290,9 +341,6 @@ def __new__(cls, name, bases, cls_namespace): return super().__new__(cls, name, bases, cls_namespace) -T = t.TypeVar("T") - - class NoPublicConstructor(Final): """Metaclass that enforces a class to be final (i.e., subclass not allowed) and ensures a private constructor. @@ -312,16 +360,16 @@ class SomeClass(metaclass=NoPublicConstructor): - TypeError if a subclass or an instance is created. """ - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: object, **kwargs: object) -> None: raise TypeError( f"{cls.__module__}.{cls.__qualname__} has no public constructor" ) - def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T: + def _create(cls: t.Type[T], *args: object, **kwargs: object) -> T: return super().__call__(*args, **kwargs) # type: ignore -def name_asyncgen(agen): +def name_asyncgen(agen: AsyncGeneratorType[object, t.NoReturn]) -> str: """Return the fully-qualified name of the async generator function that produced the async generator iterator *agen*. """ diff --git a/trio/_version.py b/trio/_version.py index 7111a4849d..65242863a9 100644 --- a/trio/_version.py +++ b/trio/_version.py @@ -1,3 +1,3 @@ # This file is imported from __init__.py and exec'd from setup.py -__version__ = "0.22.0+dev" +__version__ = "0.22.2+dev" diff --git a/trio/_wait_for_object.py b/trio/_wait_for_object.py index 2e24682444..50a9d13ff2 100644 --- a/trio/_wait_for_object.py +++ b/trio/_wait_for_object.py @@ -1,16 +1,20 @@ +from __future__ import annotations + import math -from . import _timeouts + import trio + from ._core._windows_cffi import ( + CData, + ErrorCodes, + _handle, ffi, kernel32, - ErrorCodes, raise_winerror, - _handle, ) -async def WaitForSingleObject(obj): +async def WaitForSingleObject(obj: int | CData) -> None: """Async and cancellable variant of WaitForSingleObject. Windows only. Args: @@ -50,7 +54,7 @@ async def WaitForSingleObject(obj): kernel32.CloseHandle(cancel_handle) -def WaitForMultipleObjects_sync(*handles): +def WaitForMultipleObjects_sync(*handles: int | CData) -> None: """Wait for any of the given Windows handles to be signaled.""" n = len(handles) handle_arr = ffi.new(f"HANDLE[{n}]") diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index 693792ba0e..c1c357b018 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -1,9 +1,10 @@ import sys from typing import TYPE_CHECKING + from . import _core -from ._abc import SendStream, ReceiveStream +from ._abc import ReceiveStream, SendStream +from ._core._windows_cffi import _handle, kernel32, raise_winerror from ._util import ConflictDetector, Final -from ._core._windows_cffi import _handle, raise_winerror, kernel32, ffi assert sys.platform == "win32" or not TYPE_CHECKING diff --git a/trio/abc.py b/trio/abc.py index ce0a1f6c00..439995640e 100644 --- a/trio/abc.py +++ b/trio/abc.py @@ -4,18 +4,20 @@ # temporaries, imports, etc. when implementing the module. So we put the # implementation in an underscored module, and then re-export the public parts # here. + +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) from ._abc import ( - Clock, - Instrument, - AsyncResource, - SendStream, - ReceiveStream, - Stream, - HalfCloseableStream, - SocketFactory, - HostnameResolver, - Listener, - SendChannel, - ReceiveChannel, - Channel, + AsyncResource as AsyncResource, + Channel as Channel, + Clock as Clock, + HalfCloseableStream as HalfCloseableStream, + HostnameResolver as HostnameResolver, + Instrument as Instrument, + Listener as Listener, + ReceiveChannel as ReceiveChannel, + ReceiveStream as ReceiveStream, + SendChannel as SendChannel, + SendStream as SendStream, + SocketFactory as SocketFactory, + Stream as Stream, ) diff --git a/trio/from_thread.py b/trio/from_thread.py index 93dffd31f4..0de0023941 100644 --- a/trio/from_thread.py +++ b/trio/from_thread.py @@ -3,6 +3,12 @@ an external thread by means of a Trio Token present in Thread Local Storage """ -from ._threads import from_thread_run as run -from ._threads import from_thread_run_sync as run_sync -from ._threads import from_thread_check_cancelled as check_cancelled + +from ._threads import ( + from_thread_check_cancelled as check_cancelled, + from_thread_run as run, + from_thread_run_sync as run_sync, +) + +# need to use __all__ for pyright --verifytypes to see re-exports when renaming them +__all__ = ["check_cancelled", "run", "run_sync"] diff --git a/trio/lowlevel.py b/trio/lowlevel.py index 004692475f..25e64975e2 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -7,69 +7,72 @@ import sys import typing as _t -# This is the union of a subset of trio/_core/ and some things from trio/*.py. -# See comments in trio/__init__.py for details. To make static analysis easier, -# this lists all possible symbols from trio._core, and then we prune those that -# aren't available on this system. After that we add some symbols from trio/*.py. - # Generally available symbols from ._core import ( - cancel_shielded_checkpoint, - Abort, - RaiseCancelT, - wait_task_rescheduled, - enable_ki_protection, - disable_ki_protection, - currently_ki_protected, - Task, - checkpoint, - current_task, - ParkingLot, - UnboundedQueue, - RunVar, - TrioToken, - current_trio_token, - temporarily_detach_coroutine_object, - permanently_detach_coroutine_object, - reattach_detached_coroutine_object, - current_statistics, - reschedule, - remove_instrument, - add_instrument, - current_clock, - current_root_task, - checkpoint_if_cancelled, - spawn_system_task, - wait_readable, - wait_writable, - notify_closing, - start_thread_soon, - start_guest_run, + Abort as Abort, + ParkingLot as ParkingLot, + ParkingLotStatistics as ParkingLotStatistics, + RaiseCancelT as RaiseCancelT, + RunStatistics as RunStatistics, + RunVar as RunVar, + Task as Task, + TrioToken as TrioToken, + UnboundedQueue as UnboundedQueue, + UnboundedQueueStatistics as UnboundedQueueStatistics, + add_instrument as add_instrument, + cancel_shielded_checkpoint as cancel_shielded_checkpoint, + checkpoint as checkpoint, + checkpoint_if_cancelled as checkpoint_if_cancelled, + current_clock as current_clock, + current_root_task as current_root_task, + current_statistics as current_statistics, + current_task as current_task, + current_trio_token as current_trio_token, + currently_ki_protected as currently_ki_protected, + disable_ki_protection as disable_ki_protection, + enable_ki_protection as enable_ki_protection, + notify_closing as notify_closing, + permanently_detach_coroutine_object as permanently_detach_coroutine_object, + reattach_detached_coroutine_object as reattach_detached_coroutine_object, + remove_instrument as remove_instrument, + reschedule as reschedule, + spawn_system_task as spawn_system_task, + start_guest_run as start_guest_run, + start_thread_soon as start_thread_soon, + temporarily_detach_coroutine_object as temporarily_detach_coroutine_object, + wait_readable as wait_readable, + wait_task_rescheduled as wait_task_rescheduled, + wait_writable as wait_writable, ) +from ._subprocess import open_process as open_process + +# This is the union of a subset of trio/_core/ and some things from trio/*.py. +# See comments in trio/__init__.py for details. + +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) -from ._subprocess import open_process if sys.platform == "win32": # Windows symbols from ._core import ( - current_iocp, - register_with_iocp, - wait_overlapped, - monitor_completion_key, - readinto_overlapped, - write_overlapped, + current_iocp as current_iocp, + monitor_completion_key as monitor_completion_key, + readinto_overlapped as readinto_overlapped, + register_with_iocp as register_with_iocp, + wait_overlapped as wait_overlapped, + write_overlapped as write_overlapped, ) - from ._wait_for_object import WaitForSingleObject + from ._wait_for_object import WaitForSingleObject as WaitForSingleObject else: # Unix symbols - from ._unix_pipes import FdStream + from ._unix_pipes import FdStream as FdStream # Kqueue-specific symbols if sys.platform != "linux" and (_t.TYPE_CHECKING or not hasattr(_select, "epoll")): from ._core import ( - current_kqueue, - monitor_kevent, - wait_kevent, + current_kqueue as current_kqueue, + monitor_kevent as monitor_kevent, + wait_kevent as wait_kevent, ) del sys diff --git a/trio/py.typed b/trio/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trio/socket.py b/trio/socket.py index 4bbc7d14f6..f8d0bc3fc2 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -6,127 +6,16 @@ # here. # We still have some underscore names though but only a few. -from . import _socket -import sys -import typing as _t -# The socket module exports a bunch of platform-specific constants. We want to -# re-export them. Since the exact set of constants varies depending on Python -# version, platform, the libc installed on the system where Python was built, -# etc., we figure out which constants to re-export dynamically at runtime (see -# below). But that confuses static analysis tools like jedi and mypy. So this -# import statement statically lists every constant that *could* be -# exported. It always fails at runtime, since no single Python build exports -# all these constants, but it lets static analysis tools understand what's -# going on. There's a test in test_exports.py to make sure that the list is -# kept up to date. -try: - # fmt: off - from socket import ( # type: ignore - CMSG_LEN, CMSG_SPACE, CAPI, AF_UNSPEC, AF_INET, AF_UNIX, AF_IPX, - AF_APPLETALK, AF_INET6, AF_ROUTE, AF_LINK, AF_SNA, PF_SYSTEM, - AF_SYSTEM, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, SOCK_SEQPACKET, SOCK_RDM, - SO_DEBUG, SO_ACCEPTCONN, SO_REUSEADDR, SO_KEEPALIVE, SO_DONTROUTE, - SO_BROADCAST, SO_USELOOPBACK, SO_LINGER, SO_OOBINLINE, SO_REUSEPORT, - SO_SNDBUF, SO_RCVBUF, SO_SNDLOWAT, SO_RCVLOWAT, SO_SNDTIMEO, - SO_RCVTIMEO, SO_ERROR, SO_TYPE, LOCAL_PEERCRED, SOMAXCONN, SCM_RIGHTS, - SCM_CREDS, MSG_OOB, MSG_PEEK, MSG_DONTROUTE, MSG_DONTWAIT, MSG_EOR, - MSG_TRUNC, MSG_CTRUNC, MSG_WAITALL, MSG_EOF, SOL_SOCKET, SOL_IP, - SOL_TCP, SOL_UDP, IPPROTO_IP, IPPROTO_HOPOPTS, IPPROTO_ICMP, - IPPROTO_IGMP, IPPROTO_GGP, IPPROTO_IPV4, IPPROTO_IPIP, IPPROTO_TCP, - IPPROTO_EGP, IPPROTO_PUP, IPPROTO_UDP, IPPROTO_IDP, IPPROTO_HELLO, - IPPROTO_ND, IPPROTO_TP, IPPROTO_ROUTING, IPPROTO_FRAGMENT, - IPPROTO_RSVP, IPPROTO_GRE, IPPROTO_ESP, IPPROTO_AH, IPPROTO_ICMPV6, - IPPROTO_NONE, IPPROTO_DSTOPTS, IPPROTO_XTP, IPPROTO_EON, IPPROTO_PIM, - IPPROTO_IPCOMP, IPPROTO_SCTP, IPPROTO_RAW, IPPROTO_MAX, IPPROTO_MPTCP, - SYSPROTO_CONTROL, IPPORT_RESERVED, IPPORT_USERRESERVED, INADDR_ANY, - INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_UNSPEC_GROUP, - INADDR_ALLHOSTS_GROUP, INADDR_MAX_LOCAL_GROUP, INADDR_NONE, IP_OPTIONS, - IP_HDRINCL, IP_TOS, IP_TTL, IP_RECVOPTS, IP_RECVRETOPTS, - IP_RECVDSTADDR, IP_RETOPTS, IP_MULTICAST_IF, IP_MULTICAST_TTL, - IP_MULTICAST_LOOP, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, - IP_DEFAULT_MULTICAST_TTL, IP_DEFAULT_MULTICAST_LOOP, - IP_MAX_MEMBERSHIPS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, - IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, - IPV6_UNICAST_HOPS, IPV6_V6ONLY, IPV6_CHECKSUM, IPV6_RECVTCLASS, - IPV6_RTHDR_TYPE_0, IPV6_TCLASS, TCP_NODELAY, TCP_MAXSEG, TCP_KEEPINTVL, - TCP_KEEPCNT, TCP_FASTOPEN, TCP_NOTSENT_LOWAT, EAI_ADDRFAMILY, - EAI_AGAIN, EAI_BADFLAGS, EAI_FAIL, EAI_FAMILY, EAI_MEMORY, EAI_NODATA, - EAI_NONAME, EAI_OVERFLOW, EAI_SERVICE, EAI_SOCKTYPE, EAI_SYSTEM, - EAI_BADHINTS, EAI_PROTOCOL, EAI_MAX, AI_PASSIVE, AI_CANONNAME, - AI_NUMERICHOST, AI_NUMERICSERV, AI_MASK, AI_ALL, AI_V4MAPPED_CFG, - AI_ADDRCONFIG, AI_V4MAPPED, AI_DEFAULT, NI_MAXHOST, NI_MAXSERV, - NI_NOFQDN, NI_NUMERICHOST, NI_NAMEREQD, NI_NUMERICSERV, NI_DGRAM, - SHUT_RD, SHUT_WR, SHUT_RDWR, EBADF, EAGAIN, EWOULDBLOCK, AF_ASH, - AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_ECONET, - AF_IRDA, AF_KEY, AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, - AF_PPPOX, AF_ROSE, AF_SECURITY, AF_WANPIPE, AF_X25, BDADDR_ANY, - BDADDR_LOCAL, FD_SETSIZE, IPV6_DSTOPTS, IPV6_HOPLIMIT, IPV6_HOPOPTS, - IPV6_NEXTHOP, IPV6_PKTINFO, IPV6_RECVDSTOPTS, IPV6_RECVHOPLIMIT, - IPV6_RECVHOPOPTS, IPV6_RECVPKTINFO, IPV6_RECVRTHDR, IPV6_RTHDR, - IPV6_RTHDRDSTOPTS, MSG_ERRQUEUE, NETLINK_DNRTMSG, NETLINK_FIREWALL, - NETLINK_IP6_FW, NETLINK_NFLOG, NETLINK_ROUTE, NETLINK_USERSOCK, - NETLINK_XFRM, PACKET_BROADCAST, PACKET_FASTROUTE, PACKET_HOST, - PACKET_LOOPBACK, PACKET_MULTICAST, PACKET_OTHERHOST, PACKET_OUTGOING, - POLLERR, POLLHUP, POLLIN, POLLMSG, POLLNVAL, POLLOUT, POLLPRI, - POLLRDBAND, POLLRDNORM, POLLWRNORM, SIOCGIFINDEX, SIOCGIFNAME, - SOCK_CLOEXEC, TCP_CORK, TCP_DEFER_ACCEPT, TCP_INFO, TCP_KEEPIDLE, - TCP_LINGER2, TCP_QUICKACK, TCP_SYNCNT, TCP_WINDOW_CLAMP, AF_ALG, - AF_CAN, AF_RDS, AF_TIPC, AF_VSOCK, ALG_OP_DECRYPT, ALG_OP_ENCRYPT, - ALG_OP_SIGN, ALG_OP_VERIFY, ALG_SET_AEAD_ASSOCLEN, - ALG_SET_AEAD_AUTHSIZE, ALG_SET_IV, ALG_SET_KEY, ALG_SET_OP, - ALG_SET_PUBKEY, CAN_BCM, CAN_BCM_RX_CHANGED, CAN_BCM_RX_DELETE, - CAN_BCM_RX_READ, CAN_BCM_RX_SETUP, CAN_BCM_RX_STATUS, - CAN_BCM_RX_TIMEOUT, CAN_BCM_TX_DELETE, CAN_BCM_TX_EXPIRED, - CAN_BCM_TX_READ, CAN_BCM_TX_SEND, CAN_BCM_TX_SETUP, CAN_BCM_TX_STATUS, - CAN_EFF_FLAG, CAN_EFF_MASK, CAN_ERR_FLAG, CAN_ERR_MASK, CAN_ISOTP, - CAN_RAW, CAN_RAW_ERR_FILTER, CAN_RAW_FD_FRAMES, CAN_RAW_FILTER, - CAN_RAW_LOOPBACK, CAN_RAW_RECV_OWN_MSGS, CAN_RTR_FLAG, CAN_SFF_MASK, - IOCTL_VM_SOCKETS_GET_LOCAL_CID, IPV6_DONTFRAG, IPV6_PATHMTU, - IPV6_RECVPATHMTU, IP_TRANSPARENT, MSG_CMSG_CLOEXEC, MSG_CONFIRM, - MSG_FASTOPEN, MSG_MORE, MSG_NOSIGNAL, NETLINK_CRYPTO, PF_CAN, - PF_PACKET, PF_RDS, SCM_CREDENTIALS, SOCK_NONBLOCK, SOL_ALG, - SOL_CAN_BASE, SOL_CAN_RAW, SOL_TIPC, SO_BINDTODEVICE, SO_DOMAIN, - SO_MARK, SO_PASSCRED, SO_PASSSEC, SO_PEERCRED, SO_PEERSEC, SO_PRIORITY, - SO_PROTOCOL, SO_VM_SOCKETS_BUFFER_MAX_SIZE, - SO_VM_SOCKETS_BUFFER_MIN_SIZE, SO_VM_SOCKETS_BUFFER_SIZE, - TCP_CONGESTION, TCP_USER_TIMEOUT, TIPC_ADDR_ID, TIPC_ADDR_NAME, - TIPC_ADDR_NAMESEQ, TIPC_CFG_SRV, TIPC_CLUSTER_SCOPE, TIPC_CONN_TIMEOUT, - TIPC_CRITICAL_IMPORTANCE, TIPC_DEST_DROPPABLE, TIPC_HIGH_IMPORTANCE, - TIPC_IMPORTANCE, TIPC_LOW_IMPORTANCE, TIPC_MEDIUM_IMPORTANCE, - TIPC_NODE_SCOPE, TIPC_PUBLISHED, TIPC_SRC_DROPPABLE, - TIPC_SUBSCR_TIMEOUT, TIPC_SUB_CANCEL, TIPC_SUB_PORTS, TIPC_SUB_SERVICE, - TIPC_TOP_SRV, TIPC_WAIT_FOREVER, TIPC_WITHDRAWN, TIPC_ZONE_SCOPE, - VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_PORT_ANY, - VM_SOCKETS_INVALID_VERSION, MSG_BCAST, MSG_MCAST, RCVALL_MAX, - RCVALL_OFF, RCVALL_ON, RCVALL_SOCKETLEVELONLY, SIO_KEEPALIVE_VALS, - SIO_LOOPBACK_FAST_PATH, SIO_RCVALL, SO_EXCLUSIVEADDRUSE, HCI_FILTER, - BTPROTO_SCO, BTPROTO_HCI, HCI_TIME_STAMP, SOL_RDS, BTPROTO_L2CAP, - BTPROTO_RFCOMM, HCI_DATA_DIR, SOL_HCI, CAN_BCM_RX_ANNOUNCE_RESUME, - CAN_BCM_RX_CHECK_DLC, CAN_BCM_RX_FILTER_ID, CAN_BCM_RX_NO_AUTOTIMER, - CAN_BCM_RX_RTR_FRAME, CAN_BCM_SETTIMER, CAN_BCM_STARTTIMER, - CAN_BCM_TX_ANNOUNCE, CAN_BCM_TX_COUNTEVT, CAN_BCM_TX_CP_CAN_ID, - CAN_BCM_TX_RESET_MULTI_IDX, IPPROTO_CBT, IPPROTO_ICLFXBM, IPPROTO_IGP, - IPPROTO_L2TP, IPPROTO_PGM, IPPROTO_RDP, IPPROTO_ST, AF_QIPCRTR, - CAN_BCM_CAN_FD_FRAME, IPPROTO_MOBILE, IPV6_USE_MIN_MTU, - MSG_NOTIFICATION, SO_SETFIB, CAN_J1939, CAN_RAW_JOIN_FILTERS, - IPPROTO_UDPLITE, J1939_EE_INFO_NONE, J1939_EE_INFO_TX_ABORT, - J1939_FILTER_MAX, J1939_IDLE_ADDR, J1939_MAX_UNICAST_ADDR, - J1939_NLA_BYTES_ACKED, J1939_NLA_PAD, J1939_NO_ADDR, J1939_NO_NAME, - J1939_NO_PGN, J1939_PGN_ADDRESS_CLAIMED, J1939_PGN_ADDRESS_COMMANDED, - J1939_PGN_MAX, J1939_PGN_PDU1_MAX, J1939_PGN_REQUEST, - SCM_J1939_DEST_ADDR, SCM_J1939_DEST_NAME, SCM_J1939_ERRQUEUE, - SCM_J1939_PRIO, SO_J1939_ERRQUEUE, SO_J1939_FILTER, SO_J1939_PROMISC, - SO_J1939_SEND_PRIO, UDPLITE_RECV_CSCOV, UDPLITE_SEND_CSCOV, IP_RECVTOS, - TCP_KEEPALIVE, SO_INCOMING_CPU - ) - # fmt: on -except ImportError: - pass +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) # Dynamically re-export whatever constants this particular Python happens to # have: import socket as _stdlib_socket +import sys +import typing as _t + +from . import _socket _bad_symbols: _t.Set[str] = set() if sys.platform == "win32": @@ -145,49 +34,53 @@ # import the overwrites from ._socket import ( - fromfd, - from_stdlib_socket, - getprotobyname, - socketpair, - getnameinfo, - socket, - getaddrinfo, - set_custom_hostname_resolver, - set_custom_socket_factory, - SocketType, + Address as Address, + SocketType as SocketType, + _SocketType as _SocketType, + from_stdlib_socket as from_stdlib_socket, + fromfd as fromfd, + getaddrinfo as getaddrinfo, + getnameinfo as getnameinfo, + getprotobyname as getprotobyname, + set_custom_hostname_resolver as set_custom_hostname_resolver, + set_custom_socket_factory as set_custom_socket_factory, + socket as socket, + socketpair as socketpair, ) # not always available so expose only if if sys.platform == "win32" or not _t.TYPE_CHECKING: try: - from ._socket import fromshare + from ._socket import fromshare as fromshare except ImportError: pass # expose these functions to trio.socket from socket import ( - gaierror, - herror, - gethostname, - ntohs, - htonl, - htons, - inet_aton, - inet_ntoa, - inet_pton, - inet_ntop, + gaierror as gaierror, + gethostname as gethostname, + herror as herror, + htonl as htonl, + htons as htons, + inet_aton as inet_aton, + inet_ntoa as inet_ntoa, + inet_ntop as inet_ntop, + inet_pton as inet_pton, + ntohs as ntohs, ) # not always available so expose only if if sys.platform != "win32" or not _t.TYPE_CHECKING: try: - from socket import sethostname, if_nameindex, if_nametoindex, if_indextoname + from socket import ( + if_indextoname as if_indextoname, + if_nameindex as if_nameindex, + if_nametoindex as if_nametoindex, + sethostname as sethostname, + ) except ImportError: pass -# get names used by Trio that we define on our own -from ._socket import IPPROTO_IPV6 - if _t.TYPE_CHECKING: IP_BIND_ADDRESS_NO_PORT: int else: @@ -198,3 +91,477 @@ IP_BIND_ADDRESS_NO_PORT = 24 del sys + + +# The socket module exports a bunch of platform-specific constants. We want to +# re-export them. Since the exact set of constants varies depending on Python +# version, platform, the libc installed on the system where Python was built, +# etc., we figure out which constants to re-export dynamically at runtime (see +# below). But that confuses static analysis tools like jedi and mypy. So this +# import statement statically lists every constant that *could* be +# exported. There's a test in test_exports.py to make sure that the list is +# kept up to date. +if _t.TYPE_CHECKING: + from socket import ( # type: ignore[attr-defined] + AF_ALG as AF_ALG, + AF_APPLETALK as AF_APPLETALK, + AF_ASH as AF_ASH, + AF_ATMPVC as AF_ATMPVC, + AF_ATMSVC as AF_ATMSVC, + AF_AX25 as AF_AX25, + AF_BLUETOOTH as AF_BLUETOOTH, + AF_BRIDGE as AF_BRIDGE, + AF_CAN as AF_CAN, + AF_ECONET as AF_ECONET, + AF_INET as AF_INET, + AF_INET6 as AF_INET6, + AF_IPX as AF_IPX, + AF_IRDA as AF_IRDA, + AF_KEY as AF_KEY, + AF_LINK as AF_LINK, + AF_LLC as AF_LLC, + AF_NETBEUI as AF_NETBEUI, + AF_NETLINK as AF_NETLINK, + AF_NETROM as AF_NETROM, + AF_PACKET as AF_PACKET, + AF_PPPOX as AF_PPPOX, + AF_QIPCRTR as AF_QIPCRTR, + AF_RDS as AF_RDS, + AF_ROSE as AF_ROSE, + AF_ROUTE as AF_ROUTE, + AF_SECURITY as AF_SECURITY, + AF_SNA as AF_SNA, + AF_SYSTEM as AF_SYSTEM, + AF_TIPC as AF_TIPC, + AF_UNIX as AF_UNIX, + AF_UNSPEC as AF_UNSPEC, + AF_VSOCK as AF_VSOCK, + AF_WANPIPE as AF_WANPIPE, + AF_X25 as AF_X25, + AI_ADDRCONFIG as AI_ADDRCONFIG, + AI_ALL as AI_ALL, + AI_CANONNAME as AI_CANONNAME, + AI_DEFAULT as AI_DEFAULT, + AI_MASK as AI_MASK, + AI_NUMERICHOST as AI_NUMERICHOST, + AI_NUMERICSERV as AI_NUMERICSERV, + AI_PASSIVE as AI_PASSIVE, + AI_V4MAPPED as AI_V4MAPPED, + AI_V4MAPPED_CFG as AI_V4MAPPED_CFG, + ALG_OP_DECRYPT as ALG_OP_DECRYPT, + ALG_OP_ENCRYPT as ALG_OP_ENCRYPT, + ALG_OP_SIGN as ALG_OP_SIGN, + ALG_OP_VERIFY as ALG_OP_VERIFY, + ALG_SET_AEAD_ASSOCLEN as ALG_SET_AEAD_ASSOCLEN, + ALG_SET_AEAD_AUTHSIZE as ALG_SET_AEAD_AUTHSIZE, + ALG_SET_IV as ALG_SET_IV, + ALG_SET_KEY as ALG_SET_KEY, + ALG_SET_OP as ALG_SET_OP, + ALG_SET_PUBKEY as ALG_SET_PUBKEY, + BDADDR_ANY as BDADDR_ANY, + BDADDR_LOCAL as BDADDR_LOCAL, + BTPROTO_HCI as BTPROTO_HCI, + BTPROTO_L2CAP as BTPROTO_L2CAP, + BTPROTO_RFCOMM as BTPROTO_RFCOMM, + BTPROTO_SCO as BTPROTO_SCO, + CAN_BCM as CAN_BCM, + CAN_BCM_CAN_FD_FRAME as CAN_BCM_CAN_FD_FRAME, + CAN_BCM_RX_ANNOUNCE_RESUME as CAN_BCM_RX_ANNOUNCE_RESUME, + CAN_BCM_RX_CHANGED as CAN_BCM_RX_CHANGED, + CAN_BCM_RX_CHECK_DLC as CAN_BCM_RX_CHECK_DLC, + CAN_BCM_RX_DELETE as CAN_BCM_RX_DELETE, + CAN_BCM_RX_FILTER_ID as CAN_BCM_RX_FILTER_ID, + CAN_BCM_RX_NO_AUTOTIMER as CAN_BCM_RX_NO_AUTOTIMER, + CAN_BCM_RX_READ as CAN_BCM_RX_READ, + CAN_BCM_RX_RTR_FRAME as CAN_BCM_RX_RTR_FRAME, + CAN_BCM_RX_SETUP as CAN_BCM_RX_SETUP, + CAN_BCM_RX_STATUS as CAN_BCM_RX_STATUS, + CAN_BCM_RX_TIMEOUT as CAN_BCM_RX_TIMEOUT, + CAN_BCM_SETTIMER as CAN_BCM_SETTIMER, + CAN_BCM_STARTTIMER as CAN_BCM_STARTTIMER, + CAN_BCM_TX_ANNOUNCE as CAN_BCM_TX_ANNOUNCE, + CAN_BCM_TX_COUNTEVT as CAN_BCM_TX_COUNTEVT, + CAN_BCM_TX_CP_CAN_ID as CAN_BCM_TX_CP_CAN_ID, + CAN_BCM_TX_DELETE as CAN_BCM_TX_DELETE, + CAN_BCM_TX_EXPIRED as CAN_BCM_TX_EXPIRED, + CAN_BCM_TX_READ as CAN_BCM_TX_READ, + CAN_BCM_TX_RESET_MULTI_IDX as CAN_BCM_TX_RESET_MULTI_IDX, + CAN_BCM_TX_SEND as CAN_BCM_TX_SEND, + CAN_BCM_TX_SETUP as CAN_BCM_TX_SETUP, + CAN_BCM_TX_STATUS as CAN_BCM_TX_STATUS, + CAN_EFF_FLAG as CAN_EFF_FLAG, + CAN_EFF_MASK as CAN_EFF_MASK, + CAN_ERR_FLAG as CAN_ERR_FLAG, + CAN_ERR_MASK as CAN_ERR_MASK, + CAN_ISOTP as CAN_ISOTP, + CAN_J1939 as CAN_J1939, + CAN_RAW as CAN_RAW, + CAN_RAW_ERR_FILTER as CAN_RAW_ERR_FILTER, + CAN_RAW_FD_FRAMES as CAN_RAW_FD_FRAMES, + CAN_RAW_FILTER as CAN_RAW_FILTER, + CAN_RAW_JOIN_FILTERS as CAN_RAW_JOIN_FILTERS, + CAN_RAW_LOOPBACK as CAN_RAW_LOOPBACK, + CAN_RAW_RECV_OWN_MSGS as CAN_RAW_RECV_OWN_MSGS, + CAN_RTR_FLAG as CAN_RTR_FLAG, + CAN_SFF_MASK as CAN_SFF_MASK, + CAPI as CAPI, + CMSG_LEN as CMSG_LEN, + CMSG_SPACE as CMSG_SPACE, + EAGAIN as EAGAIN, + EAI_ADDRFAMILY as EAI_ADDRFAMILY, + EAI_AGAIN as EAI_AGAIN, + EAI_BADFLAGS as EAI_BADFLAGS, + EAI_BADHINTS as EAI_BADHINTS, + EAI_FAIL as EAI_FAIL, + EAI_FAMILY as EAI_FAMILY, + EAI_MAX as EAI_MAX, + EAI_MEMORY as EAI_MEMORY, + EAI_NODATA as EAI_NODATA, + EAI_NONAME as EAI_NONAME, + EAI_OVERFLOW as EAI_OVERFLOW, + EAI_PROTOCOL as EAI_PROTOCOL, + EAI_SERVICE as EAI_SERVICE, + EAI_SOCKTYPE as EAI_SOCKTYPE, + EAI_SYSTEM as EAI_SYSTEM, + EBADF as EBADF, + ETH_P_ALL as ETH_P_ALL, + ETHERTYPE_ARP as ETHERTYPE_ARP, + ETHERTYPE_IP as ETHERTYPE_IP, + ETHERTYPE_IPV6 as ETHERTYPE_IPV6, + ETHERTYPE_VLAN as ETHERTYPE_VLAN, + EWOULDBLOCK as EWOULDBLOCK, + FD_ACCEPT as FD_ACCEPT, + FD_CLOSE as FD_CLOSE, + FD_CLOSE_BIT as FD_CLOSE_BIT, + FD_CONNECT as FD_CONNECT, + FD_CONNECT_BIT as FD_CONNECT_BIT, + FD_READ as FD_READ, + FD_SETSIZE as FD_SETSIZE, + FD_WRITE as FD_WRITE, + HCI_DATA_DIR as HCI_DATA_DIR, + HCI_FILTER as HCI_FILTER, + HCI_TIME_STAMP as HCI_TIME_STAMP, + INADDR_ALLHOSTS_GROUP as INADDR_ALLHOSTS_GROUP, + INADDR_ANY as INADDR_ANY, + INADDR_BROADCAST as INADDR_BROADCAST, + INADDR_LOOPBACK as INADDR_LOOPBACK, + INADDR_MAX_LOCAL_GROUP as INADDR_MAX_LOCAL_GROUP, + INADDR_NONE as INADDR_NONE, + INADDR_UNSPEC_GROUP as INADDR_UNSPEC_GROUP, + INFINITE as INFINITE, + IOCTL_VM_SOCKETS_GET_LOCAL_CID as IOCTL_VM_SOCKETS_GET_LOCAL_CID, + IP_ADD_MEMBERSHIP as IP_ADD_MEMBERSHIP, + IP_ADD_SOURCE_MEMBERSHIP as IP_ADD_SOURCE_MEMBERSHIP, + IP_BLOCK_SOURCE as IP_BLOCK_SOURCE, + IP_DEFAULT_MULTICAST_LOOP as IP_DEFAULT_MULTICAST_LOOP, + IP_DEFAULT_MULTICAST_TTL as IP_DEFAULT_MULTICAST_TTL, + IP_DROP_MEMBERSHIP as IP_DROP_MEMBERSHIP, + IP_DROP_SOURCE_MEMBERSHIP as IP_DROP_SOURCE_MEMBERSHIP, + IP_HDRINCL as IP_HDRINCL, + IP_MAX_MEMBERSHIPS as IP_MAX_MEMBERSHIPS, + IP_MULTICAST_IF as IP_MULTICAST_IF, + IP_MULTICAST_LOOP as IP_MULTICAST_LOOP, + IP_MULTICAST_TTL as IP_MULTICAST_TTL, + IP_OPTIONS as IP_OPTIONS, + IP_PKTINFO as IP_PKTINFO, + IP_RECVDSTADDR as IP_RECVDSTADDR, + IP_RECVOPTS as IP_RECVOPTS, + IP_RECVRETOPTS as IP_RECVRETOPTS, + IP_RECVTOS as IP_RECVTOS, + IP_RETOPTS as IP_RETOPTS, + IP_TOS as IP_TOS, + IP_TRANSPARENT as IP_TRANSPARENT, + IP_TTL as IP_TTL, + IP_UNBLOCK_SOURCE as IP_UNBLOCK_SOURCE, + IPPORT_RESERVED as IPPORT_RESERVED, + IPPORT_USERRESERVED as IPPORT_USERRESERVED, + IPPROTO_AH as IPPROTO_AH, + IPPROTO_CBT as IPPROTO_CBT, + IPPROTO_DSTOPTS as IPPROTO_DSTOPTS, + IPPROTO_EGP as IPPROTO_EGP, + IPPROTO_EON as IPPROTO_EON, + IPPROTO_ESP as IPPROTO_ESP, + IPPROTO_FRAGMENT as IPPROTO_FRAGMENT, + IPPROTO_GGP as IPPROTO_GGP, + IPPROTO_GRE as IPPROTO_GRE, + IPPROTO_HELLO as IPPROTO_HELLO, + IPPROTO_HOPOPTS as IPPROTO_HOPOPTS, + IPPROTO_ICLFXBM as IPPROTO_ICLFXBM, + IPPROTO_ICMP as IPPROTO_ICMP, + IPPROTO_ICMPV6 as IPPROTO_ICMPV6, + IPPROTO_IDP as IPPROTO_IDP, + IPPROTO_IGMP as IPPROTO_IGMP, + IPPROTO_IGP as IPPROTO_IGP, + IPPROTO_IP as IPPROTO_IP, + IPPROTO_IPCOMP as IPPROTO_IPCOMP, + IPPROTO_IPIP as IPPROTO_IPIP, + IPPROTO_IPV4 as IPPROTO_IPV4, + IPPROTO_IPV6 as IPPROTO_IPV6, + IPPROTO_L2TP as IPPROTO_L2TP, + IPPROTO_MAX as IPPROTO_MAX, + IPPROTO_MOBILE as IPPROTO_MOBILE, + IPPROTO_MPTCP as IPPROTO_MPTCP, + IPPROTO_ND as IPPROTO_ND, + IPPROTO_NONE as IPPROTO_NONE, + IPPROTO_PGM as IPPROTO_PGM, + IPPROTO_PIM as IPPROTO_PIM, + IPPROTO_PUP as IPPROTO_PUP, + IPPROTO_RAW as IPPROTO_RAW, + IPPROTO_RDP as IPPROTO_RDP, + IPPROTO_ROUTING as IPPROTO_ROUTING, + IPPROTO_RSVP as IPPROTO_RSVP, + IPPROTO_SCTP as IPPROTO_SCTP, + IPPROTO_ST as IPPROTO_ST, + IPPROTO_TCP as IPPROTO_TCP, + IPPROTO_TP as IPPROTO_TP, + IPPROTO_UDP as IPPROTO_UDP, + IPPROTO_UDPLITE as IPPROTO_UDPLITE, + IPPROTO_XTP as IPPROTO_XTP, + IPV6_CHECKSUM as IPV6_CHECKSUM, + IPV6_DONTFRAG as IPV6_DONTFRAG, + IPV6_DSTOPTS as IPV6_DSTOPTS, + IPV6_HOPLIMIT as IPV6_HOPLIMIT, + IPV6_HOPOPTS as IPV6_HOPOPTS, + IPV6_JOIN_GROUP as IPV6_JOIN_GROUP, + IPV6_LEAVE_GROUP as IPV6_LEAVE_GROUP, + IPV6_MULTICAST_HOPS as IPV6_MULTICAST_HOPS, + IPV6_MULTICAST_IF as IPV6_MULTICAST_IF, + IPV6_MULTICAST_LOOP as IPV6_MULTICAST_LOOP, + IPV6_NEXTHOP as IPV6_NEXTHOP, + IPV6_PATHMTU as IPV6_PATHMTU, + IPV6_PKTINFO as IPV6_PKTINFO, + IPV6_RECVDSTOPTS as IPV6_RECVDSTOPTS, + IPV6_RECVHOPLIMIT as IPV6_RECVHOPLIMIT, + IPV6_RECVHOPOPTS as IPV6_RECVHOPOPTS, + IPV6_RECVPATHMTU as IPV6_RECVPATHMTU, + IPV6_RECVPKTINFO as IPV6_RECVPKTINFO, + IPV6_RECVRTHDR as IPV6_RECVRTHDR, + IPV6_RECVTCLASS as IPV6_RECVTCLASS, + IPV6_RTHDR as IPV6_RTHDR, + IPV6_RTHDR_TYPE_0 as IPV6_RTHDR_TYPE_0, + IPV6_RTHDRDSTOPTS as IPV6_RTHDRDSTOPTS, + IPV6_TCLASS as IPV6_TCLASS, + IPV6_UNICAST_HOPS as IPV6_UNICAST_HOPS, + IPV6_USE_MIN_MTU as IPV6_USE_MIN_MTU, + IPV6_V6ONLY as IPV6_V6ONLY, + J1939_EE_INFO_NONE as J1939_EE_INFO_NONE, + J1939_EE_INFO_TX_ABORT as J1939_EE_INFO_TX_ABORT, + J1939_FILTER_MAX as J1939_FILTER_MAX, + J1939_IDLE_ADDR as J1939_IDLE_ADDR, + J1939_MAX_UNICAST_ADDR as J1939_MAX_UNICAST_ADDR, + J1939_NLA_BYTES_ACKED as J1939_NLA_BYTES_ACKED, + J1939_NLA_PAD as J1939_NLA_PAD, + J1939_NO_ADDR as J1939_NO_ADDR, + J1939_NO_NAME as J1939_NO_NAME, + J1939_NO_PGN as J1939_NO_PGN, + J1939_PGN_ADDRESS_CLAIMED as J1939_PGN_ADDRESS_CLAIMED, + J1939_PGN_ADDRESS_COMMANDED as J1939_PGN_ADDRESS_COMMANDED, + J1939_PGN_MAX as J1939_PGN_MAX, + J1939_PGN_PDU1_MAX as J1939_PGN_PDU1_MAX, + J1939_PGN_REQUEST as J1939_PGN_REQUEST, + LOCAL_PEERCRED as LOCAL_PEERCRED, + MSG_BCAST as MSG_BCAST, + MSG_CMSG_CLOEXEC as MSG_CMSG_CLOEXEC, + MSG_CONFIRM as MSG_CONFIRM, + MSG_CTRUNC as MSG_CTRUNC, + MSG_DONTROUTE as MSG_DONTROUTE, + MSG_DONTWAIT as MSG_DONTWAIT, + MSG_EOF as MSG_EOF, + MSG_EOR as MSG_EOR, + MSG_ERRQUEUE as MSG_ERRQUEUE, + MSG_FASTOPEN as MSG_FASTOPEN, + MSG_MCAST as MSG_MCAST, + MSG_MORE as MSG_MORE, + MSG_NOSIGNAL as MSG_NOSIGNAL, + MSG_NOTIFICATION as MSG_NOTIFICATION, + MSG_OOB as MSG_OOB, + MSG_PEEK as MSG_PEEK, + MSG_TRUNC as MSG_TRUNC, + MSG_WAITALL as MSG_WAITALL, + NETLINK_CRYPTO as NETLINK_CRYPTO, + NETLINK_DNRTMSG as NETLINK_DNRTMSG, + NETLINK_FIREWALL as NETLINK_FIREWALL, + NETLINK_IP6_FW as NETLINK_IP6_FW, + NETLINK_NFLOG as NETLINK_NFLOG, + NETLINK_ROUTE as NETLINK_ROUTE, + NETLINK_USERSOCK as NETLINK_USERSOCK, + NETLINK_XFRM as NETLINK_XFRM, + NI_DGRAM as NI_DGRAM, + NI_MAXHOST as NI_MAXHOST, + NI_MAXSERV as NI_MAXSERV, + NI_NAMEREQD as NI_NAMEREQD, + NI_NOFQDN as NI_NOFQDN, + NI_NUMERICHOST as NI_NUMERICHOST, + NI_NUMERICSERV as NI_NUMERICSERV, + PACKET_BROADCAST as PACKET_BROADCAST, + PACKET_FASTROUTE as PACKET_FASTROUTE, + PACKET_HOST as PACKET_HOST, + PACKET_LOOPBACK as PACKET_LOOPBACK, + PACKET_MULTICAST as PACKET_MULTICAST, + PACKET_OTHERHOST as PACKET_OTHERHOST, + PACKET_OUTGOING as PACKET_OUTGOING, + PF_CAN as PF_CAN, + PF_PACKET as PF_PACKET, + PF_RDS as PF_RDS, + PF_SYSTEM as PF_SYSTEM, + POLLERR as POLLERR, + POLLHUP as POLLHUP, + POLLIN as POLLIN, + POLLMSG as POLLMSG, + POLLNVAL as POLLNVAL, + POLLOUT as POLLOUT, + POLLPRI as POLLPRI, + POLLRDBAND as POLLRDBAND, + POLLRDNORM as POLLRDNORM, + POLLWRNORM as POLLWRNORM, + RCVALL_MAX as RCVALL_MAX, + RCVALL_OFF as RCVALL_OFF, + RCVALL_ON as RCVALL_ON, + RCVALL_SOCKETLEVELONLY as RCVALL_SOCKETLEVELONLY, + SCM_CREDENTIALS as SCM_CREDENTIALS, + SCM_CREDS as SCM_CREDS, + SCM_J1939_DEST_ADDR as SCM_J1939_DEST_ADDR, + SCM_J1939_DEST_NAME as SCM_J1939_DEST_NAME, + SCM_J1939_ERRQUEUE as SCM_J1939_ERRQUEUE, + SCM_J1939_PRIO as SCM_J1939_PRIO, + SCM_RIGHTS as SCM_RIGHTS, + SHUT_RD as SHUT_RD, + SHUT_RDWR as SHUT_RDWR, + SHUT_WR as SHUT_WR, + SIO_KEEPALIVE_VALS as SIO_KEEPALIVE_VALS, + SIO_LOOPBACK_FAST_PATH as SIO_LOOPBACK_FAST_PATH, + SIO_RCVALL as SIO_RCVALL, + SIOCGIFINDEX as SIOCGIFINDEX, + SIOCGIFNAME as SIOCGIFNAME, + SO_ACCEPTCONN as SO_ACCEPTCONN, + SO_BINDTODEVICE as SO_BINDTODEVICE, + SO_BROADCAST as SO_BROADCAST, + SO_DEBUG as SO_DEBUG, + SO_DOMAIN as SO_DOMAIN, + SO_DONTROUTE as SO_DONTROUTE, + SO_ERROR as SO_ERROR, + SO_EXCLUSIVEADDRUSE as SO_EXCLUSIVEADDRUSE, + SO_INCOMING_CPU as SO_INCOMING_CPU, + SO_J1939_ERRQUEUE as SO_J1939_ERRQUEUE, + SO_J1939_FILTER as SO_J1939_FILTER, + SO_J1939_PROMISC as SO_J1939_PROMISC, + SO_J1939_SEND_PRIO as SO_J1939_SEND_PRIO, + SO_KEEPALIVE as SO_KEEPALIVE, + SO_LINGER as SO_LINGER, + SO_MARK as SO_MARK, + SO_OOBINLINE as SO_OOBINLINE, + SO_PASSCRED as SO_PASSCRED, + SO_PASSSEC as SO_PASSSEC, + SO_PEERCRED as SO_PEERCRED, + SO_PEERSEC as SO_PEERSEC, + SO_PRIORITY as SO_PRIORITY, + SO_PROTOCOL as SO_PROTOCOL, + SO_RCVBUF as SO_RCVBUF, + SO_RCVLOWAT as SO_RCVLOWAT, + SO_RCVTIMEO as SO_RCVTIMEO, + SO_REUSEADDR as SO_REUSEADDR, + SO_REUSEPORT as SO_REUSEPORT, + SO_SETFIB as SO_SETFIB, + SO_SNDBUF as SO_SNDBUF, + SO_SNDLOWAT as SO_SNDLOWAT, + SO_SNDTIMEO as SO_SNDTIMEO, + SO_TYPE as SO_TYPE, + SO_USELOOPBACK as SO_USELOOPBACK, + SO_VM_SOCKETS_BUFFER_MAX_SIZE as SO_VM_SOCKETS_BUFFER_MAX_SIZE, + SO_VM_SOCKETS_BUFFER_MIN_SIZE as SO_VM_SOCKETS_BUFFER_MIN_SIZE, + SO_VM_SOCKETS_BUFFER_SIZE as SO_VM_SOCKETS_BUFFER_SIZE, + SOCK_CLOEXEC as SOCK_CLOEXEC, + SOCK_DGRAM as SOCK_DGRAM, + SOCK_NONBLOCK as SOCK_NONBLOCK, + SOCK_RAW as SOCK_RAW, + SOCK_RDM as SOCK_RDM, + SOCK_SEQPACKET as SOCK_SEQPACKET, + SOCK_STREAM as SOCK_STREAM, + SOL_ALG as SOL_ALG, + SOL_CAN_BASE as SOL_CAN_BASE, + SOL_CAN_RAW as SOL_CAN_RAW, + SOL_HCI as SOL_HCI, + SOL_IP as SOL_IP, + SOL_RDS as SOL_RDS, + SOL_SOCKET as SOL_SOCKET, + SOL_TCP as SOL_TCP, + SOL_TIPC as SOL_TIPC, + SOL_UDP as SOL_UDP, + SOMAXCONN as SOMAXCONN, + SYSPROTO_CONTROL as SYSPROTO_CONTROL, + TCP_CC_INFO as TCP_CC_INFO, + TCP_CONGESTION as TCP_CONGESTION, + TCP_CORK as TCP_CORK, + TCP_DEFER_ACCEPT as TCP_DEFER_ACCEPT, + TCP_FASTOPEN as TCP_FASTOPEN, + TCP_FASTOPEN_CONNECT as TCP_FASTOPEN_CONNECT, + TCP_FASTOPEN_KEY as TCP_FASTOPEN_KEY, + TCP_FASTOPEN_NO_COOKIE as TCP_FASTOPEN_NO_COOKIE, + TCP_INFO as TCP_INFO, + TCP_INQ as TCP_INQ, + TCP_KEEPALIVE as TCP_KEEPALIVE, + TCP_KEEPCNT as TCP_KEEPCNT, + TCP_KEEPIDLE as TCP_KEEPIDLE, + TCP_KEEPINTVL as TCP_KEEPINTVL, + TCP_LINGER2 as TCP_LINGER2, + TCP_MAXSEG as TCP_MAXSEG, + TCP_MD5SIG as TCP_MD5SIG, + TCP_MD5SIG_EXT as TCP_MD5SIG_EXT, + TCP_NODELAY as TCP_NODELAY, + TCP_NOTSENT_LOWAT as TCP_NOTSENT_LOWAT, + TCP_QUEUE_SEQ as TCP_QUEUE_SEQ, + TCP_QUICKACK as TCP_QUICKACK, + TCP_REPAIR as TCP_REPAIR, + TCP_REPAIR_OPTIONS as TCP_REPAIR_OPTIONS, + TCP_REPAIR_QUEUE as TCP_REPAIR_QUEUE, + TCP_REPAIR_WINDOW as TCP_REPAIR_WINDOW, + TCP_SAVE_SYN as TCP_SAVE_SYN, + TCP_SAVED_SYN as TCP_SAVED_SYN, + TCP_SYNCNT as TCP_SYNCNT, + TCP_THIN_DUPACK as TCP_THIN_DUPACK, + TCP_THIN_LINEAR_TIMEOUTS as TCP_THIN_LINEAR_TIMEOUTS, + TCP_TIMESTAMP as TCP_TIMESTAMP, + TCP_TX_DELAY as TCP_TX_DELAY, + TCP_ULP as TCP_ULP, + TCP_USER_TIMEOUT as TCP_USER_TIMEOUT, + TCP_WINDOW_CLAMP as TCP_WINDOW_CLAMP, + TCP_ZEROCOPY_RECEIVE as TCP_ZEROCOPY_RECEIVE, + TIPC_ADDR_ID as TIPC_ADDR_ID, + TIPC_ADDR_NAME as TIPC_ADDR_NAME, + TIPC_ADDR_NAMESEQ as TIPC_ADDR_NAMESEQ, + TIPC_CFG_SRV as TIPC_CFG_SRV, + TIPC_CLUSTER_SCOPE as TIPC_CLUSTER_SCOPE, + TIPC_CONN_TIMEOUT as TIPC_CONN_TIMEOUT, + TIPC_CRITICAL_IMPORTANCE as TIPC_CRITICAL_IMPORTANCE, + TIPC_DEST_DROPPABLE as TIPC_DEST_DROPPABLE, + TIPC_HIGH_IMPORTANCE as TIPC_HIGH_IMPORTANCE, + TIPC_IMPORTANCE as TIPC_IMPORTANCE, + TIPC_LOW_IMPORTANCE as TIPC_LOW_IMPORTANCE, + TIPC_MEDIUM_IMPORTANCE as TIPC_MEDIUM_IMPORTANCE, + TIPC_NODE_SCOPE as TIPC_NODE_SCOPE, + TIPC_PUBLISHED as TIPC_PUBLISHED, + TIPC_SRC_DROPPABLE as TIPC_SRC_DROPPABLE, + TIPC_SUB_CANCEL as TIPC_SUB_CANCEL, + TIPC_SUB_PORTS as TIPC_SUB_PORTS, + TIPC_SUB_SERVICE as TIPC_SUB_SERVICE, + TIPC_SUBSCR_TIMEOUT as TIPC_SUBSCR_TIMEOUT, + TIPC_TOP_SRV as TIPC_TOP_SRV, + TIPC_WAIT_FOREVER as TIPC_WAIT_FOREVER, + TIPC_WITHDRAWN as TIPC_WITHDRAWN, + TIPC_ZONE_SCOPE as TIPC_ZONE_SCOPE, + UDPLITE_RECV_CSCOV as UDPLITE_RECV_CSCOV, + UDPLITE_SEND_CSCOV as UDPLITE_SEND_CSCOV, + VM_SOCKETS_INVALID_VERSION as VM_SOCKETS_INVALID_VERSION, + VMADDR_CID_ANY as VMADDR_CID_ANY, + VMADDR_CID_HOST as VMADDR_CID_HOST, + VMADDR_PORT_ANY as VMADDR_PORT_ANY, + WSA_FLAG_OVERLAPPED as WSA_FLAG_OVERLAPPED, + WSA_INVALID_HANDLE as WSA_INVALID_HANDLE, + WSA_INVALID_PARAMETER as WSA_INVALID_PARAMETER, + WSA_IO_INCOMPLETE as WSA_IO_INCOMPLETE, + WSA_IO_PENDING as WSA_IO_PENDING, + WSA_NOT_ENOUGH_MEMORY as WSA_NOT_ENOUGH_MEMORY, + WSA_OPERATION_ABORTED as WSA_OPERATION_ABORTED, + WSA_WAIT_FAILED as WSA_WAIT_FAILED, + WSA_WAIT_TIMEOUT as WSA_WAIT_TIMEOUT, + ) diff --git a/trio/testing/__init__.py b/trio/testing/__init__.py index aa15c4743e..fa683e1145 100644 --- a/trio/testing/__init__.py +++ b/trio/testing/__init__.py @@ -1,32 +1,34 @@ -from .._core import wait_all_tasks_blocked, MockClock - -from ._trio_test import trio_test - -from ._checkpoints import assert_checkpoints, assert_no_checkpoints - -from ._sequencer import Sequencer +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) +from .._core import ( + MockClock as MockClock, + wait_all_tasks_blocked as wait_all_tasks_blocked, +) +from .._util import fixup_module_metadata from ._check_streams import ( - check_one_way_stream, - check_two_way_stream, - check_half_closeable_stream, + check_half_closeable_stream as check_half_closeable_stream, + check_one_way_stream as check_one_way_stream, + check_two_way_stream as check_two_way_stream, +) +from ._checkpoints import ( + assert_checkpoints as assert_checkpoints, + assert_no_checkpoints as assert_no_checkpoints, ) - from ._memory_streams import ( - MemorySendStream, - MemoryReceiveStream, - memory_stream_pump, - memory_stream_one_way_pair, - memory_stream_pair, - lockstep_stream_one_way_pair, - lockstep_stream_pair, + MemoryReceiveStream as MemoryReceiveStream, + MemorySendStream as MemorySendStream, + lockstep_stream_one_way_pair as lockstep_stream_one_way_pair, + lockstep_stream_pair as lockstep_stream_pair, + memory_stream_one_way_pair as memory_stream_one_way_pair, + memory_stream_pair as memory_stream_pair, + memory_stream_pump as memory_stream_pump, ) - -from ._network import open_stream_to_socket_listener +from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener +from ._sequencer import Sequencer as Sequencer +from ._trio_test import trio_test as trio_test ################################################################ -from .._util import fixup_module_metadata fixup_module_metadata(__name__, globals()) del fixup_module_metadata diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 0206f1f737..33947ccc55 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -1,30 +1,49 @@ # Generic stream tests +from __future__ import annotations -from contextlib import contextmanager import random +from collections.abc import Generator +from contextlib import contextmanager +from typing import TYPE_CHECKING, Awaitable, Callable, Generic, Tuple, TypeVar -from .. import _core +from .. import CancelScope, _core +from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream from .._highlevel_generic import aclose_forcefully -from .._abc import SendStream, ReceiveStream, Stream, HalfCloseableStream from ._checkpoints import assert_checkpoints +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import ParamSpec, TypeAlias + + ArgsT = ParamSpec("ArgsT") -class _ForceCloseBoth: - def __init__(self, both): - self._both = list(both) +Res1 = TypeVar("Res1", bound=AsyncResource) +Res2 = TypeVar("Res2", bound=AsyncResource) +StreamMaker: TypeAlias = Callable[[], Awaitable[Tuple[Res1, Res2]]] - async def __aenter__(self): - return self._both - async def __aexit__(self, *args): +class _ForceCloseBoth(Generic[Res1, Res2]): + def __init__(self, both: tuple[Res1, Res2]) -> None: + self._first, self._second = both + + async def __aenter__(self) -> tuple[Res1, Res2]: + return self._first, self._second + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: try: - await aclose_forcefully(self._both[0]) + await aclose_forcefully(self._first) finally: - await aclose_forcefully(self._both[1]) + await aclose_forcefully(self._second) @contextmanager -def _assert_raises(exc): +def _assert_raises(exc: type[BaseException]) -> Generator[None, None, None]: __tracebackhide__ = True try: yield @@ -34,7 +53,10 @@ def _assert_raises(exc): raise AssertionError(f"expected exception: {exc}") -async def check_one_way_stream(stream_maker, clogged_stream_maker): +async def check_one_way_stream( + stream_maker: StreamMaker[SendStream, ReceiveStream], + clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None, +) -> None: """Perform a number of generic tests on a custom one-way stream implementation. @@ -57,18 +79,18 @@ async def check_one_way_stream(stream_maker, clogged_stream_maker): assert isinstance(s, SendStream) assert isinstance(r, ReceiveStream) - async def do_send_all(data): - with assert_checkpoints(): - assert await s.send_all(data) is None + async def do_send_all(data: bytes | bytearray | memoryview) -> None: + with assert_checkpoints(): # We're testing that it doesn't return anything. + assert await s.send_all(data) is None # type: ignore[func-returns-value] - async def do_receive_some(*args): + async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray: with assert_checkpoints(): - return await r.receive_some(*args) + return await r.receive_some(max_bytes) - async def checked_receive_1(expected): + async def checked_receive_1(expected: bytes) -> None: assert await do_receive_some(1) == expected - async def do_aclose(resource): + async def do_aclose(resource: AsyncResource) -> None: with assert_checkpoints(): await resource.aclose() @@ -77,7 +99,7 @@ async def do_aclose(resource): nursery.start_soon(do_send_all, b"x") nursery.start_soon(checked_receive_1, b"x") - async def send_empty_then_y(): + async def send_empty_then_y() -> None: # Streams should tolerate sending b"" without giving it any # special meaning. await do_send_all(b"") @@ -104,7 +126,7 @@ async def send_empty_then_y(): with _assert_raises(ValueError): await r.receive_some(0) with _assert_raises(TypeError): - await r.receive_some(1.5) + await r.receive_some(1.5) # type: ignore[arg-type] # it can also be missing or None async with _core.open_nursery() as nursery: nursery.start_soon(do_send_all, b"x") @@ -123,7 +145,9 @@ async def send_empty_then_y(): # for send_all to wait until receive_some is called to run, though; a # stream doesn't *have* to have any internal buffering. That's why we # start a concurrent receive_some call, then cancel it.) - async def simple_check_wait_send_all_might_not_block(scope): + async def simple_check_wait_send_all_might_not_block( + scope: CancelScope, + ) -> None: with assert_checkpoints(): await s.wait_send_all_might_not_block() scope.cancel() @@ -136,7 +160,7 @@ async def simple_check_wait_send_all_might_not_block(scope): # closing the r side leads to BrokenResourceError on the s side # (eventually) - async def expect_broken_stream_on_send(): + async def expect_broken_stream_on_send() -> None: with _assert_raises(_core.BrokenResourceError): while True: await do_send_all(b"x" * 100) @@ -179,11 +203,11 @@ async def expect_broken_stream_on_send(): async with _ForceCloseBoth(await stream_maker()) as (s, r): # if send-then-graceful-close, receiver gets data then b"" - async def send_then_close(): + async def send_then_close() -> None: await do_send_all(b"y") await do_aclose(s) - async def receive_send_then_close(): + async def receive_send_then_close() -> None: # We want to make sure that if the sender closes the stream before # we read anything, then we still get all the data. But some # streams might block on the do_send_all call. So we let the @@ -248,9 +272,13 @@ async def receive_send_then_close(): # https://github.com/python-trio/trio/issues/77 async with _ForceCloseBoth(await stream_maker()) as (s, r): - async def expect_cancelled(afn, *args): + async def expect_cancelled( + afn: Callable[ArgsT, Awaitable[object]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, + ) -> None: with _assert_raises(_core.Cancelled): - await afn(*args) + await afn(*args, **kwargs) with _core.CancelScope() as scope: scope.cancel() @@ -278,16 +306,16 @@ async def receive_expecting_closed(): # check wait_send_all_might_not_block, if we can if clogged_stream_maker is not None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): - record = [] + record: list[str] = [] - async def waiter(cancel_scope): + async def waiter(cancel_scope: CancelScope) -> None: record.append("waiter sleeping") with assert_checkpoints(): await s.wait_send_all_might_not_block() record.append("waiter wokeup") cancel_scope.cancel() - async def receiver(): + async def receiver() -> None: # give wait_send_all_might_not_block a chance to block await _core.wait_all_tasks_blocked() record.append("receiver starting") @@ -333,14 +361,14 @@ async def receiver(): # with or without an exception async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): - async def sender(): + async def sender() -> None: try: with assert_checkpoints(): await s.wait_send_all_might_not_block() except _core.BrokenResourceError: # pragma: no cover pass - async def receiver(): + async def receiver() -> None: await _core.wait_all_tasks_blocked() await aclose_forcefully(r) @@ -359,7 +387,7 @@ async def receiver(): # Check that if a task is blocked in a send-side method, then closing # the send stream causes it to wake up. - async def close_soon(s): + async def close_soon(s: SendStream) -> None: await _core.wait_all_tasks_blocked() await aclose_forcefully(s) @@ -376,7 +404,10 @@ async def close_soon(s): await s.wait_send_all_might_not_block() -async def check_two_way_stream(stream_maker, clogged_stream_maker): +async def check_two_way_stream( + stream_maker: StreamMaker[Stream, Stream], + clogged_stream_maker: StreamMaker[Stream, Stream] | None, +) -> None: """Perform a number of generic tests on a custom two-way stream implementation. @@ -391,13 +422,15 @@ async def check_two_way_stream(stream_maker, clogged_stream_maker): """ await check_one_way_stream(stream_maker, clogged_stream_maker) - async def flipped_stream_maker(): - return reversed(await stream_maker()) + async def flipped_stream_maker() -> tuple[Stream, Stream]: + return (await stream_maker())[::-1] + + flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None if clogged_stream_maker is not None: - async def flipped_clogged_stream_maker(): - return reversed(await clogged_stream_maker()) + async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]: + return (await clogged_stream_maker())[::-1] else: flipped_clogged_stream_maker = None @@ -415,7 +448,9 @@ async def flipped_clogged_stream_maker(): i = r.getrandbits(8 * DUPLEX_TEST_SIZE) test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little") - async def sender(s, data, seed): + async def sender( + s: Stream, data: bytes | bytearray | memoryview, seed: int + ) -> None: r = random.Random(seed) m = memoryview(data) while m: @@ -423,7 +458,7 @@ async def sender(s, data, seed): await s.send_all(m[:chunk_size]) m = m[chunk_size:] - async def receiver(s, data, seed): + async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None: r = random.Random(seed) got = bytearray() while len(got) < len(data): @@ -438,7 +473,7 @@ async def receiver(s, data, seed): nursery.start_soon(receiver, s1, test_data[::-1], 2) nursery.start_soon(receiver, s2, test_data, 3) - async def expect_receive_some_empty(): + async def expect_receive_some_empty() -> None: assert await s2.receive_some(10) == b"" await s2.aclose() @@ -447,7 +482,10 @@ async def expect_receive_some_empty(): nursery.start_soon(s1.aclose) -async def check_half_closeable_stream(stream_maker, clogged_stream_maker): +async def check_half_closeable_stream( + stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream], + clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None, +) -> None: """Perform a number of generic tests on a custom half-closeable stream implementation. @@ -466,12 +504,12 @@ async def check_half_closeable_stream(stream_maker, clogged_stream_maker): assert isinstance(s1, HalfCloseableStream) assert isinstance(s2, HalfCloseableStream) - async def send_x_then_eof(s): + async def send_x_then_eof(s: HalfCloseableStream) -> None: await s.send_all(b"x") with assert_checkpoints(): await s.send_eof() - async def expect_x_then_eof(r): + async def expect_x_then_eof(r: HalfCloseableStream) -> None: await _core.wait_all_tasks_blocked() assert await r.receive_some(10) == b"x" assert await r.receive_some(10) == b"" diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 5804295300..4a4047813b 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -1,10 +1,14 @@ -from contextlib import contextmanager +from __future__ import annotations + +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager from .. import _core @contextmanager -def _assert_yields_or_not(expected): +def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]: + """Check if checkpoints are executed in a block of code.""" __tracebackhide__ = True task = _core.current_task() orig_cancel = task._cancel_points @@ -22,7 +26,7 @@ def _assert_yields_or_not(expected): raise AssertionError("assert_no_checkpoints block yielded!") -def assert_checkpoints(): +def assert_checkpoints() -> AbstractContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block either exits with an exception or executes at least one :ref:`checkpoint `. @@ -42,7 +46,7 @@ def assert_checkpoints(): return _assert_yields_or_not(True) -def assert_no_checkpoints(): +def assert_no_checkpoints() -> AbstractContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block does not execute any :ref:`checkpoints `. diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index f0ea927734..ddf46174f3 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -6,18 +6,22 @@ # - TCP # - UDP broadcast -import trio -import attr -import ipaddress -from collections import deque +from __future__ import annotations + import errno +import ipaddress import os -from typing import Union, List, Optional -import enum -from contextlib import contextmanager +from typing import TYPE_CHECKING, Optional, Union +import attr + +import trio from trio._util import Final, NoPublicConstructor +if TYPE_CHECKING: + from socket import AddressFamily, SocketKind + from types import TracebackType + IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -101,7 +105,7 @@ def reply(self, payload): class FakeSocketFactory(trio.abc.SocketFactory): fake_net: "FakeNet" - def socket(self, family: int, type: int, proto: int) -> "FakeSocket": + def socket(self, family: int, type: int, proto: int) -> FakeSocket: # type: ignore[override] return FakeSocket._create(self.fake_net, family, type, proto) @@ -110,22 +114,38 @@ class FakeHostnameResolver(trio.abc.HostnameResolver): fake_net: "FakeNet" async def getaddrinfo( - self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0 - ): + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: raise NotImplementedError("FakeNet doesn't do fake DNS yet") - async def getnameinfo(self, sockaddr, flags: int): + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: raise NotImplementedError("FakeNet doesn't do fake DNS yet") class FakeNet(metaclass=Final): - def __init__(self): + def __init__(self) -> None: # When we need to pick an arbitrary unique ip address/port, use these: self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() - self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() + self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() # type: ignore[assignment] self._auto_port_iter = iter(range(50000, 65535)) - self._bound: Dict[UDPBinding, FakeSocket] = {} + self._bound: dict[UDPBinding, FakeSocket] = {} self.route_packet = None @@ -173,9 +193,9 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): self._closed = False - self._packet_sender, self._packet_receiver = trio.open_memory_channel( - float("inf") - ) + self._packet_sender, self._packet_receiver = trio.open_memory_channel[ + UDPPacket + ](float("inf")) # This is the source-of-truth for what port etc. this socket is bound to self._binding: Optional[UDPBinding] = None @@ -203,7 +223,7 @@ async def _resolve_address_nocp(self, address, *, local): local=local, ) - def _deliver_packet(self, packet: UDPPacket): + def _deliver_packet(self, packet: UDPPacket) -> None: try: self._packet_sender.send_nowait(packet) except trio.BrokenResourceError: @@ -217,7 +237,7 @@ def _deliver_packet(self, packet: UDPPacket): async def bind(self, addr): self._check_closed() if self._binding is not None: - _fake_error(errno.EINVAL) + _fake_err(errno.EINVAL) await trio.lowlevel.checkpoint() ip_str, port = await self._resolve_address_nocp(addr, local=True) ip = ipaddress.ip_address(ip_str) @@ -340,7 +360,12 @@ def setsockopt(self, level, item, value): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.close() async def send(self, data, flags=0): diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 99ad7dfcaf..fc23fae842 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -1,9 +1,22 @@ +from __future__ import annotations + import operator +from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar -from .. import _core +from .. import _core, _util from .._highlevel_generic import StapledStream -from .. import _util -from ..abc import SendStream, ReceiveStream +from ..abc import ReceiveStream, SendStream + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + +AsyncHook: TypeAlias = Callable[[], Awaitable[object]] +# Would be nice to exclude awaitable here, but currently not possible. +SyncHook: TypeAlias = Callable[[], object] +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) + ################################################################ # In-memory streams - Unbounded buffer version @@ -11,7 +24,7 @@ class _UnboundedByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._closed = False self._lot = _core.ParkingLot() @@ -23,28 +36,28 @@ def __init__(self): # channel: so after close(), calling put() raises ClosedResourceError, and # calling the get() variants drains the buffer and then returns an empty # bytearray. - def close(self): + def close(self) -> None: self._closed = True self._lot.unpark_all() - def close_and_wipe(self): + def close_and_wipe(self) -> None: self._data = bytearray() self.close() - def put(self, data): + def put(self, data: bytes | bytearray | memoryview) -> None: if self._closed: raise _core.ClosedResourceError("virtual connection closed") self._data += data self._lot.unpark_all() - def _check_max_bytes(self, max_bytes): + def _check_max_bytes(self, max_bytes: int | None) -> None: if max_bytes is None: return max_bytes = operator.index(max_bytes) if max_bytes < 1: raise ValueError("max_bytes must be >= 1") - def _get_impl(self, max_bytes): + def _get_impl(self, max_bytes: int | None) -> bytearray: assert self._closed or self._data if max_bytes is None: max_bytes = len(self._data) @@ -56,14 +69,14 @@ def _get_impl(self, max_bytes): else: return bytearray() - def get_nowait(self, max_bytes=None): + def get_nowait(self, max_bytes: int | None = None) -> bytearray: with self._fetch_lock: self._check_max_bytes(max_bytes) if not self._closed and not self._data: raise _core.WouldBlock return self._get_impl(max_bytes) - async def get(self, max_bytes=None): + async def get(self, max_bytes: int | None = None) -> bytearray: with self._fetch_lock: self._check_max_bytes(max_bytes) if not self._closed and not self._data: @@ -96,9 +109,9 @@ class MemorySendStream(SendStream, metaclass=_util.Final): def __init__( self, - send_all_hook=None, - wait_send_all_might_not_block_hook=None, - close_hook=None, + send_all_hook: AsyncHook | None = None, + wait_send_all_might_not_block_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -108,7 +121,7 @@ def __init__( self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook self.close_hook = close_hook - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Places the given data into the object's internal buffer, and then calls the :attr:`send_all_hook` (if any). @@ -122,12 +135,12 @@ async def send_all(self, data): if self.send_all_hook is not None: await self.send_all_hook() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and then returns immediately. """ - # Execute two checkpoints so we have more of a chance to detect + # Execute two checkpoints so that we have more of a chance to detect # buggy user code that calls this twice at the same time. with self._conflict_detector: await _core.checkpoint() @@ -137,7 +150,7 @@ async def wait_send_all_might_not_block(self): if self.wait_send_all_might_not_block_hook is not None: await self.wait_send_all_might_not_block_hook() - def close(self): + def close(self) -> None: """Marks this stream as closed, and then calls the :attr:`close_hook` (if any). @@ -154,12 +167,12 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() - async def get_data(self, max_bytes=None): + async def get_data(self, max_bytes: int | None = None) -> bytearray: """Retrieves data from the internal buffer, blocking if necessary. Args: @@ -175,7 +188,7 @@ async def get_data(self, max_bytes=None): """ return await self._outgoing.get(max_bytes) - def get_data_nowait(self, max_bytes=None): + def get_data_nowait(self, max_bytes: int | None = None) -> bytearray: """Retrieves data from the internal buffer, but doesn't block. See :meth:`get_data` for details. @@ -204,7 +217,11 @@ class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final): """ - def __init__(self, receive_some_hook=None, close_hook=None): + def __init__( + self, + receive_some_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, + ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" ) @@ -213,7 +230,7 @@ def __init__(self, receive_some_hook=None, close_hook=None): self.receive_some_hook = receive_some_hook self.close_hook = close_hook - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytearray: """Calls the :attr:`receive_some_hook` (if any), and then retrieves data from the internal buffer, blocking if necessary. @@ -236,7 +253,7 @@ async def receive_some(self, max_bytes=None): raise _core.ClosedResourceError return data - def close(self): + def close(self) -> None: """Discards any pending data from the internal buffer, and marks this stream as closed. @@ -246,21 +263,26 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() - def put_data(self, data): + def put_data(self, data: bytes | bytearray | memoryview) -> None: """Appends the given data to the internal buffer.""" self._incoming.put(data) - def put_eof(self): + def put_eof(self) -> None: """Adds an end-of-file marker to the internal buffer.""" self._incoming.close() -def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=None): +def memory_stream_pump( + memory_send_stream: MemorySendStream, + memory_receive_stream: MemoryReceiveStream, + *, + max_bytes: int | None = None, +) -> bool: """Take data out of the given :class:`MemorySendStream`'s internal buffer, and put it into the given :class:`MemoryReceiveStream`'s internal buffer. @@ -293,7 +315,7 @@ def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=N return True -def memory_stream_one_way_pair(): +def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]: """Create a connected, pure-Python, unidirectional stream with infinite buffering and flexible configuration options. @@ -320,10 +342,10 @@ def memory_stream_one_way_pair(): send_stream = MemorySendStream() recv_stream = MemoryReceiveStream() - def pump_from_send_stream_to_recv_stream(): + def pump_from_send_stream_to_recv_stream() -> None: memory_stream_pump(send_stream, recv_stream) - async def async_pump_from_send_stream_to_recv_stream(): + async def async_pump_from_send_stream_to_recv_stream() -> None: pump_from_send_stream_to_recv_stream() send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream @@ -331,7 +353,12 @@ async def async_pump_from_send_stream_to_recv_stream(): return send_stream, recv_stream -def _make_stapled_pair(one_way_pair): +def _make_stapled_pair( + one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]] +) -> tuple[ + StapledStream[SendStreamT, ReceiveStreamT], + StapledStream[SendStreamT, ReceiveStreamT], +]: pipe1_send, pipe1_recv = one_way_pair() pipe2_send, pipe2_recv = one_way_pair() stream1 = StapledStream(pipe1_send, pipe2_recv) @@ -339,7 +366,12 @@ def _make_stapled_pair(one_way_pair): return stream1, stream2 -def memory_stream_pair(): +def memory_stream_pair() -> ( + tuple[ + StapledStream[MemorySendStream, MemoryReceiveStream], + StapledStream[MemorySendStream, MemoryReceiveStream], + ] +): """Create a connected, pure-Python, bidirectional stream with infinite buffering and flexible configuration options. @@ -422,7 +454,7 @@ async def receiver(): class _LockstepByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._sender_closed = False self._receiver_closed = False @@ -435,12 +467,12 @@ def __init__(self): "another task is already receiving" ) - def _something_happened(self): + def _something_happened(self) -> None: self._waiters.unpark_all() # Always wakes up when one side is closed, because everyone always reacts # to that. - async def _wait_for(self, fn): + async def _wait_for(self, fn: Callable[[], bool]) -> None: while True: if fn(): break @@ -449,15 +481,15 @@ async def _wait_for(self, fn): await self._waiters.park() await _core.checkpoint() - def close_sender(self): + def close_sender(self) -> None: self._sender_closed = True self._something_happened() - def close_receiver(self): + def close_receiver(self) -> None: self._receiver_closed = True self._something_happened() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: with self._send_conflict_detector: if self._sender_closed: raise _core.ClosedResourceError @@ -466,13 +498,13 @@ async def send_all(self, data): assert not self._data self._data += data self._something_happened() - await self._wait_for(lambda: not self._data) + await self._wait_for(lambda: self._data == b"") if self._sender_closed: raise _core.ClosedResourceError if self._data and self._receiver_closed: raise _core.BrokenResourceError - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self._sender_closed: raise _core.ClosedResourceError @@ -483,7 +515,7 @@ async def wait_send_all_might_not_block(self): if self._sender_closed: raise _core.ClosedResourceError - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: with self._receive_conflict_detector: # Argument validation if max_bytes is not None: @@ -497,7 +529,7 @@ async def receive_some(self, max_bytes=None): self._receiver_waiting = True self._something_happened() try: - await self._wait_for(lambda: self._data) + await self._wait_for(lambda: self._data != b"") finally: self._receiver_waiting = False if self._receiver_closed: @@ -516,39 +548,39 @@ async def receive_some(self, max_bytes=None): class _LockstepSendStream(SendStream): - def __init__(self, lbq): + def __init__(self, lbq: _LockstepByteQueue): self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_sender() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: await self._lbq.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: await self._lbq.wait_send_all_might_not_block() class _LockstepReceiveStream(ReceiveStream): - def __init__(self, lbq): + def __init__(self, lbq: _LockstepByteQueue): self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_receiver() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: return await self._lbq.receive_some(max_bytes) -def lockstep_stream_one_way_pair(): +def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]: """Create a connected, pure Python, unidirectional stream where data flows in lockstep. @@ -575,7 +607,12 @@ def lockstep_stream_one_way_pair(): return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq) -def lockstep_stream_pair(): +def lockstep_stream_pair() -> ( + tuple[ + StapledStream[SendStream, ReceiveStream], + StapledStream[SendStream, ReceiveStream], + ] +): """Create a connected, pure-Python, bidirectional stream where data flows in lockstep. diff --git a/trio/testing/_network.py b/trio/testing/_network.py index 615ce2effb..fddbbf0fdc 100644 --- a/trio/testing/_network.py +++ b/trio/testing/_network.py @@ -1,8 +1,10 @@ from .. import socket as tsocket -from .._highlevel_socket import SocketStream +from .._highlevel_socket import SocketListener, SocketStream -async def open_stream_to_socket_listener(socket_listener): +async def open_stream_to_socket_listener( + socket_listener: SocketListener, +) -> SocketStream: """Connect to the given :class:`~trio.SocketListener`. This is particularly useful in tests when you want to let a server pick diff --git a/trio/testing/_sequencer.py b/trio/testing/_sequencer.py index 3f4bda9cfc..137fd3c522 100644 --- a/trio/testing/_sequencer.py +++ b/trio/testing/_sequencer.py @@ -6,9 +6,7 @@ import attr -from .. import _core -from .. import _util -from .. import Event +from .. import Event, _core, _util if TYPE_CHECKING: from collections.abc import AsyncIterator diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 4fcaeae372..5619352846 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -1,20 +1,36 @@ -from functools import wraps, partial +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar from .. import _core from ..abc import Clock, Instrument +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + ArgsT = ParamSpec("ArgsT") + + +RetT = TypeVar("RetT") + + +def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]: + """Converts an async test function to be synchronous, running via Trio. + + Usage:: + + @trio_test + async def test_whatever(): + await ... + + If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or + :class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`. + """ -# Use: -# -# @trio_test -# async def test_whatever(): -# await ... -# -# Also: if a pytest fixture is passed in that subclasses the Clock abc, then -# that clock is passed to trio.run(). -def trio_test(fn): @wraps(fn) - def wrapper(**kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: __tracebackhide__ = True clocks = [c for c in kwargs.values() if isinstance(c, Clock)] if not clocks: @@ -24,6 +40,8 @@ def wrapper(**kwargs): else: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - return _core.run(partial(fn, **kwargs), clock=clock, instruments=instruments) + return _core.run( + partial(fn, *args, **kwargs), clock=clock, instruments=instruments + ) return wrapper diff --git a/trio/tests.py b/trio/tests.py new file mode 100644 index 0000000000..4ffb583a3a --- /dev/null +++ b/trio/tests.py @@ -0,0 +1,38 @@ +import importlib +import sys +from typing import Any + +from . import _tests +from ._deprecate import warn_deprecated + +warn_deprecated( + "trio.tests", + "0.22.1", + instead="trio._tests", + issue=274, +) + + +# This won't give deprecation warning on import, but will give a warning on use of any +# attribute in tests, and static analysis tools will also not see any content inside. +class TestsDeprecationWrapper: + __name__ = "trio.tests" + + def __getattr__(self, attr: str) -> Any: + warn_deprecated( + f"trio.tests.{attr}", + "0.22.1", + instead=f"trio._tests.{attr}", + issue=274, + ) + + # needed to access e.g. trio._tests.tools, although pytest doesn't need it + if not hasattr(_tests, attr): # pragma: no cover + importlib.import_module(f"trio._tests.{attr}", "trio._tests") + return attr + + return getattr(_tests, attr) + + +# https://stackoverflow.com/questions/2447353/getattr-on-a-module +sys.modules[__name__] = TestsDeprecationWrapper() # type: ignore[assignment] diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py deleted file mode 100644 index 026d6f5efa..0000000000 --- a/trio/tests/test_exports.py +++ /dev/null @@ -1,145 +0,0 @@ -import re -import sys -import importlib -import types -import inspect -import enum - -import pytest - -import trio -import trio.testing - -from .. import _core -from .. import _util - - -def test_core_is_properly_reexported(): - # Each export from _core should be re-exported by exactly one of these - # three modules: - sources = [trio, trio.lowlevel, trio.testing] - for symbol in dir(_core): - if symbol.startswith("_") or symbol == "tests": - continue - found = 0 - for source in sources: - if symbol in dir(source) and getattr(source, symbol) is getattr( - _core, symbol - ): - found += 1 - print(symbol, found) - assert found == 1 - - -def public_modules(module): - yield module - for name, class_ in module.__dict__.items(): - if name.startswith("_"): # pragma: no cover - continue - if not isinstance(class_, types.ModuleType): - continue - if not class_.__name__.startswith(module.__name__): # pragma: no cover - continue - if class_ is module: - continue - # We should rename the trio.tests module (#274), but until then we use - # a special-case hack: - if class_.__name__ == "trio.tests": - continue - yield from public_modules(class_) - - -PUBLIC_MODULES = list(public_modules(trio)) -PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES] - - -# It doesn't make sense for downstream redistributors to run this test, since -# they might be using a newer version of Python with additional symbols which -# won't be reflected in trio.socket, and this shouldn't cause downstream test -# runs to start failing. -@pytest.mark.redistributors_should_skip -# pylint/jedi often have trouble with alpha releases, where Python's internals -# are in flux, grammar may not have settled down, etc. -@pytest.mark.skipif( - sys.version_info.releaselevel == "alpha", - reason="skip static introspection tools on Python dev/alpha releases", -) -@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES) -@pytest.mark.parametrize("tool", ["pylint", "jedi"]) -@pytest.mark.filterwarnings( - # https://github.com/pypa/setuptools/issues/3274 - "ignore:module 'sre_constants' is deprecated:DeprecationWarning", -) -def test_static_tool_sees_all_symbols(tool, modname): - module = importlib.import_module(modname) - - def no_underscores(symbols): - return {symbol for symbol in symbols if not symbol.startswith("_")} - - runtime_names = no_underscores(dir(module)) - - # We should rename the trio.tests module (#274), but until then we use a - # special-case hack: - if modname == "trio": - runtime_names.remove("tests") - - if tool == "pylint": - from pylint.lint import PyLinter - - linter = PyLinter() - ast = linter.get_ast(module.__file__, modname) - static_names = no_underscores(ast) - elif tool == "jedi": - import jedi - - # Simulate typing "import trio; trio." - script = jedi.Script(f"import {modname}; {modname}.") - completions = script.complete() - static_names = no_underscores(c.name for c in completions) - else: # pragma: no cover - assert False - - # It's expected that the static set will contain more names than the - # runtime set: - # - static tools are sometimes sloppy and include deleted names - # - some symbols are platform-specific at runtime, but always show up in - # static analysis (e.g. in trio.socket or trio.lowlevel) - # So we check that the runtime names are a subset of the static names. - missing_names = runtime_names - static_names - if missing_names: # pragma: no cover - print(f"{tool} can't see the following names in {modname}:") - print() - for name in sorted(missing_names): - print(f" {name}") - assert False - - -def test_classes_are_final(): - for module in PUBLIC_MODULES: - for name, class_ in module.__dict__.items(): - if not isinstance(class_, type): - continue - # Deprecated classes are exported with a leading underscore - if name.startswith("_"): # pragma: no cover - continue - - # Abstract classes can be subclassed, because that's the whole - # point of ABCs - if inspect.isabstract(class_): - continue - # Exceptions are allowed to be subclassed, because exception - # subclassing isn't used to inherit behavior. - if issubclass(class_, BaseException): - continue - # These are classes that are conceptually abstract, but - # inspect.isabstract returns False for boring reasons. - if class_ in {trio.abc.Instrument, trio.socket.SocketType}: - continue - # Enums have their own metaclass, so we can't use our metaclasses. - # And I don't think there's a lot of risk from people subclassing - # enums... - if issubclass(class_, enum.Enum): - continue - # ... insert other special cases here ... - - assert isinstance(class_, _util.Final) diff --git a/trio/to_thread.py b/trio/to_thread.py index 6eec7b36c7..45ea5b480b 100644 --- a/trio/to_thread.py +++ b/trio/to_thread.py @@ -1,2 +1,4 @@ -from ._threads import to_thread_run_sync as run_sync -from ._threads import current_default_thread_limiter +from ._threads import current_default_thread_limiter, to_thread_run_sync as run_sync + +# need to use __all__ for pyright --verifytypes to see re-exports when renaming them +__all__ = ["current_default_thread_limiter", "run_sync"]