Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor WebSocket support into separate sync/async implementations #206

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion curl_cffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._wrapper import ffi, lib # type: ignore

from .const import CurlInfo, CurlMOpt, CurlOpt, CurlECode, CurlHttpVersion
from .curl import Curl, CurlError
from .curl import Curl, CurlError, CurlWsFrame
from .aio import AsyncCurl

from .__version__ import __title__, __version__, __description__, __curl_version__
44 changes: 29 additions & 15 deletions curl_cffi/curl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
from __future__ import annotations

import ctypes
import re
import warnings
from http.cookies import SimpleCookie
from typing import Any, List, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Tuple, Union

import certifi

from ._wrapper import ffi, lib # type: ignore
from .const import CurlHttpVersion, CurlInfo, CurlOpt, CurlWsFlag


DEFAULT_CACERT = certifi.where()


class CurlWsFrame(ctypes.Structure):
_fields_ = [
("age", ctypes.c_int),
("flags", ctypes.c_int),
("offset", ctypes.c_uint64),
("bytesleft", ctypes.c_uint64),
("len", ctypes.c_size_t),
]

if TYPE_CHECKING:
age: int
flags: int
offset: int
bytesleft: int
len: int


class CurlError(Exception):
"""Base exception for curl_cffi package"""

Expand Down Expand Up @@ -50,11 +71,13 @@ def buffer_callback(ptr, size, nmemb, userdata):
buffer.write(ffi.buffer(ptr, nmemb)[:])
return nmemb * size


def ensure_int(s):
if not s:
return 0
return int(s)


@ffi.def_extern()
def write_callback(ptr, size, nmemb, userdata):
# although similar enough to the function above, kept here for performance reasons
Expand Down Expand Up @@ -85,7 +108,7 @@ class Curl:
Wrapper for `curl_easy_*` functions of libcurl.
"""

def __init__(self, cacert: str = DEFAULT_CACERT, debug: bool = False, handle = None):
def __init__(self, cacert: str = DEFAULT_CACERT, debug: bool = False, handle=None):
"""
Parameters:
cacert: CA cert path to use, by default, curl_cffi uses its own bundled cert.
Expand Down Expand Up @@ -159,15 +182,11 @@ def setopt(self, option: CurlOpt, value: Any):
elif option == CurlOpt.WRITEDATA:
c_value = ffi.new_handle(value)
self._write_handle = c_value
lib._curl_easy_setopt(
self._curl, CurlOpt.WRITEFUNCTION, lib.buffer_callback
)
lib._curl_easy_setopt(self._curl, CurlOpt.WRITEFUNCTION, lib.buffer_callback)
elif option == CurlOpt.HEADERDATA:
c_value = ffi.new_handle(value)
self._header_handle = c_value
lib._curl_easy_setopt(
self._curl, CurlOpt.HEADERFUNCTION, lib.buffer_callback
)
lib._curl_easy_setopt(self._curl, CurlOpt.HEADERFUNCTION, lib.buffer_callback)
elif option == CurlOpt.WRITEFUNCTION:
c_value = ffi.new_handle(value)
self._write_handle = c_value
Expand Down Expand Up @@ -246,9 +265,7 @@ def impersonate(self, target: str, default_headers: bool = True) -> int:
target: browser to impersonate.
default_headers: whether to add default headers, like User-Agent.
"""
return lib.curl_easy_impersonate(
self._curl, target.encode(), int(default_headers)
)
return lib.curl_easy_impersonate(self._curl, target.encode(), int(default_headers))

def _ensure_cacert(self):
if not self._is_cert_set:
Expand Down Expand Up @@ -346,7 +363,7 @@ def close(self):
ffi.release(self._error_buffer)
self._resolve = ffi.NULL

def ws_recv(self, n: int = 1024):
def ws_recv(self, n: int = 1024) -> Tuple[bytes, CurlWsFrame]:
buffer = ffi.new("char[]", n)
n_recv = ffi.new("int *")
p_frame = ffi.new("struct curl_ws_frame **")
Expand All @@ -365,6 +382,3 @@ def ws_send(self, payload: bytes, flags: CurlWsFlag = CurlWsFlag.BINARY) -> int:
ret = lib.curl_ws_send(self._curl, buffer, len(buffer), n_sent, 0, flags)
self._check_error(ret, "WS_SEND")
return n_sent[0]

def ws_close(self):
self.ws_send(b"", CurlWsFlag.CLOSE)
9 changes: 5 additions & 4 deletions curl_cffi/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"Headers",
"Request",
"Response",
"AsyncWebSocket",
"WebSocket",
"WebSocketError",
"WsCloseCode",
Expand All @@ -27,11 +28,11 @@

from ..const import CurlHttpVersion, CurlWsFlag
from .cookies import Cookies, CookieTypes
from .models import Request, Response
from .models import BrowserType, Request, Response
from .errors import RequestsError
from .headers import Headers, HeaderTypes
from .session import AsyncSession, BrowserType, Session, ProxySpec
from .websockets import WebSocket, WebSocketError, WsCloseCode
from .session import AsyncSession, Session, ProxySpec
from .websockets import AsyncWebSocket, WebSocket, WebSocketError, WsCloseCode

# ThreadType = Literal["eventlet", "gevent", None]

Expand All @@ -52,7 +53,7 @@ def request(
proxies: Optional[ProxySpec] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[Tuple[str, str]] = None,
verify: Optional[bool] = None,
verify: Optional[Union[bool, str]] = None,
referer: Optional[str] = None,
accept_encoding: Optional[str] = "gzip, deflate, br",
content_callback: Optional[Callable] = None,
Expand Down
38 changes: 31 additions & 7 deletions curl_cffi/requests/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import warnings
from enum import Enum
from json import loads
from typing import Optional
import queue
import warnings

from .. import Curl
from .headers import Headers
Expand All @@ -16,6 +17,33 @@ def clear_queue(q: queue.Queue):
q.unfinished_tasks = 0


class BrowserType(str, Enum):
edge99 = "edge99"
edge101 = "edge101"
chrome99 = "chrome99"
chrome100 = "chrome100"
chrome101 = "chrome101"
chrome104 = "chrome104"
chrome107 = "chrome107"
chrome110 = "chrome110"
chrome116 = "chrome116"
chrome119 = "chrome119"
chrome120 = "chrome120"
chrome99_android = "chrome99_android"
safari15_3 = "safari15_3"
safari15_5 = "safari15_5"
safari17_0 = "safari17_0"
safari17_2_ios = "safari17_2_ios"

chrome = "chrome120"
safari = "safari17_0"
safari_ios = "safari17_2_ios"

@classmethod
def has(cls, item):
return item in cls.__members__


class Request:
def __init__(self, url: str, headers: Headers, method: str):
self.url = url
Expand Down Expand Up @@ -86,9 +114,7 @@ def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None):
"""
pending = None

for chunk in self.iter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
):
for chunk in self.iter_content(chunk_size=chunk_size, decode_unicode=decode_unicode):
if pending is not None:
chunk = pending + chunk
if delimiter:
Expand Down Expand Up @@ -139,9 +165,7 @@ async def aiter_lines(self, chunk_size=None, decode_unicode=False, delimiter=Non
"""
pending = None

async for chunk in self.aiter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
):
async for chunk in self.aiter_content(chunk_size=chunk_size, decode_unicode=decode_unicode):
if pending is not None:
chunk = pending + chunk
if delimiter:
Expand Down
Loading