From a124470fb883c932dd5dba7942ba6c3641015c87 Mon Sep 17 00:00:00 2001 From: Vaghinak Basentsyan Date: Mon, 14 Feb 2022 09:40:37 +0400 Subject: [PATCH] Changed init flow --- src/superannotate/__init__.py | 3 +- src/superannotate/lib/__init__.py | 20 ++- .../lib/app/interface/base_interface.py | 8 +- .../lib/app/interface/sdk_interface.py | 7 +- .../lib/infrastructure/controller.py | 126 ++++++------------ .../test_annotations_pre_processing.py | 2 +- tests/unit/test_controller_init.py | 34 +++-- 7 files changed, 91 insertions(+), 109 deletions(-) diff --git a/src/superannotate/__init__.py b/src/superannotate/__init__.py index c8ce59082..42057385a 100644 --- a/src/superannotate/__init__.py +++ b/src/superannotate/__init__.py @@ -4,8 +4,8 @@ import requests import superannotate.lib.core as constances +from lib import get_default_controller from packaging.version import parse -from superannotate.lib import get_default_controller from superannotate.lib.app.analytics.class_analytics import class_distribution from superannotate.lib.app.exceptions import AppException from superannotate.lib.app.input_converters.conversion import convert_json_version @@ -112,6 +112,7 @@ controller = get_default_controller() + __all__ = [ "__version__", "controller", diff --git a/src/superannotate/lib/__init__.py b/src/superannotate/lib/__init__.py index 89f5863a8..c4db8ec42 100644 --- a/src/superannotate/lib/__init__.py +++ b/src/superannotate/lib/__init__.py @@ -5,12 +5,22 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -controller = None +DEFAULT_CONTROLLER = None -def get_default_controller(): + +def get_default_controller(raise_exception=False): from lib.infrastructure.controller import Controller + try: + global DEFAULT_CONTROLLER + if not DEFAULT_CONTROLLER: + DEFAULT_CONTROLLER = Controller() + return DEFAULT_CONTROLLER + except Exception: + if raise_exception: + raise + - global controller - controller = Controller() - return controller +def set_default_controller(controller_obj): + # global DEFAULT_CONTROLLER + DEFAULT_CONTROLLER = controller_obj diff --git a/src/superannotate/lib/app/interface/base_interface.py b/src/superannotate/lib/app/interface/base_interface.py index 2a0381430..ba9fc8e2b 100644 --- a/src/superannotate/lib/app/interface/base_interface.py +++ b/src/superannotate/lib/app/interface/base_interface.py @@ -1,16 +1,14 @@ -from lib.infrastructure.controller import Controller +from lib import get_default_controller from lib.infrastructure.repositories import ConfigRepository class BaseInterfaceFacade: def __init__(self): self._config_path = None + self._controller = get_default_controller() @property def controller(self): if not ConfigRepository().get_one("token"): raise Exception("Config does not exists!") - controller = Controller() - if self._config_path: - controller.init(config_path=self._config_path) - return controller + return self._controller diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index 02079eb8b..32e5ff1ce 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -11,8 +11,8 @@ from typing import Union import boto3 +import lib import lib.core as constances -from lib import controller from lib.app.annotation_helpers import add_annotation_bbox_to_json from lib.app.annotation_helpers import add_annotation_comment_to_json from lib.app.annotation_helpers import add_annotation_point_to_json @@ -42,6 +42,7 @@ from lib.core.types import ClassesJson from lib.core.types import MLModel from lib.core.types import Project +from lib.infrastructure.controller import Controller from pydantic import conlist from pydantic import parse_obj_as from pydantic import StrictBool @@ -50,6 +51,7 @@ logger = get_default_logger() +controller = lib.DEFAULT_CONTROLLER @validate_arguments @@ -65,8 +67,7 @@ def init(path_to_config_json: Optional[str] = None, token: str = None): :param token: Team token :type token: str """ - global controller - controller.init(config_path=path_to_config_json, token=token) + lib.DEFAULT_CONTROLLER = Controller(config_path=path_to_config_json, token=token) @validate_arguments diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index 63f35ef25..c28430cd0 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -1,12 +1,12 @@ import copy import io import os +from abc import ABCMeta from os.path import expanduser from pathlib import Path from typing import Iterable from typing import List from typing import Optional -from typing import Tuple from typing import Union import lib.core as constances @@ -37,21 +37,7 @@ from superannotate_schemas.validators import AnnotationValidators -class SingleInstanceMetaClass(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in SingleInstanceMetaClass._instances: - SingleInstanceMetaClass._instances[cls] = super().__call__(*args, **kwargs) - return SingleInstanceMetaClass._instances[cls] - - def get_instance(cls): - if cls._instances: - return cls._instances[cls] - return cls() - - -class BaseController(metaclass=SingleInstanceMetaClass): +class BaseController(metaclass=ABCMeta): def __init__(self, config_path: str, token: str = None): self._team_data = None self._token = None @@ -69,25 +55,26 @@ def __init__(self, config_path: str, token: str = None): self._user_id = None self._team_name = None self._reporter = None - self._ssl_verify = not os.environ.get("SA_TESTING", False) - self._config_path = expanduser(config_path) if config_path else constances.CONFIG_FILE_LOCATION + self._ssl_verify = not (os.environ.get("SA_TESTING", "False").lower() == "false") self._backend_url = os.environ.get("SA_URL", constances.BACKEND_URL) - if not token and not config_path: + + if token: + self._token = self._validate_token(os.environ.get("SA_TOKEN")) + elif config_path: + config_path = expanduser(config_path) + self.retrieve_configs(Path(config_path), raise_exception=True) + else: env_token = os.environ.get("SA_TOKEN") if env_token: self._token = self._validate_token(os.environ.get("SA_TOKEN")) - if token: - self._token = self._validate_token(token) - - if not self._token: - self._token, self._backend_url, self._ssl_verify = self.retrieve_configs( - Path(self._config_path), raise_exception=False - ) + else: + config_path = expanduser(constances.CONFIG_FILE_LOCATION) + self.retrieve_configs(Path(config_path), raise_exception=False) + self.initialize_backend_client() - def retrieve_configs( - self, path: Path, raise_exception=True - ) -> Tuple[Optional[str], Optional[str], Optional[str]]: + def retrieve_configs(self, path: Path, raise_exception=True): + token, backend_url, ssl_verify = None, None, None if not path.is_file() or not os.access(path, os.R_OK): if raise_exception: raise AppException( @@ -97,7 +84,7 @@ def retrieve_configs( ) try: config_repo = ConfigRepository(str(path)) - return ( + token, backend_url, ssl_verify = ( self._validate_token(config_repo.get_one("token").value), config_repo.get_one("main_endpoint").value, config_repo.get_one("ssl_verify").value, @@ -107,54 +94,36 @@ def retrieve_configs( raise AppException( f"Incorrect config file: token is not present in the config file {path}" ) - return None, None, None + self._token = token + self._backend_url = backend_url or self._backend_url + self._ssl_verify = ssl_verify or self._ssl_verify - def init( - self, token: str = None, backend_url: str = None, config_path: str = None, - ): - if backend_url: - self._backend_url = backend_url - if token: - if self._validate_token(token): - self._token = token - return self - else: - raise AppException("Invalid token.") - if not config_path: - raise AppException( - " Please provide correct config file location to sa.init()." - ) - self._config_path = config_path - self._token, self._backend_url, ssl_verify = self.retrieve_configs( - Path(config_path), raise_exception=True + @staticmethod + def _validate_token(token: str): + try: + int(token.split("=")[-1]) + except ValueError: + raise AppException("Invalid token.") + return token + + def initialize_backend_client(self): + if not self._token: + raise AppException("Team token not provided") + self._backend_client = SuperannotateBackendService( + api_url=self._backend_url, + auth_token=self._token, + logger=self._logger, + verify_ssl=self._ssl_verify, ) - self._ssl_verify = ssl_verify - return self + self._backend_client.get_session.cache_clear() + return self._backend_client @property def backend_client(self): if not self._backend_client: - if not self._token: - raise AppException("Team token not provided") - self._backend_client = SuperannotateBackendService( - api_url=self._backend_url, - auth_token=self._token, - logger=self._logger, - verify_ssl=self._ssl_verify, - ) - self._backend_client._api_url = self._backend_url - self._backend_client._auth_token = self._token - self._backend_client.get_session.cache_clear() + self.initialize_backend_client() return self._backend_client - @staticmethod - def is_valid_token(token: str): - return int(token.split("=")[-1]) - - @property - def config_path(self): - return self._config_path - @property def user_id(self): if not self._user_id: @@ -167,18 +136,11 @@ def team_name(self): _, self._team_name = self.get_team() return self._team_name - @staticmethod - def _validate_token(token: str): - try: - int(token.split("=")[-1]) - except ValueError: - raise AppException("Invalid token.") - return token - def set_token(self, token: str, backend_url: str = constances.BACKEND_URL): self._token = self._validate_token(token) - self._backend_url = backend_url - self._backend_client = self.backend_client + if backend_url: + self._backend_url = backend_url + self.initialize_backend_client() @property def projects(self): @@ -262,8 +224,8 @@ def annotation_validators(self) -> AnnotationValidators: class Controller(BaseController): - def __init__(self, config_path: str = None): - super().__init__(config_path) + def __init__(self, config_path: str = None, token: str = None): + super().__init__(config_path, token) self._team = None def _get_project(self, name: str): diff --git a/tests/integration/annotations/test_annotations_pre_processing.py b/tests/integration/annotations/test_annotations_pre_processing.py index cc1328b89..45d485384 100644 --- a/tests/integration/annotations/test_annotations_pre_processing.py +++ b/tests/integration/annotations/test_annotations_pre_processing.py @@ -45,7 +45,7 @@ def test_annotation_last_action_and_creation_type(self, reporter): self.assertEqual(instance["creationType"], CreationTypeEnum.PRE_ANNOTATION.value) self.assertEqual( type(annotation["metadata"]["lastAction"]["email"]), - type(sa.controller.team_data.data.creator_id) + type(sa.get_default_controller().team_data.data.creator_id) ) self.assertEqual( type(annotation["metadata"]["lastAction"]["timestamp"]), diff --git a/tests/unit/test_controller_init.py b/tests/unit/test_controller_init.py index 3c065011a..fdfb16c3e 100644 --- a/tests/unit/test_controller_init.py +++ b/tests/unit/test_controller_init.py @@ -1,3 +1,4 @@ +import os from os.path import join import json import pkg_resources @@ -20,21 +21,18 @@ class CLITest(TestCase): + CONFIG_FILE_DATA = '{"main_endpoint": "https://amazonaws.com:3000","token": "c9c55ct=6085","ssl_verify": false}' + # @pytest.mark.skip(reason="Need to adjust") - @pytest.mark.skip(reason="Need to adjust") @patch('builtins.input') def test_init_update(self, input_mock): input_mock.side_effect = ["y", "token"] - with open(CONFIG_FILE_LOCATION) as f: - config_data = f.read() - with patch('builtins.open', mock_open(read_data=config_data)) as config_file: + with patch('builtins.open', mock_open(read_data=self.CONFIG_FILE_DATA)) as config_file: try: with catch_prints() as out: cli = CLIFacade() cli.init() except SystemExit: - input_args = [i[0][0] for i in input_mock.call_args_list] - self.assertIn(f"File {CONFIG_FILE_LOCATION} exists. Do you want to overwrite? [y/n] : ", input_args) input_mock.assert_called_with("Input the team SDK token from https://app.superannotate.com/team : ") config_file().write.assert_called_once_with( json.dumps( @@ -65,6 +63,8 @@ def test_init_create(self, input_mock): class SKDInitTest(TestCase): + TEST_TOKEN = "toke=123" + VALID_JSON = { "token": "a"*28 + "=1234" } @@ -74,15 +74,28 @@ class SKDInitTest(TestCase): FILE_NAME = "config.json" FILE_NAME_2 = "config.json" - def test_init_flow(self): + @patch.dict(os.environ, {"SA_TOKEN": TEST_TOKEN}) + def test_env_flow(self): + import superannotate as sa + self.assertEqual(sa.get_default_controller()._token, self.TEST_TOKEN) + + def test_init_via_config_file(self): with tempfile.TemporaryDirectory() as temp_dir: token_path = f"{temp_dir}/config.json" with open(token_path, "w") as temp_config: - json.dump({"token": "token=1234"}, temp_config) + json.dump({"token": self.TEST_TOKEN}, temp_config) temp_config.close() import src.superannotate as sa sa.init(token_path) + @patch("lib.infrastructure.controller.Controller.retrieve_configs") + def test_init_default_configs_open(self, retrieve_configs): + import src.superannotate as sa + try: + sa.init() + except Exception: + self.assertTrue(retrieve_configs.call_args[0], sa.constances.CONFIG_FILE_LOCATION) + def test_init(self): with tempfile.TemporaryDirectory() as temp_dir: path = join(temp_dir, self.FILE_NAME) @@ -90,7 +103,4 @@ def test_init(self): json.dump(self.VALID_JSON, config) import src.superannotate as sa sa.init(path) - self.assertEqual(sa.controller.team_id, 1234) - - def test_(self): - import superannotate as sa \ No newline at end of file + self.assertEqual(sa.get_default_controller().team_id, 1234)