diff --git a/mypy/dmypy_server.py b/mypy/dmypy_server.py index b4c3fe8fe0dcb..3d337eedbf1ca 100644 --- a/mypy/dmypy_server.py +++ b/mypy/dmypy_server.py @@ -221,8 +221,8 @@ def serve(self) -> None: while True: with server: data = receive(server) - sys.stdout = WriteToConn(server, "stdout") # type: ignore[assignment] - sys.stderr = WriteToConn(server, "stderr") # type: ignore[assignment] + sys.stdout = WriteToConn(server, "stdout", sys.stdout.isatty()) + sys.stderr = WriteToConn(server, "stderr", sys.stderr.isatty()) resp: dict[str, Any] = {} if "command" not in data: resp = {"error": "No command found in request"} diff --git a/mypy/dmypy_util.py b/mypy/dmypy_util.py index 6ddef3ecc0a4b..10df1020ed5d0 100644 --- a/mypy/dmypy_util.py +++ b/mypy/dmypy_util.py @@ -6,7 +6,8 @@ from __future__ import annotations import json -from typing import Any, Final, Iterable +from types import TracebackType +from typing import Any, Final, Iterable, Iterator, TextIO from mypy.ipc import IPCBase @@ -40,12 +41,66 @@ def send(connection: IPCBase, data: Any) -> None: connection.write(json.dumps(data)) -class WriteToConn: +class WriteToConn(TextIO): """Helper class to write to a connection instead of standard output.""" - def __init__(self, server: IPCBase, output_key: str) -> None: + def __init__(self, server: IPCBase, output_key: str, isatty: bool) -> None: self.server = server self.output_key = output_key + self._isatty = isatty + + def __enter__(self) -> TextIO: + return self + + def __exit__( + self, + t: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def __iter__(self) -> Iterator[str]: + raise NotImplementedError + + def __next__(self) -> str: + raise NotImplementedError + + def close(self) -> None: + pass + + def fileno(self) -> int: + raise OSError + + def flush(self) -> None: + pass + + def isatty(self) -> bool: + return self._isatty + + def read(self, n: int = 0) -> str: + raise NotImplementedError + + def readable(self) -> bool: + return False + + def readline(self, limit: int = 0) -> str: + raise NotImplementedError + + def readlines(self, hint: int = 0) -> list[str]: + raise NotImplementedError + + def seek(self, offset: int, whence: int = 0) -> int: + raise NotImplementedError + + def seekable(self) -> bool: + return False + + def tell(self) -> int: + raise NotImplementedError + + def truncate(self, size: int | None = 0) -> int: + raise NotImplementedError def write(self, output: str) -> int: resp: dict[str, Any] = {} @@ -53,6 +108,9 @@ def write(self, output: str) -> int: send(self.server, resp) return len(output) + def writable(self) -> bool: + return True + def writelines(self, lines: Iterable[str]) -> None: for s in lines: self.write(s)