From e07eb44d5498a312953cf7db41328cc0dfa723d1 Mon Sep 17 00:00:00 2001 From: Carlos Villavicencio Date: Tue, 28 Oct 2025 15:51:51 -0500 Subject: [PATCH 1/3] Add type annotations (from #393) --- setup.py | 6 +- shotgun_api3/lib/mockgun/mockgun.py | 5 +- shotgun_api3/lib/mockgun/schema.py | 2 +- shotgun_api3/py.typed | 0 shotgun_api3/shotgun.py | 585 +++++++++++++++++----------- 5 files changed, 372 insertions(+), 226 deletions(-) create mode 100644 shotgun_api3/py.typed diff --git a/setup.py b/setup.py index 8d903f5f1..0ddda9d79 100644 --- a/setup.py +++ b/setup.py @@ -30,17 +30,17 @@ packages=find_packages(exclude=("tests",)), script_args=sys.argv[1:], include_package_data=True, - package_data={"": ["cacerts.txt", "cacert.pem"]}, + package_data={"": ["cacerts.txt", "cacert.pem", "py.typed"]}, zip_safe=False, - python_requires=">=3.7.0", + python_requires=">=3.9.0", classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Programming Language :: Python", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Operating System :: OS Independent", ], ) diff --git a/shotgun_api3/lib/mockgun/mockgun.py b/shotgun_api3/lib/mockgun/mockgun.py index 45d0b2aa5..522e162d9 100644 --- a/shotgun_api3/lib/mockgun/mockgun.py +++ b/shotgun_api3/lib/mockgun/mockgun.py @@ -115,6 +115,7 @@ """ import datetime +from typing import Any from ... import ShotgunError from ...shotgun import _Config @@ -580,7 +581,7 @@ def _get_new_row(self, entity_type): row[field] = default_value return row - def _compare(self, field_type, lval, operator, rval): + def _compare(self, field_type: str, lval: Any, operator: str, rval: Any) -> bool: """ Compares a field using the operator and value provide by the filter. @@ -797,7 +798,7 @@ def _row_matches_filter(self, entity_type, row, sg_filter, retired_only): return self._compare(field_type, lval, operator, rval) - def _rearrange_filters(self, filters): + def _rearrange_filters(self, filters: list) -> None: """ Modifies the filter syntax to turn it into a list of three items regardless of the actual filter. Most of the filters are list of three elements, so this doesn't change much. diff --git a/shotgun_api3/lib/mockgun/schema.py b/shotgun_api3/lib/mockgun/schema.py index ab671629d..f5d9312cc 100644 --- a/shotgun_api3/lib/mockgun/schema.py +++ b/shotgun_api3/lib/mockgun/schema.py @@ -47,7 +47,7 @@ class SchemaFactory(object): _schema_cache_path = None @classmethod - def get_schemas(cls, schema_path, schema_entity_path): + def get_schemas(cls, schema_path: str, schema_entity_path: str) -> tuple: """ Retrieves the schemas from disk. diff --git a/shotgun_api3/py.typed b/shotgun_api3/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/shotgun_api3/shotgun.py b/shotgun_api3/shotgun.py index 6400ff373..5f98e781b 100644 --- a/shotgun_api3/shotgun.py +++ b/shotgun_api3/shotgun.py @@ -50,6 +50,18 @@ import urllib.request import uuid # used for attachment upload import xml.etree.ElementTree +from typing import ( + Any, + BinaryIO, + Iterable, + Literal, + NoReturn, + Optional, + TypedDict, + TypeVar, + Union, + TYPE_CHECKING, +) # Import Error and ResponseError (even though they're unused in this file) since they need # to be exposed as part of the API. @@ -83,6 +95,30 @@ # Version __version__ = "3.9.0" + +# ---------------------------------------------------------------------------- +# Types + + +T = TypeVar("T") + + +class OrderItem(TypedDict): + field_name: str + direction: str + + +class GroupingItem(TypedDict): + field: str + type: str + direction: str + + +class BaseEntity(TypedDict, total=False): + id: int + type: str + + # ---------------------------------------------------------------------------- # Errors @@ -168,7 +204,7 @@ class ServerCapabilities(object): the future. Therefore, usage of this class is discouraged. """ - def __init__(self, host, meta): + def __init__(self, host: str, meta: dict[str, Any]) -> None: """ ServerCapabilities.__init__ @@ -208,14 +244,14 @@ def __init__(self, host, meta): self.version = tuple(self.version[:3]) self._ensure_json_supported() - def _ensure_python_version_supported(self): + def _ensure_python_version_supported(self) -> None: """ Checks the if current Python version is supported. """ if sys.version_info < (3, 7): raise ShotgunError("This module requires Python version 3.7 or higher.") - def _ensure_support(self, feature, raise_hell=True): + def _ensure_support(self, feature: dict[str, Any], raise_hell: bool = True) -> bool: """ Checks the server version supports a given feature, raises an exception if it does not. @@ -243,13 +279,13 @@ def _ensure_support(self, feature, raise_hell=True): else: return True - def _ensure_json_supported(self): + def _ensure_json_supported(self) -> None: """ Ensures server has support for JSON API endpoint added in v2.4.0. """ self._ensure_support({"version": (2, 4, 0), "label": "JSON API"}) - def ensure_include_archived_projects(self): + def ensure_include_archived_projects(self) -> None: """ Ensures server has support for archived Projects feature added in v5.3.14. """ @@ -257,7 +293,7 @@ def ensure_include_archived_projects(self): {"version": (5, 3, 14), "label": "include_archived_projects parameter"} ) - def ensure_per_project_customization(self): + def ensure_per_project_customization(self) -> bool: """ Ensures server has support for per-project customization feature added in v5.4.4. """ @@ -265,7 +301,7 @@ def ensure_per_project_customization(self): {"version": (5, 4, 4), "label": "project parameter"}, True ) - def ensure_support_for_additional_filter_presets(self): + def ensure_support_for_additional_filter_presets(self) -> bool: """ Ensures server has support for additional filter presets feature added in v7.0.0. """ @@ -273,7 +309,7 @@ def ensure_support_for_additional_filter_presets(self): {"version": (7, 0, 0), "label": "additional_filter_presets parameter"}, True ) - def ensure_user_following_support(self): + def ensure_user_following_support(self) -> bool: """ Ensures server has support for listing items a user is following, added in v7.0.12. """ @@ -281,7 +317,7 @@ def ensure_user_following_support(self): {"version": (7, 0, 12), "label": "user_following parameter"}, True ) - def ensure_paging_info_without_counts_support(self): + def ensure_paging_info_without_counts_support(self) -> bool: """ Ensures server has support for optimized pagination, added in v7.4.0. """ @@ -289,7 +325,7 @@ def ensure_paging_info_without_counts_support(self): {"version": (7, 4, 0), "label": "optimized pagination"}, False ) - def ensure_return_image_urls_support(self): + def ensure_return_image_urls_support(self) -> bool: """ Ensures server has support for returning thumbnail URLs without additional round-trips, added in v3.3.0. """ @@ -297,7 +333,7 @@ def ensure_return_image_urls_support(self): {"version": (3, 3, 0), "label": "return thumbnail URLs"}, False ) - def __str__(self): + def __str__(self) -> str: return "ServerCapabilities: host %s, version %s, is_dev %s" % ( self.host, self.version, @@ -355,7 +391,7 @@ class _Config(object): Container for the client configuration. """ - def __init__(self, sg): + def __init__(self, sg: "Shotgun"): """ :param sg: Shotgun connection. """ @@ -376,41 +412,41 @@ def __init__(self, sg): # If the optional timeout parameter is given, blocking operations # (like connection attempts) will timeout after that many seconds # (if it is not given, the global default timeout setting is used) - self.timeout_secs = None + self.timeout_secs: Optional[float] = None self.api_ver = "api3" self.convert_datetimes_to_utc = True - self._records_per_page = None - self.api_key = None - self.script_name = None - self.user_login = None - self.user_password = None - self.auth_token = None - self.sudo_as_login = None + self._records_per_page: Optional[int] = None + self.api_key: Optional[str] = None + self.script_name: Optional[str] = None + self.user_login: Optional[str] = None + self.user_password: Optional[str] = None + self.auth_token: Optional[str] = None + self.sudo_as_login: Optional[str] = None # Authentication parameters to be folded into final auth_params dict - self.extra_auth_params = None + self.extra_auth_params: Optional[dict[str, Any]] = None # uuid as a string - self.session_uuid = None - self.scheme = None - self.server = None - self.api_path = None + self.session_uuid: Optional[str] = None + self.scheme: Optional[str] = None + self.server: Optional[str] = None + self.api_path: Optional[str] = None # The raw_http_proxy reflects the exact string passed in # to the Shotgun constructor. This can be useful if you # need to construct a Shotgun API instance based on # another Shotgun API instance. - self.raw_http_proxy = None + self.raw_http_proxy: Optional[str] = None # if a proxy server is being used, the proxy_handler # below will contain a urllib2.ProxyHandler instance # which can be used whenever a request needs to be made. - self.proxy_handler = None - self.proxy_server = None + self.proxy_handler: Optional["urllib.request.ProxyHandler"] = None + self.proxy_server: Optional[str] = None self.proxy_port = 8080 - self.proxy_user = None - self.proxy_pass = None - self.session_token = None - self.authorization = None + self.proxy_user: Optional[str] = None + self.proxy_pass: Optional[str] = None + self.session_token: Optional[str] = None + self.authorization: Optional[str] = None self.localized = False - def set_server_params(self, base_url): + def set_server_params(self, base_url: str) -> None: """ Set the different server related fields based on the passed in URL. @@ -432,7 +468,7 @@ def set_server_params(self, base_url): ) @property - def records_per_page(self): + def records_per_page(self) -> int: """ The records per page value from the server. """ @@ -465,19 +501,19 @@ class Shotgun(object): def __init__( self, - base_url, - script_name=None, - api_key=None, - convert_datetimes_to_utc=True, - http_proxy=None, - connect=True, - ca_certs=None, - login=None, - password=None, - sudo_as_login=None, - session_token=None, - auth_token=None, - ): + base_url: str, + script_name: Optional[str] = None, + api_key: Optional[str] = None, + convert_datetimes_to_utc: bool = True, + http_proxy: Optional[str] = None, + connect: bool = True, + ca_certs: Optional[str] = None, + login: Optional[str] = None, + password: Optional[str] = None, + sudo_as_login: Optional[str] = None, + session_token: Optional[str] = None, + auth_token: Optional[str] = None, + ) -> None: """ Initializes a new instance of the Shotgun client. @@ -589,7 +625,7 @@ def __init__( "must provide login/password, session_token or script_name/api_key" ) - self.config = _Config(self) + self.config: _Config = _Config(self) self.config.api_key = api_key self.config.script_name = script_name self.config.user_login = login @@ -625,7 +661,7 @@ def __init__( ): SHOTGUN_API_DISABLE_ENTITY_OPTIMIZATION = True - self._connection = None + self._connection: Optional[Http] = None self.__ca_certs = self._get_certs_file(ca_certs) @@ -690,7 +726,7 @@ def __init__( # this relies on self.client_caps being set first self.reset_user_agent() - self._server_caps = None + self._server_caps: Optional[ServerCapabilities] = None # test to ensure the the server supports the json API # call to server will only be made once and will raise error if connect: @@ -704,7 +740,7 @@ def __init__( self.config.user_password = None self.config.auth_token = None - def _split_url(self, base_url): + def _split_url(self, base_url: str) -> tuple[str, str]: """ Extract the hostname:port and username/password/token from base_url sent when connect to the API. @@ -736,7 +772,7 @@ def _split_url(self, base_url): # API Functions @property - def server_info(self): + def server_info(self) -> dict[str, Any]: """ Property containing server information. @@ -754,7 +790,7 @@ def server_info(self): return self.server_caps.server_info @property - def server_caps(self): + def server_caps(self) -> ServerCapabilities: """ Property containing :class:`ServerCapabilities` object. @@ -769,7 +805,7 @@ def server_caps(self): self._server_caps = ServerCapabilities(self.config.server, self.info()) return self._server_caps - def connect(self): + def connect(self) -> None: """ Connect client to the server if it is not already connected. @@ -780,7 +816,7 @@ def connect(self): self.info() return - def close(self): + def close(self) -> None: """ Close the current connection to the server. @@ -789,7 +825,7 @@ def close(self): self._close_connection() return - def info(self): + def info(self) -> dict[str, Any]: """ Get API-related metadata from the Shotgun server. @@ -822,15 +858,15 @@ def info(self): def find_one( self, - entity_type, - filters, - fields=None, - order=None, - filter_operator=None, - retired_only=False, - include_archived_projects=True, - additional_filter_presets=None, - ): + entity_type: str, + filters: Union[list, tuple, dict[str, Any]], + fields: Optional[list[str]] = None, + order: Optional[list[OrderItem]] = None, + filter_operator: Optional[Literal["all", "any"]] = None, + retired_only: bool = False, + include_archived_projects: bool = True, + additional_filter_presets: Optional[list[dict[str, Any]]] = None, + ) -> Optional[BaseEntity]: """ Shortcut for :meth:`~shotgun_api3.Shotgun.find` with ``limit=1`` so it returns a single result. @@ -845,7 +881,7 @@ def find_one( :param list fields: Optional list of fields to include in each entity record returned. Defaults to ``["id"]``. - :param int order: Optional list of fields to order the results by. List has the format:: + :param list order: Optional list of fields to order the results by. List has the format:: [ {'field_name':'foo', 'direction':'asc'}, @@ -862,7 +898,7 @@ def find_one( same query. :param bool include_archived_projects: Optional boolean flag to include entities whose projects have been archived. Defaults to ``True``. - :param additional_filter_presets: Optional list of presets to further filter the result + :param list additional_filter_presets: Optional list of presets to further filter the result set, list has the form:: [{ @@ -902,17 +938,17 @@ def find_one( def find( self, - entity_type, - filters, - fields=None, - order=None, - filter_operator=None, - limit=0, - retired_only=False, - page=0, - include_archived_projects=True, - additional_filter_presets=None, - ): + entity_type: str, + filters: Union[list, tuple, dict[str, Any]], + fields: Optional[list[str]] = None, + order: Optional[list[OrderItem]] = None, + filter_operator: Optional[Literal["all", "any"]] = None, + limit: int = 0, + retired_only: bool = False, + page: int = 0, + include_archived_projects: bool = True, + additional_filter_presets: Optional[list[dict[str, Any]]] = None, + ) -> list[BaseEntity]: """ Find entities matching the given filters. @@ -990,7 +1026,7 @@ def find( same query. :param bool include_archived_projects: Optional boolean flag to include entities whose projects have been archived. Defaults to ``True``. - :param additional_filter_presets: Optional list of presets to further filter the result + :param list additional_filter_presets: Optional list of presets to further filter the result set, list has the form:: [{ @@ -1101,15 +1137,15 @@ def find( def _construct_read_parameters( self, - entity_type, - fields, - filters, - retired_only, - order, - include_archived_projects, - additional_filter_presets, - ): - params = {} + entity_type: str, + fields: Optional[list[str]], + filters: dict[str, Any], + retired_only: bool, + order: Optional[list[dict[str, Any]]], + include_archived_projects: bool, + additional_filter_presets: Optional[list[dict[str, Any]]], + ) -> dict[str, Any]: + params: dict[str, Any] = {} params["type"] = entity_type params["return_fields"] = fields or ["id"] params["filters"] = filters @@ -1139,7 +1175,9 @@ def _construct_read_parameters( params["sorts"] = sort_list return params - def _add_project_param(self, params, project_entity): + def _add_project_param( + self, params: dict[str, Any], project_entity + ) -> dict[str, Any]: if project_entity and self.server_caps.ensure_per_project_customization(): params["project"] = project_entity @@ -1147,8 +1185,12 @@ def _add_project_param(self, params, project_entity): return params def _translate_update_params( - self, entity_type, entity_id, data, multi_entity_update_modes - ): + self, + entity_type: str, + entity_id: int, + data: dict, + multi_entity_update_modes: Optional[dict], + ) -> dict[str, Any]: global SHOTGUN_API_DISABLE_ENTITY_OPTIMIZATION def optimize_field(field_dict): @@ -1170,13 +1212,13 @@ def optimize_field(field_dict): def summarize( self, - entity_type, - filters, - summary_fields, - filter_operator=None, - grouping=None, - include_archived_projects=True, - ): + entity_type: str, + filters: Union[list, dict[str, Any]], + summary_fields: list[dict[str, str]], + filter_operator: Optional[str] = None, + grouping: Optional[list[GroupingItem]] = None, + include_archived_projects: bool = True, + ) -> dict[str, Any]: """ Summarize field data returned by a query. @@ -1376,7 +1418,12 @@ def summarize( records = self._call_rpc("summarize", params) return records - def create(self, entity_type, data, return_fields=None): + def create( + self, + entity_type: str, + data: dict[str, Any], + return_fields: Optional[list] = None, + ) -> dict[str, Any]: """ Create a new entity of the specified ``entity_type``. @@ -1459,7 +1506,13 @@ def create(self, entity_type, data, return_fields=None): return result - def update(self, entity_type, entity_id, data, multi_entity_update_modes=None): + def update( + self, + entity_type: str, + entity_id: int, + data: dict[str, Any], + multi_entity_update_modes: Optional[dict[str, Any]] = None, + ) -> BaseEntity: """ Update the specified entity with the supplied data. @@ -1538,7 +1591,7 @@ def update(self, entity_type, entity_id, data, multi_entity_update_modes=None): return result - def delete(self, entity_type, entity_id): + def delete(self, entity_type: str, entity_id: int) -> bool: """ Retire the specified entity. @@ -1562,7 +1615,7 @@ def delete(self, entity_type, entity_id): return self._call_rpc("delete", params) - def revive(self, entity_type, entity_id): + def revive(self, entity_type: str, entity_id: int) -> bool: """ Revive an entity that has previously been deleted. @@ -1580,7 +1633,7 @@ def revive(self, entity_type, entity_id): return self._call_rpc("revive", params) - def batch(self, requests): + def batch(self, requests: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Make a batch request of several :meth:`~shotgun_api3.Shotgun.create`, :meth:`~shotgun_api3.Shotgun.update`, and :meth:`~shotgun_api3.Shotgun.delete` calls. @@ -1695,7 +1748,13 @@ def _required_keys(message, required_keys, data): records = self._call_rpc("batch", calls) return self._parse_records(records) - def work_schedule_read(self, start_date, end_date, project=None, user=None): + def work_schedule_read( + self, + start_date: str, + end_date: str, + project: Optional[dict[str, Any]] = None, + user: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: """ Return the work day rules for a given date range. @@ -1766,13 +1825,13 @@ def work_schedule_read(self, start_date, end_date, project=None, user=None): def work_schedule_update( self, - date, - working, - description=None, - project=None, - user=None, - recalculate_field=None, - ): + date: str, + working: bool, + description: Optional[str] = None, + project: Optional[dict[str, Any]] = None, + user: Optional[dict[str, Any]] = None, + recalculate_field: Optional[str] = None, + ) -> dict[str, Any]: """ Update the work schedule for a given date. @@ -1826,7 +1885,7 @@ def work_schedule_update( return self._call_rpc("work_schedule_update", params) - def follow(self, user, entity): + def follow(self, user: dict[str, Any], entity: dict[str, Any]) -> dict[str, Any]: """ Add the entity to the user's followed entities. @@ -1854,7 +1913,7 @@ def follow(self, user, entity): return self._call_rpc("follow", params) - def unfollow(self, user, entity): + def unfollow(self, user: dict[str, Any], entity: dict[str, Any]) -> dict[str, Any]: """ Remove entity from the user's followed entities. @@ -1881,7 +1940,7 @@ def unfollow(self, user, entity): return self._call_rpc("unfollow", params) - def followers(self, entity): + def followers(self, entity: dict[str, Any]) -> list[dict[str, Any]]: """ Return all followers for an entity. @@ -1909,7 +1968,12 @@ def followers(self, entity): return self._call_rpc("followers", params) - def following(self, user, project=None, entity_type=None): + def following( + self, + user: dict[str, Any], + project: Optional[dict[str, Any]] = None, + entity_type: Optional[str] = None, + ) -> list[BaseEntity]: """ Return all entity instances a user is following. @@ -1940,7 +2004,9 @@ def following(self, user, project=None, entity_type=None): return self._call_rpc("following", params) - def schema_entity_read(self, project_entity=None): + def schema_entity_read( + self, project_entity: Optional[BaseEntity] = None + ) -> dict[str, dict[str, Any]]: """ Return all active entity types, their display names, and their visibility. @@ -1984,7 +2050,9 @@ def schema_entity_read(self, project_entity=None): else: return self._call_rpc("schema_entity_read", None) - def schema_read(self, project_entity=None): + def schema_read( + self, project_entity: Optional[BaseEntity] = None + ) -> dict[str, dict[str, Any]]: """ Get the schema for all fields on all entities. @@ -2056,7 +2124,12 @@ def schema_read(self, project_entity=None): else: return self._call_rpc("schema_read", None) - def schema_field_read(self, entity_type, field_name=None, project_entity=None): + def schema_field_read( + self, + entity_type: str, + field_name: Optional[str] = None, + project_entity: Optional[BaseEntity] = None, + ) -> dict[str, dict[str, Any]]: """ Get schema for all fields on the specified entity type or just the field name specified if provided. @@ -2121,8 +2194,12 @@ def schema_field_read(self, entity_type, field_name=None, project_entity=None): return self._call_rpc("schema_field_read", params) def schema_field_create( - self, entity_type, data_type, display_name, properties=None - ): + self, + entity_type: str, + data_type: str, + display_name: str, + properties: Optional[dict[str, Any]] = None, + ) -> str: """ Create a field for the specified entity type. @@ -2160,8 +2237,12 @@ def schema_field_create( return self._call_rpc("schema_field_create", params) def schema_field_update( - self, entity_type, field_name, properties, project_entity=None - ): + self, + entity_type: str, + field_name: str, + properties: dict[str, Any], + project_entity: Optional[BaseEntity] = None, + ) -> bool: """ Update the properties for the specified field on an entity. @@ -2175,9 +2256,9 @@ def schema_field_update( >>> sg.schema_field_update("Asset", "sg_test_number", properties) True - :param entity_type: Entity type of field to update. - :param field_name: Internal Shotgun name of the field to update. - :param properties: Dictionary with key/value pairs where the key is the property to be + :param str entity_type: Entity type of field to update. + :param str field_name: Internal Shotgun name of the field to update. + :param dict properties: Dictionary with key/value pairs where the key is the property to be updated and the value is the new value. :param dict project_entity: Optional Project entity specifying which project to modify the ``visible`` property for. If ``visible`` is present in ``properties`` and @@ -2202,7 +2283,7 @@ def schema_field_update( params = self._add_project_param(params, project_entity) return self._call_rpc("schema_field_update", params) - def schema_field_delete(self, entity_type, field_name): + def schema_field_delete(self, entity_type: str, field_name: str) -> bool: """ Delete the specified field from the entity type. @@ -2219,7 +2300,7 @@ def schema_field_delete(self, entity_type, field_name): return self._call_rpc("schema_field_delete", params) - def add_user_agent(self, agent): + def add_user_agent(self, agent: str) -> None: """ Add agent to the user-agent header. @@ -2231,7 +2312,7 @@ def add_user_agent(self, agent): """ self._user_agents.append(agent) - def reset_user_agent(self): + def reset_user_agent(self) -> None: """ Reset user agent to the default value. @@ -2251,7 +2332,7 @@ def reset_user_agent(self): "ssl %s" % (self.client_caps.ssl_version), ] - def set_session_uuid(self, session_uuid): + def set_session_uuid(self, session_uuid: str) -> None: """ Set the browser session_uuid in the current Shotgun API instance. @@ -2269,12 +2350,12 @@ def set_session_uuid(self, session_uuid): def share_thumbnail( self, - entities, - thumbnail_path=None, - source_entity=None, - filmstrip_thumbnail=False, - **kwargs, - ): + entities: list[dict[str, Any]], + thumbnail_path: Optional[str] = None, + source_entity: Optional[BaseEntity] = None, + filmstrip_thumbnail: bool = False, + **kwargs: Any, + ) -> int: """ Associate a thumbnail with more than one Shotgun entity. @@ -2413,7 +2494,9 @@ def share_thumbnail( return attachment_id - def upload_thumbnail(self, entity_type, entity_id, path, **kwargs): + def upload_thumbnail( + self, entity_type: str, entity_id: int, path: str, **kwargs: Any + ) -> int: """ Upload a file from a local path and assign it as the thumbnail for the specified entity. @@ -2438,12 +2521,15 @@ def upload_thumbnail(self, entity_type, entity_id, path, **kwargs): :param int entity_id: Id of the entity to set the thumbnail for. :param str path: Full path to the thumbnail file on disk. :returns: Id of the new attachment + :rtype: int """ return self.upload( entity_type, entity_id, path, field_name="thumb_image", **kwargs ) - def upload_filmstrip_thumbnail(self, entity_type, entity_id, path, **kwargs): + def upload_filmstrip_thumbnail( + self, entity_type: str, entity_id: int, path: str, **kwargs: Any + ) -> int: """ Upload filmstrip thumbnail to specified entity. @@ -2494,13 +2580,13 @@ def upload_filmstrip_thumbnail(self, entity_type, entity_id, path, **kwargs): def upload( self, - entity_type, - entity_id, - path, - field_name=None, - display_name=None, - tag_list=None, - ): + entity_type: str, + entity_id: int, + path: str, + field_name: Optional[str] = None, + display_name: Optional[str] = None, + tag_list: Optional[str] = None, + ) -> int: """ Upload a file to the specified entity. @@ -2583,14 +2669,14 @@ def upload( def _upload_to_storage( self, - entity_type, - entity_id, - path, - field_name, - display_name, - tag_list, - is_thumbnail, - ): + entity_type: str, + entity_id: int, + path: str, + field_name: Optional[str], + display_name: Optional[str], + tag_list: Optional[str], + is_thumbnail: bool, + ) -> int: """ Internal function to upload a file to the Cloud storage and link it to the specified entity. @@ -2673,14 +2759,14 @@ def _upload_to_storage( def _upload_to_sg( self, - entity_type, - entity_id, - path, - field_name, - display_name, - tag_list, - is_thumbnail, - ): + entity_type: str, + entity_id: int, + path: str, + field_name: Optional[str], + display_name: Optional[str], + tag_list: Optional[str], + is_thumbnail: bool, + ) -> int: """ Internal function to upload a file to Shotgun and link it to the specified entity. @@ -2752,7 +2838,9 @@ def _upload_to_sg( attachment_id = int(result.split(":", 2)[1].split("\n", 1)[0]) return attachment_id - def _get_attachment_upload_info(self, is_thumbnail, filename, is_multipart_upload): + def _get_attachment_upload_info( + self, is_thumbnail: bool, filename: str, is_multipart_upload: bool + ) -> dict[str, Any]: """ Internal function to get the information needed to upload a file to Cloud storage. @@ -2799,7 +2887,12 @@ def _get_attachment_upload_info(self, is_thumbnail, filename, is_multipart_uploa "upload_info": upload_info, } - def download_attachment(self, attachment=False, file_path=None, attachment_id=None): + def download_attachment( + self, + attachment: Union[dict[str, Any], Literal[False]] = False, + file_path: Optional[str] = None, + attachment_id: Optional[int] = None, + ) -> Union[str, bytes, None]: """ Download the file associated with a Shotgun Attachment. @@ -2915,7 +3008,7 @@ def download_attachment(self, attachment=False, file_path=None, attachment_id=No else: return attachment - def get_auth_cookie_handler(self): + def get_auth_cookie_handler(self) -> urllib.request.HTTPCookieProcessor: """ Return an urllib cookie handler containing a cookie for FPTR authentication. @@ -2947,7 +3040,9 @@ def get_auth_cookie_handler(self): cj.set_cookie(c) return urllib.request.HTTPCookieProcessor(cj) - def get_attachment_download_url(self, attachment): + def get_attachment_download_url( + self, attachment: Optional[Union[int, dict[str, Any]]] + ) -> str: """ Return the URL for downloading provided Attachment. @@ -3005,7 +3100,9 @@ def get_attachment_download_url(self, attachment): ) return url - def authenticate_human_user(self, user_login, user_password, auth_token=None): + def authenticate_human_user( + self, user_login: str, user_password: str, auth_token: Optional[str] = None + ) -> dict[str, Any]: """ Authenticate Shotgun HumanUser. @@ -3064,7 +3161,9 @@ def authenticate_human_user(self, user_login, user_password, auth_token=None): self.config.auth_token = original_auth_token raise - def update_project_last_accessed(self, project, user=None): + def update_project_last_accessed( + self, project: dict[str, Any], user: Optional[dict[str, Any]] = None + ) -> None: """ Update a Project's ``last_accessed_by_current_user`` field to the current timestamp. @@ -3110,7 +3209,9 @@ def update_project_last_accessed(self, project, user=None): record = self._call_rpc("update_project_last_accessed_by_current_user", params) self._parse_records(record)[0] - def note_thread_read(self, note_id, entity_fields=None): + def note_thread_read( + self, note_id: int, entity_fields: Optional[dict[str, Any]] = None + ) -> list[dict[str, Any]]: """ Return the full conversation for a given note, including Replies and Attachments. @@ -3185,7 +3286,13 @@ def note_thread_read(self, note_id, entity_fields=None): result = self._parse_records(record) return result - def text_search(self, text, entity_types, project_ids=None, limit=None): + def text_search( + self, + text: str, + entity_types: dict[str, Any], + project_ids: Optional[list] = None, + limit: Optional[int] = None, + ) -> dict[str, Any]: """ Search across the specified entity types for the given text. @@ -3279,13 +3386,13 @@ def text_search(self, text, entity_types, project_ids=None, limit=None): def activity_stream_read( self, - entity_type, - entity_id, - entity_fields=None, - min_id=None, - max_id=None, - limit=None, - ): + entity_type: str, + entity_id: int, + entity_fields: Optional[dict[str, Any]] = None, + min_id: Optional[int] = None, + max_id: Optional[int] = None, + limit: Optional[int] = None, + ) -> dict[str, Any]: """ Retrieve activity stream data from Shotgun. @@ -3375,7 +3482,7 @@ def activity_stream_read( result = self._parse_records(record)[0] return result - def nav_expand(self, path, seed_entity_field=None, entity_fields=None): + def nav_expand(self, path: str, seed_entity_field=None, entity_fields=None): """ Expand the navigation hierarchy for the supplied path. @@ -3395,7 +3502,9 @@ def nav_expand(self, path, seed_entity_field=None, entity_fields=None): }, ) - def nav_search_string(self, root_path, search_string, seed_entity_field=None): + def nav_search_string( + self, root_path: str, search_string: str, seed_entity_field=None + ): """ Search function adapted to work with the navigation hierarchy. @@ -3414,7 +3523,12 @@ def nav_search_string(self, root_path, search_string, seed_entity_field=None): }, ) - def nav_search_entity(self, root_path, entity, seed_entity_field=None): + def nav_search_entity( + self, + root_path: str, + entity: dict[str, Any], + seed_entity_field: Optional[dict[str, Any]] = None, + ): """ Search function adapted to work with the navigation hierarchy. @@ -3434,7 +3548,7 @@ def nav_search_entity(self, root_path, entity, seed_entity_field=None): }, ) - def get_session_token(self): + def get_session_token(self) -> str: """ Get the session token associated with the current session. @@ -3458,7 +3572,7 @@ def get_session_token(self): return session_token - def preferences_read(self, prefs=None): + def preferences_read(self, prefs: Optional[list] = None) -> dict[str, Any]: """ Get a subset of the site preferences. @@ -3481,7 +3595,7 @@ def preferences_read(self, prefs=None): return self._call_rpc("preferences_read", {"prefs": prefs}) - def user_subscriptions_read(self): + def user_subscriptions_read(self) -> list: """ Get the list of user subscriptions. @@ -3493,8 +3607,9 @@ def user_subscriptions_read(self): return self._call_rpc("user_subscriptions_read", None) - def user_subscriptions_create(self, users): - # type: (list[dict[str, Union[str, list[str], None]) -> bool + def user_subscriptions_create( + self, users: list[dict[str, Union[str, list[str], None]]] + ) -> bool: """ Assign subscriptions to users. @@ -3515,7 +3630,7 @@ def user_subscriptions_create(self, users): return response.get("status") == "success" - def _build_opener(self, handler): + def _build_opener(self, handler) -> urllib.request.OpenerDirector: """ Build urllib2 opener with appropriate proxy handler. """ @@ -3616,7 +3731,13 @@ def entity_types(self): # ======================================================================== # RPC Functions - def _call_rpc(self, method, params, include_auth_params=True, first=False): + def _call_rpc( + self, + method: str, + params: Any, + include_auth_params: bool = True, + first: bool = False, + ) -> Any: """ Call the specified method on the Shotgun Server sending the supplied payload. """ @@ -3680,7 +3801,7 @@ def _call_rpc(self, method, params, include_auth_params=True, first=False): return results[0] return results - def _auth_params(self): + def _auth_params(self) -> dict[str, Any]: """ Return a dictionary of the authentication parameters being used. """ @@ -3735,7 +3856,7 @@ def _auth_params(self): return auth_params - def _sanitize_auth_params(self, params): + def _sanitize_auth_params(self, params: dict[str, Any]) -> dict[str, Any]: """ Given an authentication parameter dictionary, sanitize any sensitive information and return the sanitized dict copy. @@ -3746,7 +3867,9 @@ def _sanitize_auth_params(self, params): sanitized_params[k] = "********" return sanitized_params - def _build_payload(self, method, params, include_auth_params=True): + def _build_payload( + self, method: str, params, include_auth_params: bool = True + ) -> dict[str, Any]: """ Build the payload to be send to the rpc endpoint. """ @@ -3764,7 +3887,7 @@ def _build_payload(self, method, params, include_auth_params=True): return {"method_name": method, "params": call_params} - def _encode_payload(self, payload): + def _encode_payload(self, payload) -> bytes: """ Encode the payload to a string to be passed to the rpc endpoint. @@ -3775,7 +3898,9 @@ def _encode_payload(self, payload): return json.dumps(payload, ensure_ascii=False).encode("utf-8") - def _make_call(self, verb, path, body, headers): + def _make_call( + self, verb: str, path: str, body, headers: Optional[dict[str, Any]] + ) -> tuple[tuple[int, str], dict[str, Any], str]: """ Make an HTTP call to the server. @@ -3825,7 +3950,9 @@ def _make_call(self, verb, path, body, headers): ) time.sleep(rpc_attempt_interval) - def _http_request(self, verb, path, body, headers): + def _http_request( + self, verb: str, path: str, body, headers: dict[str, Any] + ) -> tuple[tuple[int, str], dict[str, Any], str]: """ Make the actual HTTP request. """ @@ -3849,7 +3976,9 @@ def _http_request(self, verb, path, body, headers): return (http_status, resp_headers, resp_body) - def _make_upload_request(self, request, opener): + def _make_upload_request( + self, request, opener: "urllib.request.OpenerDirector" + ) -> "urllib.request._UrlopenRet": """ Open the given request object, return the response, raises URLError on protocol errors. @@ -3861,7 +3990,7 @@ def _make_upload_request(self, request, opener): raise return result - def _parse_http_status(self, status): + def _parse_http_status(self, status: tuple) -> None: """ Parse the status returned from the http request. @@ -3879,7 +4008,9 @@ def _parse_http_status(self, status): return - def _decode_response(self, headers, body): + def _decode_response( + self, headers: dict[str, Any], body: str + ) -> Union[str, dict[str, Any]]: """ Decode the response from the server from the wire format to a python data structure. @@ -3900,7 +4031,7 @@ def _decode_response(self, headers, body): return self._json_loads(body) return body - def _json_loads(self, body): + def _json_loads(self, body: str) -> Any: return json.loads(body) def _response_errors(self, sg_response): @@ -3949,7 +4080,7 @@ def _response_errors(self, sg_response): raise Fault(sg_response.get("message", "Unknown Error")) return - def _visit_data(self, data, visitor): + def _visit_data(self, data: T, visitor) -> T: """ Walk the data (simple python types) and call the visitor. """ @@ -3959,17 +4090,17 @@ def _visit_data(self, data, visitor): recursive = self._visit_data if isinstance(data, list): - return [recursive(i, visitor) for i in data] + return [recursive(i, visitor) for i in data] # type: ignore[return-value] if isinstance(data, tuple): - return tuple(recursive(i, visitor) for i in data) + return tuple(recursive(i, visitor) for i in data) # type: ignore[return-value] if isinstance(data, dict): - return dict((k, recursive(v, visitor)) for k, v in data.items()) + return dict((k, recursive(v, visitor)) for k, v in data.items()) # type: ignore[return-value] return visitor(data) - def _transform_outbound(self, data): + def _transform_outbound(self, data: T) -> T: """ Transform data types or values before they are sent by the client. @@ -4016,7 +4147,7 @@ def _outbound_visitor(value): return self._visit_data(data, _outbound_visitor) - def _transform_inbound(self, data): + def _transform_inbound(self, data: T) -> T: """ Transforms data types or values after they are received from the server. """ @@ -4052,7 +4183,7 @@ def _inbound_visitor(value): # ======================================================================== # Connection Functions - def _get_connection(self): + def _get_connection(self) -> Http: """ Return the current connection or creates a new connection to the current server. """ @@ -4081,7 +4212,7 @@ def _get_connection(self): return self._connection - def _close_connection(self): + def _close_connection(self) -> None: """ Close the current connection. """ @@ -4100,7 +4231,7 @@ def _close_connection(self): # ======================================================================== # Utility - def _parse_records(self, records): + def _parse_records(self, records: list) -> list: """ Parse 'records' returned from the api to do local modifications: @@ -4156,7 +4287,7 @@ def _parse_records(self, records): return records - def _build_thumb_url(self, entity_type, entity_id): + def _build_thumb_url(self, entity_type: str, entity_id: int) -> str: """ Return the URL for the thumbnail of an entity given the entity type and the entity id. @@ -4204,8 +4335,12 @@ def _build_thumb_url(self, entity_type, entity_id): raise RuntimeError("Unknown code %s %s" % (code, thumb_url)) def _dict_to_list( - self, d, key_name="field_name", value_name="value", extra_data=None - ): + self, + d: Optional[dict[str, Any]], + key_name: str = "field_name", + value_name: str = "value", + extra_data=None, + ) -> list[dict[str, Any]]: """ Utility function to convert a dict into a list dicts using the key_name and value_name keys. @@ -4222,7 +4357,7 @@ def _dict_to_list( ret.append(d) return ret - def _dict_to_extra_data(self, d, key_name="value"): + def _dict_to_extra_data(self, d: Optional[dict], key_name="value") -> dict: """ Utility function to convert a dict into a dict compatible with the extra_data arg of _dict_to_list. @@ -4231,7 +4366,7 @@ def _dict_to_extra_data(self, d, key_name="value"): """ return dict([(k, {key_name: v}) for (k, v) in (d or {}).items()]) - def _upload_file_to_storage(self, path, storage_url): + def _upload_file_to_storage(self, path: str, storage_url: str) -> None: """ Internal function to upload an entire file to the Cloud storage. @@ -4251,7 +4386,9 @@ def _upload_file_to_storage(self, path, storage_url): LOG.debug("File uploaded to Cloud storage: %s", filename) - def _multipart_upload_file_to_storage(self, path, upload_info): + def _multipart_upload_file_to_storage( + self, path: str, upload_info: dict[str, Any] + ) -> None: """ Internal function to upload a file to the Cloud storage in multiple parts. @@ -4293,7 +4430,9 @@ def _multipart_upload_file_to_storage(self, path, upload_info): LOG.debug("File uploaded in multiple parts to Cloud storage: %s", path) - def _get_upload_part_link(self, upload_info, filename, part_number): + def _get_upload_part_link( + self, upload_info: dict[str, Any], filename: str, part_number: int + ) -> str: """ Internal function to get the url to upload the next part of a file to the Cloud storage, in a multi-part upload process. @@ -4333,7 +4472,9 @@ def _get_upload_part_link(self, upload_info, filename, part_number): LOG.debug("Got next upload link from server for multipart upload.") return result.split("\n", 2)[1] - def _upload_data_to_storage(self, data, content_type, size, storage_url): + def _upload_data_to_storage( + self, data: BinaryIO, content_type: str, size: int, storage_url: str + ) -> str: """ Internal function to upload data to Cloud storage. @@ -4388,13 +4529,15 @@ def _upload_data_to_storage(self, data, content_type, size, storage_url): LOG.debug("Part upload completed successfully.") return etag - def _complete_multipart_upload(self, upload_info, filename, etags): + def _complete_multipart_upload( + self, upload_info: dict[str, Any], filename: str, etags: Iterable[str] + ) -> None: """ Internal function to complete a multi-part upload to the Cloud storage. :param dict upload_info: Contains details received from the server, about the upload. :param str filename: Name of the file for which we want to complete the upload. - :param tupple etags: Contains the etag of each uploaded file part. + :param tuple etags: Contains the etag of each uploaded file part. """ params = { @@ -4421,7 +4564,9 @@ def _complete_multipart_upload(self, upload_info, filename, etags): if not result.startswith("1"): raise ShotgunError("Unable get upload part link: %s" % result) - def _requires_direct_s3_upload(self, entity_type, field_name): + def _requires_direct_s3_upload( + self, entity_type: str, field_name: Optional[str] + ) -> bool: """ Internal function that determines if an entity_type + field_name combination should be uploaded to cloud storage. @@ -4462,7 +4607,7 @@ def _requires_direct_s3_upload(self, entity_type, field_name): else: return False - def _send_form(self, url, params): + def _send_form(self, url: str, params: dict[str, Any]) -> str: """ Utility function to send a Form to Shotgun and process any HTTP errors that could occur. @@ -4594,7 +4739,7 @@ def https_request(self, request): return self.http_request(request) -def _translate_filters(filters, filter_operator): +def _translate_filters(filters: Union[list, tuple], filter_operator) -> dict[str, Any]: """ Translate filters params into data structure expected by rpc call. """ @@ -4603,7 +4748,7 @@ def _translate_filters(filters, filter_operator): return _translate_filters_dict(wrapped_filters) -def _translate_filters_dict(sg_filter): +def _translate_filters_dict(sg_filter: dict[str, Any]) -> dict[str, Any]: new_filters = {} filter_operator = sg_filter.get("filter_operator") @@ -4663,14 +4808,14 @@ def _translate_filters_simple(sg_filter): return condition -def _version_str(version): +def _version_str(version: tuple[int]) -> str: """ Convert a tuple of int's to a '.' separated str. """ return ".".join(map(str, version)) -def _get_type_and_id_from_value(value): +def _get_type_and_id_from_value(value: T) -> T: """ For an entity dictionary, returns a new dictionary with only the type and id keys. If any of these keys are not present, the original dictionary is returned. From dfee11df09cbe3accb865862bc6108565da142bc Mon Sep 17 00:00:00 2001 From: Carlos Villavicencio Date: Tue, 28 Oct 2025 16:32:22 -0500 Subject: [PATCH 2/3] Add more fixes --- shotgun_api3/shotgun.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/shotgun_api3/shotgun.py b/shotgun_api3/shotgun.py index 5f98e781b..eb7f2e7bf 100644 --- a/shotgun_api3/shotgun.py +++ b/shotgun_api3/shotgun.py @@ -740,7 +740,7 @@ def __init__( self.config.user_password = None self.config.auth_token = None - def _split_url(self, base_url: str) -> tuple[str, str]: + def _split_url(self, base_url: str) -> tuple[Optional[str], Optional[str]]: """ Extract the hostname:port and username/password/token from base_url sent when connect to the API. @@ -2041,7 +2041,7 @@ def schema_entity_read( The returned display names for this method will be localized when the ``localize`` Shotgun config property is set to ``True``. See :ref:`localization` for more information. """ - params = {} + params: dict[str, Any] = {} params = self._add_project_param(params, project_entity) @@ -2115,7 +2115,7 @@ def schema_read( The returned display names for this method will be localized when the ``localize`` Shotgun config property is set to ``True``. See :ref:`localization` for more information. """ - params = {} + params: dict[str, Any] = {} params = self._add_project_param(params, project_entity) @@ -3102,7 +3102,7 @@ def get_attachment_download_url( def authenticate_human_user( self, user_login: str, user_password: str, auth_token: Optional[str] = None - ) -> dict[str, Any]: + ) -> Union[dict[str, Any], None]: """ Authenticate Shotgun HumanUser. @@ -4808,14 +4808,14 @@ def _translate_filters_simple(sg_filter): return condition -def _version_str(version: tuple[int]) -> str: +def _version_str(version) -> str: """ Convert a tuple of int's to a '.' separated str. """ return ".".join(map(str, version)) -def _get_type_and_id_from_value(value: T) -> T: +def _get_type_and_id_from_value(value): """ For an entity dictionary, returns a new dictionary with only the type and id keys. If any of these keys are not present, the original dictionary is returned. From 9fee1adcb55f4c32a647677558cd997f036a905c Mon Sep 17 00:00:00 2001 From: Carlos Villavicencio Date: Wed, 5 Nov 2025 11:59:38 -0500 Subject: [PATCH 3/3] Remove unused types --- shotgun_api3/shotgun.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/shotgun_api3/shotgun.py b/shotgun_api3/shotgun.py index eb7f2e7bf..0c0c9cd5c 100644 --- a/shotgun_api3/shotgun.py +++ b/shotgun_api3/shotgun.py @@ -55,12 +55,10 @@ BinaryIO, Iterable, Literal, - NoReturn, Optional, TypedDict, TypeVar, Union, - TYPE_CHECKING, ) # Import Error and ResponseError (even though they're unused in this file) since they need