Skip to content

Commit

Permalink
Add more typing conversions for pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
perklet committed Mar 17, 2024
1 parent c1e71bf commit d86353e
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 33 deletions.
2 changes: 1 addition & 1 deletion curl_cffi/_asyncio_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class _HasFileno(Protocol):
def fileno(self) -> int:
pass
return 0


_FileDescriptorLike = Union[int, _HasFileno]
Expand Down
4 changes: 1 addition & 3 deletions curl_cffi/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from typing import Any, Dict, Set
from weakref import WeakKeyDictionary, WeakSet

import cffi

from ._wrapper import ffi, lib
from .const import CurlMOpt
from .curl import DEFAULT_CACERT, Curl
Expand Down Expand Up @@ -134,7 +132,7 @@ def __init__(self, cacert: str = "", loop=None):
self._curlm = lib.curl_multi_init()
self._cacert = cacert or DEFAULT_CACERT
self._curl2future: Dict[Curl, asyncio.Future] = {} # curl to future map
self._curl2curl: Dict[cffi.CData, Curl] = {} # c curl to Curl
self._curl2curl: Dict[ffi.CData, Curl] = {} # c curl to Curl
self._sockfds: Set[int] = set() # sockfds
self.loop = _get_selector(loop if loop is not None else asyncio.get_running_loop())
self._checker = self.loop.create_task(self._force_timeout())
Expand Down
13 changes: 8 additions & 5 deletions curl_cffi/requests/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

HeaderTypes = Union[
"Headers",
Mapping[str, str],
Mapping[bytes, bytes],
Mapping[str, Optional[str]],
Mapping[bytes, Optional[bytes]],
Sequence[Tuple[str, str]],
Sequence[Tuple[bytes, bytes]],
Sequence[Union[str, bytes]],
Expand All @@ -35,7 +35,7 @@ def to_str(value: Union[str, bytes], encoding: str = "utf-8") -> str:


def to_bytes_or_str(value: str, match_type_of: AnyStr) -> AnyStr:
return value if isinstance(match_type_of, str) else value.encode()
return value if isinstance(match_type_of, str) else value.encode() # pyright: ignore [reportGeneralTypeIssues]


SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
Expand Down Expand Up @@ -102,7 +102,10 @@ def __init__(
elif isinstance(headers, list):
if isinstance(headers[0], (str, bytes)):
sep = ":" if isinstance(headers[0], str) else b":"
h = [(k, v.lstrip()) for line in headers for k, v in [line.split(sep, maxsplit=1)]]
h = []
for line in headers:
k, v = line.split(sep, maxsplit=1) # pyright: ignore
h.append((k, v.strip()))
elif isinstance(headers[0], tuple):
h = headers
self._list = [
Expand All @@ -111,7 +114,7 @@ def __init__(
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in h
for k, v in h # pyright: ignore
]

self._encoding = encoding
Expand Down
19 changes: 14 additions & 5 deletions curl_cffi/requests/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import queue
import warnings
from concurrent.futures import Future
from json import loads
from typing import Any, Dict, List, Optional
from typing import Any, Awaitable, Dict, List, Optional

from .. import Curl
from .cookies import Cookies
Expand Down Expand Up @@ -65,7 +66,8 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non
self.history: List[Dict[str, Any]] = []
self.infos: Dict[str, Any] = {}
self.queue: Optional[queue.Queue] = None
self.stream_task = None
self.stream_task: Optional[Future] = None
self.astream_task: Optional[Awaitable] = None
self.quit_now = None

def _decode(self, content: bytes) -> str:
Expand Down Expand Up @@ -117,6 +119,9 @@ def iter_content(self, chunk_size=None, decode_unicode=False):
warnings.warn("chunk_size is ignored, there is no way to tell curl that.")
if decode_unicode:
raise NotImplementedError()

assert self.queue and self.curl, "stream mode is not enabled."

while True:
chunk = self.queue.get()

Expand All @@ -133,11 +138,12 @@ def iter_content(self, chunk_size=None, decode_unicode=False):
yield chunk

def json(self, **kw):
"""return a prased json object of the content."""
"""return a parsed json object of the content."""
return loads(self.content, **kw)

def close(self):
"""Close the streaming connection, only valid in stream mode."""

if self.quit_now:
self.quit_now.set()
if self.stream_task:
Expand Down Expand Up @@ -179,6 +185,8 @@ async def aiter_content(self, chunk_size=None, decode_unicode=False):
if decode_unicode:
raise NotImplementedError()

assert self.queue and self.curl, "stream mode is not enabled."

while True:
chunk = await self.queue.get()

Expand Down Expand Up @@ -209,8 +217,9 @@ async def acontent(self) -> bytes:

async def aclose(self):
"""Close the streaming connection, only valid in stream mode."""
if self.stream_task:
await self.stream_task

if self.astream_task:
await self.astream_task

# It prints the status code of the response instead of
# the object's memory location.
Expand Down
30 changes: 15 additions & 15 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ def _set_curl_options(
username, password = self.auth
if auth:
username, password = auth
c.setopt(CurlOpt.USERNAME, username.encode())
c.setopt(CurlOpt.PASSWORD, password.encode())
c.setopt(CurlOpt.USERNAME, username.encode()) # pyright: ignore [reportPossiblyUnboundVariable=none]
c.setopt(CurlOpt.PASSWORD, password.encode()) # pyright: ignore [reportPossiblyUnboundVariable=none]

# timeout
if timeout is not_set:
Expand Down Expand Up @@ -813,12 +813,12 @@ def perform():
except CurlError as e:
rsp = self._parse_response(c, buffer, header_buffer)
rsp.request = req
q.put_nowait(RequestsError(str(e), e.code, rsp))
cast(queue.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp))
finally:
if not header_recved.is_set():
header_recved.set()
if not cast(threading.Event, header_recved).is_set():
cast(threading.Event, header_recved).set()
# None acts as a sentinel
q.put(None)
cast(queue.Queue, q).put(None)

def cleanup(fut):
header_parsed.wait()
Expand All @@ -828,12 +828,12 @@ def cleanup(fut):
stream_task.add_done_callback(cleanup)

# Wait for the first chunk
header_recved.wait()
cast(threading.Event, header_recved).wait()
rsp = self._parse_response(c, buffer, header_buffer)
header_parsed.set()

# Raise the exception if something wrong happens when receiving the header.
first_element = _peek_queue(q)
first_element = _peek_queue(cast(queue.Queue, q))
if isinstance(first_element, RequestsError):
c.reset()
raise first_element
Expand Down Expand Up @@ -1080,33 +1080,33 @@ async def perform():
except CurlError as e:
rsp = self._parse_response(curl, buffer, header_buffer)
rsp.request = req
q.put_nowait(RequestsError(str(e), e.code, rsp))
cast(asyncio.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp))
finally:
if not header_recved.is_set():
header_recved.set()
if not cast(asyncio.Event, header_recved).is_set():
cast(asyncio.Event, header_recved).set()
# None acts as a sentinel
await q.put(None)
await cast(asyncio.Queue, q).put(None)

def cleanup(fut):
self.release_curl(curl)

stream_task = asyncio.create_task(perform())
stream_task.add_done_callback(cleanup)

await header_recved.wait()
await cast(asyncio.Event, header_recved).wait()

# Unlike threads, coroutines does not use preemptive scheduling.
# For asyncio, there is no need for a header_parsed event, the
# _parse_response will execute in the foreground, no background tasks running.
rsp = self._parse_response(curl, buffer, header_buffer)

first_element = _peek_aio_queue(q)
first_element = _peek_aio_queue(cast(asyncio.Queue, q))
if isinstance(first_element, RequestsError):
self.release_curl(curl)
raise first_element

rsp.request = req
rsp.stream_task = stream_task
rsp.astream_task = stream_task
rsp.quit_now = quit_now
rsp.queue = q
return rsp
Expand Down
2 changes: 1 addition & 1 deletion examples/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

try:
# Python 3.10+
from contextlib import aclosing
from contextlib import aclosing # pyright: ignore
except ImportError:
from contextlib import asynccontextmanager

Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def __init__(self, port):
def run(self):
async def serve(port):
# GitHub actions only likes 127, not localhost, wtf...
async with websockets.serve(echo, "127.0.0.1", port):
async with websockets.serve(echo, "127.0.0.1", port): # pyright: ignore
await asyncio.Future() # run forever

asyncio.run(serve(self.port))
Expand Down
3 changes: 2 additions & 1 deletion tests/unittest/test_curl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import json
from io import BytesIO
from typing import cast

import pytest

Expand Down Expand Up @@ -309,7 +310,7 @@ def test_elapsed(server):
url = str(server.url)
c.setopt(CurlOpt.URL, url.encode())
c.perform()
assert c.getinfo(CurlInfo.TOTAL_TIME) > 0
assert cast(int, c.getinfo(CurlInfo.TOTAL_TIME)) > 0


def test_reason(server):
Expand Down
7 changes: 6 additions & 1 deletion tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from curl_cffi import CurlOpt, requests
from curl_cffi.const import CurlECode, CurlInfo
from curl_cffi.requests.errors import SessionClosed
from curl_cffi.requests.models import Response


def test_head(server):
Expand Down Expand Up @@ -190,6 +191,7 @@ def test_too_many_redirects(server):
with pytest.raises(requests.RequestsError) as e:
requests.get(str(server.url.copy_with(path="/redirect_loop")), max_redirects=2)
assert e.value.code == CurlECode.TOO_MANY_REDIRECTS
assert isinstance(e.value.response, Response)
assert e.value.response.status_code == 301


Expand Down Expand Up @@ -548,6 +550,7 @@ def test_stream_redirect_loop(server):
with s.stream("GET", url, max_redirects=2):
pass
assert e.value.code == CurlECode.TOO_MANY_REDIRECTS
assert isinstance(e.value.response, Response)
assert e.value.response.status_code == 301


Expand All @@ -559,6 +562,7 @@ def test_stream_redirect_loop_without_close(server):
s.get(url, max_redirects=2, stream=True)

assert e.value.code == CurlECode.TOO_MANY_REDIRECTS
assert isinstance(e.value.response, Response)
assert e.value.response.status_code == 301


Expand Down Expand Up @@ -588,6 +592,7 @@ def test_stream_auto_close_with_header_errors(server):
with pytest.raises(requests.RequestsError) as e:
s.get(url, max_redirects=2, stream=True)
assert e.value.code == CurlECode.TOO_MANY_REDIRECTS
assert isinstance(e.value.response, Response)
assert e.value.response.status_code == 301

url = str(server.url.copy_with(path="/"))
Expand Down Expand Up @@ -646,4 +651,4 @@ def test_curl_infos(server):

r = s.get(str(server.url))

assert r.infos[CurlInfo.PRIMARY_IP] == b"127.0.0.1"
assert r.infos[CurlInfo.PRIMARY_IP] == b"127.0.0.1" # pyright: ignore

0 comments on commit d86353e

Please sign in to comment.