Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions predicthq/endpoints/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from predicthq.exceptions import ValidationError

from predicthq.endpoints.schemas import ArgKwargResultSet


def _kwargs_to_key_list_mapping(kwargs, separator="__"):
"""
Expand Down Expand Up @@ -41,7 +43,7 @@ def _to_url_params(key_list_mapping, glue=".", separator=",", parent_key=""):
return params


def _to_json(key_list_mapping, json = None):
def _to_json(key_list_mapping, json=None):
"""
Converts key_list_mapping to json
"""
Expand Down Expand Up @@ -81,16 +83,15 @@ def returns(model_class):
def decorator(f):
@functools.wraps(f)
def wrapper(endpoint, *args, **kwargs):

model = getattr(endpoint.Meta, f.__name__, {}).get("returns", model_class)

data = f(endpoint, *args, **kwargs)
try:
loaded_model = model(**data)
loaded_model._more = functools.partial(wrapper, endpoint)
loaded_model._endpoint = endpoint
# This is a temporary solution to get the next page for Features API
if hasattr(loaded_model, "_kwargs"):
if isinstance(loaded_model, ArgKwargResultSet):
loaded_model._args = args
loaded_model._kwargs = kwargs
return loaded_model
except PydanticValidationError as e:
Expand Down
9 changes: 5 additions & 4 deletions predicthq/endpoints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ def get_next(self):
if not self.has_next() or not hasattr(self, "_more"):
return
params = self._parse_params(self.next)
# This is a temporary solution to get the next page for Features API
# where the post request requires a json body as well as query params
if kwargs := getattr(self, "_kwargs", {}):
return self._more(_params=params, _json=kwargs.get("_json", {}) or kwargs)
return self._more(**params)

def get_previous(self):
Expand All @@ -62,3 +58,8 @@ def iter_all(self):

def __iter__(self):
return self.iter_items()


class ArgKwargResultSet(ResultSet):
_args: Optional[dict] = None
_kwargs: Optional[dict] = None
20 changes: 16 additions & 4 deletions predicthq/endpoints/v1/beam/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime
from predicthq.endpoints.schemas import ResultSet
from predicthq.endpoints.schemas import ArgKwargResultSet
from typing import Optional, List


class BeamPaginationResultSet(ArgKwargResultSet):
def has_next(self):
return self._kwargs.get("offset", 0) + len(self.results) < self.count

def get_next(self):
if "offset" in self._kwargs:
self._kwargs["offset"] = self._kwargs.get("offset") + len(self.results)
else:
self._kwargs["offset"] = len(self.results)
return self._more(**self._kwargs)


class AllowExtra(BaseModel):
model_config: ConfigDict = ConfigDict(extra="allow")

Expand Down Expand Up @@ -103,7 +115,7 @@ class Analysis(AllowExtra):
label: Optional[List[str]] = None


class AnalysisResultSet(ResultSet):
class AnalysisResultSet(BeamPaginationResultSet):
results: List[Analysis] = Field(alias="analyses")


Expand All @@ -118,7 +130,7 @@ class FeatureImportance(AllowExtra):
feature_importance: List[FeatureGroup]


class CorrelationResultSet(ResultSet):
class CorrelationResultSet(BeamPaginationResultSet):
model_version: str
version: int
results: List[dict] = Field(alias="dates")
Expand Down Expand Up @@ -154,5 +166,5 @@ class AnalysisGroup(AllowExtra):
processed_dt: Optional[datetime] = None


class AnalysisGroupResultSet(ResultSet):
class AnalysisGroupResultSet(BeamPaginationResultSet):
results: List[AnalysisGroup] = Field(alias="groups")
11 changes: 8 additions & 3 deletions predicthq/endpoints/v1/features/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, RootModel

from predicthq.endpoints.schemas import ResultSet
from predicthq.endpoints.schemas import ArgKwargResultSet


class CsvMixin:
Expand Down Expand Up @@ -58,6 +58,11 @@ def __getattr__(self, name: str) -> Union[date, FeatureStat, FeatureRankLevel]:
return self.root[name]


class FeatureResultSet(ResultSet, CsvMixin):
_kwargs: Optional[Dict] = None # temporary solution to get the next page
class FeatureResultSet(ArgKwargResultSet, CsvMixin):
results: List[Optional[Feature]]

def get_next(self):
if not self.has_next() or not hasattr(self, "_more"):
return
params = self._parse_params(self.next)
return self._more(_params=params, _json=self._kwargs.get("_json", {}) or self._kwargs)
Loading