From e730f2c9f21b9800ee75eaab7559d29be8747181 Mon Sep 17 00:00:00 2001 From: Keegan Cordeiro Date: Fri, 17 Jan 2025 15:06:53 +1300 Subject: [PATCH 1/2] add ArgKwargResult set for custom pagination --- predicthq/endpoints/decorators.py | 9 +++++---- predicthq/endpoints/schemas.py | 9 +++++---- predicthq/endpoints/v1/beam/schemas.py | 14 ++++++++++++-- predicthq/endpoints/v1/features/schemas.py | 11 ++++++++--- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/predicthq/endpoints/decorators.py b/predicthq/endpoints/decorators.py index 5ff425b..0500eab 100644 --- a/predicthq/endpoints/decorators.py +++ b/predicthq/endpoints/decorators.py @@ -4,6 +4,8 @@ from predicthq.exceptions import ValidationError +from predicthq.endpoints.schemas import ArgKwargResultSet + def _kwargs_to_key_list_mapping(kwargs, separator="__"): """ @@ -41,7 +43,7 @@ def _to_url_params(key_list_mapping, glue=".", separator=",", parent_key=""): return params -def _to_json(key_list_mapping, json = None): +def _to_json(key_list_mapping, json=None): """ Converts key_list_mapping to json """ @@ -81,7 +83,6 @@ def returns(model_class): def decorator(f): @functools.wraps(f) def wrapper(endpoint, *args, **kwargs): - model = getattr(endpoint.Meta, f.__name__, {}).get("returns", model_class) data = f(endpoint, *args, **kwargs) @@ -89,8 +90,8 @@ def wrapper(endpoint, *args, **kwargs): loaded_model = model(**data) loaded_model._more = functools.partial(wrapper, endpoint) loaded_model._endpoint = endpoint - # This is a temporary solution to get the next page for Features API - if hasattr(loaded_model, "_kwargs"): + if isinstance(loaded_model, ArgKwargResultSet): + loaded_model._args = args loaded_model._kwargs = kwargs return loaded_model except PydanticValidationError as e: diff --git a/predicthq/endpoints/schemas.py b/predicthq/endpoints/schemas.py index b66124b..a3066eb 100644 --- a/predicthq/endpoints/schemas.py +++ b/predicthq/endpoints/schemas.py @@ -32,10 +32,6 @@ def get_next(self): if not self.has_next() or not hasattr(self, "_more"): return params = self._parse_params(self.next) - # This is a temporary solution to get the next page for Features API - # where the post request requires a json body as well as query params - if kwargs := getattr(self, "_kwargs", {}): - return self._more(_params=params, _json=kwargs.get("_json", {}) or kwargs) return self._more(**params) def get_previous(self): @@ -62,3 +58,8 @@ def iter_all(self): def __iter__(self): return self.iter_items() + + +class ArgKwargResultSet(ResultSet): + _args: Optional[dict] = None + _kwargs: Optional[dict] = None diff --git a/predicthq/endpoints/v1/beam/schemas.py b/predicthq/endpoints/v1/beam/schemas.py index e1b65a9..b0d8b90 100644 --- a/predicthq/endpoints/v1/beam/schemas.py +++ b/predicthq/endpoints/v1/beam/schemas.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field, ConfigDict from datetime import datetime -from predicthq.endpoints.schemas import ResultSet +from predicthq.endpoints.schemas import ArgKwargResultSet, ResultSet from typing import Optional, List @@ -103,9 +103,19 @@ class Analysis(AllowExtra): label: Optional[List[str]] = None -class AnalysisResultSet(ResultSet): +class AnalysisResultSet(ArgKwargResultSet): results: List[Analysis] = Field(alias="analyses") + def has_next(self): + return self._kwargs.get("offset", 0) + len(self.results) < self.count + + def get_next(self): + if "offset" in self._kwargs: + self._kwargs["offset"] = self._kwargs.get("offset") + len(self.results) + else: + self._kwargs["offset"] = len(self.results) + return self._more(**self._kwargs) + class FeatureGroup(AllowExtra): feature_group: str diff --git a/predicthq/endpoints/v1/features/schemas.py b/predicthq/endpoints/v1/features/schemas.py index cdbb32a..2ea0da0 100644 --- a/predicthq/endpoints/v1/features/schemas.py +++ b/predicthq/endpoints/v1/features/schemas.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, RootModel -from predicthq.endpoints.schemas import ResultSet +from predicthq.endpoints.schemas import ArgKwargResultSet class CsvMixin: @@ -58,6 +58,11 @@ def __getattr__(self, name: str) -> Union[date, FeatureStat, FeatureRankLevel]: return self.root[name] -class FeatureResultSet(ResultSet, CsvMixin): - _kwargs: Optional[Dict] = None # temporary solution to get the next page +class FeatureResultSet(ArgKwargResultSet, CsvMixin): results: List[Optional[Feature]] + + def get_next(self): + if not self.has_next() or not hasattr(self, "_more"): + return + params = self._parse_params(self.next) + return self._more(_params=params, _json=self._kwargs.get("_json", {}) or self._kwargs) From 25bd6261d732fec8deaadf78a84f376ba1a24cde Mon Sep 17 00:00:00 2001 From: Keegan Cordeiro Date: Fri, 17 Jan 2025 15:25:19 +1300 Subject: [PATCH 2/2] use custom pagination on all supported beam endpoints --- predicthq/endpoints/v1/beam/schemas.py | 30 ++++++++++++++------------ 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/predicthq/endpoints/v1/beam/schemas.py b/predicthq/endpoints/v1/beam/schemas.py index b0d8b90..12d056d 100644 --- a/predicthq/endpoints/v1/beam/schemas.py +++ b/predicthq/endpoints/v1/beam/schemas.py @@ -1,9 +1,21 @@ from pydantic import BaseModel, Field, ConfigDict from datetime import datetime -from predicthq.endpoints.schemas import ArgKwargResultSet, ResultSet +from predicthq.endpoints.schemas import ArgKwargResultSet from typing import Optional, List +class BeamPaginationResultSet(ArgKwargResultSet): + def has_next(self): + return self._kwargs.get("offset", 0) + len(self.results) < self.count + + def get_next(self): + if "offset" in self._kwargs: + self._kwargs["offset"] = self._kwargs.get("offset") + len(self.results) + else: + self._kwargs["offset"] = len(self.results) + return self._more(**self._kwargs) + + class AllowExtra(BaseModel): model_config: ConfigDict = ConfigDict(extra="allow") @@ -103,19 +115,9 @@ class Analysis(AllowExtra): label: Optional[List[str]] = None -class AnalysisResultSet(ArgKwargResultSet): +class AnalysisResultSet(BeamPaginationResultSet): results: List[Analysis] = Field(alias="analyses") - def has_next(self): - return self._kwargs.get("offset", 0) + len(self.results) < self.count - - def get_next(self): - if "offset" in self._kwargs: - self._kwargs["offset"] = self._kwargs.get("offset") + len(self.results) - else: - self._kwargs["offset"] = len(self.results) - return self._more(**self._kwargs) - class FeatureGroup(AllowExtra): feature_group: str @@ -128,7 +130,7 @@ class FeatureImportance(AllowExtra): feature_importance: List[FeatureGroup] -class CorrelationResultSet(ResultSet): +class CorrelationResultSet(BeamPaginationResultSet): model_version: str version: int results: List[dict] = Field(alias="dates") @@ -164,5 +166,5 @@ class AnalysisGroup(AllowExtra): processed_dt: Optional[datetime] = None -class AnalysisGroupResultSet(ResultSet): +class AnalysisGroupResultSet(BeamPaginationResultSet): results: List[AnalysisGroup] = Field(alias="groups")