diff --git a/sdk/python/feast/infra/registry/proto_registry_utils.py b/sdk/python/feast/infra/registry/proto_registry_utils.py index c7eeea0f82..e93f513b69 100644 --- a/sdk/python/feast/infra/registry/proto_registry_utils.py +++ b/sdk/python/feast/infra/registry/proto_registry_utils.py @@ -1,4 +1,5 @@ import uuid +from functools import wraps from typing import List, Optional from feast import usage @@ -23,6 +24,26 @@ from feast.stream_feature_view import StreamFeatureView +def registry_proto_cache(func): + cache_key = None + cache_value = None + + @wraps(func) + def wrapper(registry_proto: RegistryProto, project: str): + nonlocal cache_key, cache_value + + key = tuple([id(registry_proto), registry_proto.version_id, project]) + + if key == cache_key: + return cache_value + else: + cache_value = func(registry_proto, project) + cache_key = key + return cache_value + + return wrapper + + def init_project_metadata(cached_registry_proto: RegistryProto, project: str): new_project_uuid = f"{uuid.uuid4()}" usage.set_current_project_uuid(new_project_uuid) @@ -137,8 +158,9 @@ def get_validation_reference( raise ValidationReferenceNotFound(name, project=project) +@registry_proto_cache def list_feature_services( - registry_proto: RegistryProto, project: str, allow_cache: bool = False + registry_proto: RegistryProto, project: str ) -> List[FeatureService]: feature_services = [] for feature_service_proto in registry_proto.feature_services: @@ -147,6 +169,7 @@ def list_feature_services( return feature_services +@registry_proto_cache def list_feature_views( registry_proto: RegistryProto, project: str ) -> List[FeatureView]: @@ -157,6 +180,7 @@ def list_feature_views( return feature_views +@registry_proto_cache def list_request_feature_views( registry_proto: RegistryProto, project: str ) -> List[RequestFeatureView]: @@ -169,6 +193,7 @@ def list_request_feature_views( return feature_views +@registry_proto_cache def list_stream_feature_views( registry_proto: RegistryProto, project: str ) -> List[StreamFeatureView]: @@ -181,6 +206,7 @@ def list_stream_feature_views( return stream_feature_views +@registry_proto_cache def list_on_demand_feature_views( registry_proto: RegistryProto, project: str ) -> List[OnDemandFeatureView]: @@ -193,6 +219,7 @@ def list_on_demand_feature_views( return on_demand_feature_views +@registry_proto_cache def list_entities(registry_proto: RegistryProto, project: str) -> List[Entity]: entities = [] for entity_proto in registry_proto.entities: @@ -201,6 +228,7 @@ def list_entities(registry_proto: RegistryProto, project: str) -> List[Entity]: return entities +@registry_proto_cache def list_data_sources(registry_proto: RegistryProto, project: str) -> List[DataSource]: data_sources = [] for data_source_proto in registry_proto.data_sources: @@ -209,6 +237,7 @@ def list_data_sources(registry_proto: RegistryProto, project: str) -> List[DataS return data_sources +@registry_proto_cache def list_saved_datasets( registry_proto: RegistryProto, project: str ) -> List[SavedDataset]: @@ -219,6 +248,7 @@ def list_saved_datasets( return saved_datasets +@registry_proto_cache def list_validation_references( registry_proto: RegistryProto, project: str ) -> List[ValidationReference]: @@ -231,6 +261,7 @@ def list_validation_references( return validation_references +@registry_proto_cache def list_project_metadata( registry_proto: RegistryProto, project: str ) -> List[ProjectMetadata]: