Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Py typing #15446

Open
wants to merge 3 commits into
base: trunk
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 56 additions & 39 deletions py/selenium/webdriver/common/options.py
Original file line number Diff line number Diff line change
@@ -19,9 +19,9 @@
from abc import ABCMeta
from abc import abstractmethod
from enum import Enum
from typing import Any
from typing import Optional

from selenium.common.exceptions import InvalidArgumentException
from selenium.webdriver.common.proxy import Proxy


@@ -42,23 +42,28 @@ class PageLoadStrategy(str, Enum):


class _BaseOptionsDescriptor:
def __init__(self, name):
def __init__(self, name: str):
self.name = name

def __get__(self, obj, cls):
def __get__(self, obj: object, cls: type[object]):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
if self.name == "enableBidi":
# whether BiDi is or will be enabled
value = obj._caps.get("webSocketUrl")
value = obj.capabilities.get("webSocketUrl")
return value is True or isinstance(value, str)
if self.name == "webSocketUrl":
# Return socket url or None if not created yet
value = obj._caps.get(self.name)
value = obj.capabilities.get(self.name)
return None if not isinstance(value, str) else value
if self.name in ("acceptInsecureCerts", "strictFileInteractability", "setWindowRect", "se:downloadsEnabled"):
return obj._caps.get(self.name, False)
return obj._caps.get(self.name)
return obj.capabilities.get(self.name, False)
return obj.capabilities.get(self.name)

def __set__(self, obj: object, value: Any):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")

def __set__(self, obj, value):
if self.name == "enableBidi":
obj.set_capability("webSocketUrl", value)
else:
@@ -72,17 +77,20 @@ class _PageLoadStrategyDescriptor:
:param strategy: the strategy corresponding to a document readiness state
"""

def __init__(self, name):
def __init__(self, name: str):
self.name = name

def __get__(self, obj, cls):
return obj._caps.get(self.name)
def __get__(self, obj: object, cls: type[object]):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
return obj.capabilities.get(self.name)

def __set__(self, obj, value):
if value in ("normal", "eager", "none"):
obj.set_capability(self.name, value)
else:
def __set__(self, obj: object, value: str):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
if value not in ("normal", "eager", "none"):
raise ValueError("Strategy can only be one of the following: normal, eager, none")
obj.set_capability(self.name, value)


class _UnHandledPromptBehaviorDescriptor:
@@ -95,20 +103,23 @@ class _UnHandledPromptBehaviorDescriptor:
:returns: Values for implicit timeout, pageLoad timeout and script timeout if set (in milliseconds)
"""

def __init__(self, name):
def __init__(self, name: str):
self.name = name

def __get__(self, obj, cls):
return obj._caps.get(self.name)
def __get__(self, obj: object, cls: type[object]):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
return obj.capabilities.get(self.name)

def __set__(self, obj, value):
if value in ("dismiss", "accept", "dismiss and notify", "accept and notify", "ignore"):
obj.set_capability(self.name, value)
else:
def __set__(self, obj: object, value: str):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
if value not in ("dismiss", "accept", "dismiss and notify", "accept and notify", "ignore"):
raise ValueError(
"Behavior can only be one of the following: dismiss, accept, dismiss and notify, "
"accept and notify, ignore"
)
obj.set_capability(self.name, value)


class _TimeoutsDescriptor:
@@ -120,13 +131,17 @@ class _TimeoutsDescriptor:
:returns: Values for implicit timeout, pageLoad timeout and script timeout if set (in milliseconds)
"""

def __init__(self, name):
def __init__(self, name: str):
self.name = name

def __get__(self, obj, cls):
return obj._caps.get(self.name)
def __get__(self, obj: object, cls: type[object]):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
return obj.capabilities.get(self.name)

def __set__(self, obj, value):
def __set__(self, obj: object, value: dict[str, Any]):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
if all(x in ("implicit", "pageLoad", "script") for x in value.keys()):
obj.set_capability(self.name, value)
else:
@@ -136,17 +151,19 @@ def __set__(self, obj, value):
class _ProxyDescriptor:
""":Returns: Proxy if set, otherwise None."""

def __init__(self, name):
def __init__(self, name: str):
self.name = name

def __get__(self, obj, cls):
def __get__(self, obj: object, cls: type[object]):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
return obj._proxy

def __set__(self, obj, value):
if not isinstance(value, Proxy):
raise InvalidArgumentException("Only Proxy objects can be passed in.")
def __set__(self, obj: object, value: Proxy):
if not isinstance(obj, BaseOptions):
raise ValueError("Invalid object: Expected an instance of BaseOptions.")
obj._proxy = value
obj._caps[self.name] = value.to_capabilities()
obj.capabilities[self.name] = value.to_capabilities()


class BaseOptions(metaclass=ABCMeta):
@@ -443,7 +460,7 @@ class BaseOptions(metaclass=ABCMeta):
def __init__(self) -> None:
super().__init__()
self._caps = self.default_capabilities
self._proxy = None
self._proxy: Proxy | None = None
self.set_capability("pageLoadStrategy", PageLoadStrategy.normal)
self.mobile_options = None
self._ignore_local_proxy = False
@@ -452,7 +469,7 @@ def __init__(self) -> None:
def capabilities(self):
return self._caps

def set_capability(self, name, value) -> None:
def set_capability(self, name: Any, value: Any) -> None:
"""Sets a capability."""
self._caps[name] = value

@@ -476,12 +493,12 @@ def enable_mobile(
self.mobile_options["androidDeviceSerial"] = device_serial

@abstractmethod
def to_capabilities(self):
def to_capabilities(self) -> dict[Any, Any]:
"""Convert options into capabilities dictionary."""

@property
@abstractmethod
def default_capabilities(self):
def default_capabilities(self) -> dict[Any, Any]:
"""Return minimal capabilities necessary as a dictionary."""

def ignore_local_proxy_environment_variables(self) -> None:
@@ -497,14 +514,14 @@ class ArgOptions(BaseOptions):

def __init__(self) -> None:
super().__init__()
self._arguments = []
self._arguments: list[Any] = []

@property
def arguments(self):
""":Returns: A list of arguments needed for the browser."""
return self._arguments

def add_argument(self, argument) -> None:
def add_argument(self, argument: Any) -> None:
"""Adds an argument to the list.

:Args:
@@ -533,5 +550,5 @@ def to_capabilities(self):
return self._caps

@property
def default_capabilities(self):
def default_capabilities(self) -> dict[Any, Any]:
return {}
29 changes: 18 additions & 11 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,9 @@
import string
import warnings
from base64 import b64encode
from typing import Any
from typing import Optional
from typing import TypeVar
from urllib import parse
from urllib.parse import urlparse

@@ -35,6 +37,11 @@

LOGGER = logging.getLogger(__name__)

# TODO: Replace with 'Self' when Python 3.11+ is supported.
# from typing import Self

RemoteConnectionType = TypeVar("RemoteConnectionType", bound="RemoteConnection")

remote_commands = {
Command.NEW_SESSION: ("POST", "/session"),
Command.QUIT: ("DELETE", "/session/$sessionId"),
@@ -158,7 +165,7 @@ class RemoteConnection:
else socket.getdefaulttimeout()
)
_ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()
_client_config: ClientConfig = None
_client_config: ClientConfig | None = None

system = platform.system().lower()
if system == "darwin":
@@ -169,7 +176,7 @@ class RemoteConnection:
user_agent = f"selenium/{__version__} (python {system})"

@classmethod
def get_timeout(cls):
def get_timeout(cls) -> float | int | None:
""":Returns:

Timeout value in seconds for all http requests made to the
@@ -183,7 +190,7 @@ def get_timeout(cls):
return cls._client_config.timeout

@classmethod
def set_timeout(cls, timeout):
def set_timeout(cls, timeout: int | float):
"""Override the default timeout.

:Args:
@@ -207,7 +214,7 @@ def reset_timeout(cls):
cls._client_config.reset_timeout()

@classmethod
def get_certificate_bundle_path(cls):
def get_certificate_bundle_path(cls) -> str:
""":Returns:

Paths of the .pem encoded certificate to verify connection to
@@ -222,7 +229,7 @@ def get_certificate_bundle_path(cls):
return cls._client_config.ca_certs

@classmethod
def set_certificate_bundle_path(cls, path):
def set_certificate_bundle_path(cls, path: str):
"""Set the path to the certificate bundle to verify connection to
command executor. Can also be set to None to disable certificate
validation.
@@ -238,7 +245,7 @@ def set_certificate_bundle_path(cls, path):
cls._client_config.ca_certs = path

@classmethod
def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
def get_remote_connection_headers(cls, parsed_url: str, keep_alive: bool = False) -> dict[str, Any]:
"""Get headers for remote request.

:Args:
@@ -309,7 +316,7 @@ def __init__(
keep_alive: Optional[bool] = True,
ignore_proxy: Optional[bool] = False,
ignore_certificates: Optional[bool] = False,
init_args_for_pool_manager: Optional[dict] = None,
init_args_for_pool_manager: Optional[dict[Any, Any]] = None,
client_config: Optional[ClientConfig] = None,
):
self._client_config = client_config or ClientConfig(
@@ -370,15 +377,15 @@ def __init__(

extra_commands = {}

def add_command(self, name, method, url):
def add_command(self, name: str, method: str, url: str):
"""Register a new command."""
self._commands[name] = (method, url)

def get_command(self, name: str):
"""Retrieve a command if it exists."""
return self._commands.get(name)

def execute(self, command, params):
def execute(self, command: str, params: dict[Any, Any]) -> dict[str, Any]:
"""Send a command to the remote server.

Any path substitutions required for the URL mapped to the command should be
@@ -403,7 +410,7 @@ def execute(self, command, params):
LOGGER.debug("%s %s %s", command_info[0], url, str(trimmed))
return self._request(command_info[0], url, body=data)

def _request(self, method, url, body=None):
def _request(self, method: str, url: str, body: str | None = None) -> dict[Any, Any]:
"""Send an HTTP request to the remote server.

:Args:
@@ -470,7 +477,7 @@ def close(self):
if hasattr(self, "_conn"):
self._conn.clear()

def _trim_large_entries(self, input_dict, max_length=100):
def _trim_large_entries(self, input_dict: dict[Any, Any], max_length: int = 100) -> dict[str, str]:
"""Truncate string values in a dictionary if they exceed max_length.

:param dict: Dictionary with potentially large values
4 changes: 2 additions & 2 deletions py/selenium/webdriver/remote/script_key.py
Original file line number Diff line number Diff line change
@@ -19,14 +19,14 @@


class ScriptKey:
def __init__(self, id=None):
def __init__(self, id: uuid.UUID | str | None = None):
self._id = id or uuid.uuid4()

@property
def id(self):
return self._id

def __eq__(self, other):
def __eq__(self, other: object):
return self._id == other

def __repr__(self) -> str:
Loading
Oops, something went wrong.