Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dmypy inspect on Windows #16355

Merged
merged 2 commits into from Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 16 additions & 8 deletions mypy/inspections.py
Expand Up @@ -219,13 +219,6 @@ def __init__(
# Module for which inspection was requested.
self.module: State | None = None

def parse_location(self, location: str) -> tuple[str, list[int]]:
if location.count(":") not in [2, 4]:
raise ValueError("Format should be file:line:column[:end_line:end_column]")
parts = location.split(":")
module, *rest = parts
return module, [int(p) for p in rest]

def reload_module(self, state: State) -> None:
"""Reload given module while temporary exporting types."""
old = self.fg_manager.manager.options.export_types
Expand Down Expand Up @@ -581,7 +574,7 @@ def run_inspection(
This can be re-used by various simple inspections.
"""
try:
file, pos = self.parse_location(location)
file, pos = parse_location(location)
except ValueError as err:
return {"error": str(err)}

Expand Down Expand Up @@ -623,3 +616,18 @@ def get_definition(self, location: str) -> dict[str, object]:
result["out"] = f"No name or member expressions at {location}"
result["status"] = 1
return result


def parse_location(location: str) -> tuple[str, list[int]]:
if location.count(":") < 2:
raise ValueError("Format should be file:line:column[:end_line:end_column]")
parts = location.rsplit(":", maxsplit=2)
start, *rest = parts
# Note: we must allow drive prefix like `C:` on Windows.
if start.count(":") < 2:
return start, [int(p) for p in rest]
parts = start.rsplit(":", maxsplit=2)
start, *start_rest = parts
if start.count(":") < 2:
return start, [int(p) for p in start_rest + rest]
raise ValueError("Format should be file:line:column[:end_line:end_column]")
5 changes: 5 additions & 0 deletions mypy/test/testutil.py
Expand Up @@ -3,6 +3,7 @@
import os
from unittest import TestCase, mock

from mypy.inspections import parse_location
from mypy.util import get_terminal_width


Expand All @@ -15,3 +16,7 @@ def test_get_terminal_size_in_pty_defaults_to_80(self) -> None:
with mock.patch.object(os, "get_terminal_size", return_value=ret):
with mock.patch.dict(os.environ, values=mock_environ, clear=True):
assert get_terminal_width() == 80

def test_parse_location_windows(self) -> None:
assert parse_location(r"C:\test.py:1:1") == (r"C:\test.py", [1, 1])
assert parse_location(r"C:\test.py:1:1:1:1") == (r"C:\test.py", [1, 1, 1, 1])
3 changes: 3 additions & 0 deletions test-data/unit/daemon.test
Expand Up @@ -372,6 +372,9 @@ foo.py:3: error: Incompatible types in assignment (expression has type "str", va
$ dmypy inspect foo:1
Format should be file:line:column[:end_line:end_column]
== Return code: 2
$ dmypy inspect foo:1:2:3
Source file is not a Python file
== Return code: 2
$ dmypy inspect foo.py:1:2:a:b
invalid literal for int() with base 10: 'a'
== Return code: 2
Expand Down