Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/superannotate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import requests
import superannotate.lib.core as constances
from lib import get_default_controller
from packaging.version import parse
from superannotate.lib import get_default_controller
from superannotate.lib.app.analytics.class_analytics import class_distribution
from superannotate.lib.app.exceptions import AppException
from superannotate.lib.app.input_converters.conversion import convert_json_version
Expand Down Expand Up @@ -112,6 +112,7 @@

controller = get_default_controller()


__all__ = [
"__version__",
"controller",
Expand Down
20 changes: 15 additions & 5 deletions src/superannotate/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

controller = None

DEFAULT_CONTROLLER = None

def get_default_controller():

def get_default_controller(raise_exception=False):
from lib.infrastructure.controller import Controller
try:
global DEFAULT_CONTROLLER
if not DEFAULT_CONTROLLER:
DEFAULT_CONTROLLER = Controller()
return DEFAULT_CONTROLLER
except Exception:
if raise_exception:
raise


global controller
controller = Controller()
return controller
def set_default_controller(controller_obj):
# global DEFAULT_CONTROLLER
DEFAULT_CONTROLLER = controller_obj
8 changes: 3 additions & 5 deletions src/superannotate/lib/app/interface/base_interface.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from lib.infrastructure.controller import Controller
from lib import get_default_controller
from lib.infrastructure.repositories import ConfigRepository


class BaseInterfaceFacade:
def __init__(self):
self._config_path = None
self._controller = get_default_controller()

@property
def controller(self):
if not ConfigRepository().get_one("token"):
raise Exception("Config does not exists!")
controller = Controller()
if self._config_path:
controller.init(config_path=self._config_path)
return controller
return self._controller
7 changes: 4 additions & 3 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from typing import Union

import boto3
import lib
import lib.core as constances
from lib import controller
from lib.app.annotation_helpers import add_annotation_bbox_to_json
from lib.app.annotation_helpers import add_annotation_comment_to_json
from lib.app.annotation_helpers import add_annotation_point_to_json
Expand Down Expand Up @@ -41,6 +41,7 @@
from lib.core.types import AttributeGroup
from lib.core.types import MLModel
from lib.core.types import Project
from lib.infrastructure.controller import Controller
from pydantic import conlist
from pydantic import parse_obj_as
from pydantic import StrictBool
Expand All @@ -49,6 +50,7 @@


logger = get_default_logger()
controller = lib.DEFAULT_CONTROLLER


@validate_arguments
Expand All @@ -64,8 +66,7 @@ def init(path_to_config_json: Optional[str] = None, token: str = None):
:param token: Team token
:type token: str
"""
global controller
controller.init(config_path=path_to_config_json, token=token)
lib.DEFAULT_CONTROLLER = Controller(config_path=path_to_config_json, token=token)


@validate_arguments
Expand Down
129 changes: 45 additions & 84 deletions src/superannotate/lib/infrastructure/controller.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import copy
import io
import os
from abc import ABCMeta
from os.path import expanduser
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import lib.core as constances
Expand Down Expand Up @@ -37,21 +37,7 @@
from superannotate_schemas.validators import AnnotationValidators


class SingleInstanceMetaClass(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in SingleInstanceMetaClass._instances:
SingleInstanceMetaClass._instances[cls] = super().__call__(*args, **kwargs)
return SingleInstanceMetaClass._instances[cls]

def get_instance(cls):
if cls._instances:
return cls._instances[cls]
return cls()


class BaseController(metaclass=SingleInstanceMetaClass):
class BaseController(metaclass=ABCMeta):
def __init__(self, config_path: str, token: str = None):
self._team_data = None
self._token = None
Expand All @@ -69,27 +55,26 @@ def __init__(self, config_path: str, token: str = None):
self._user_id = None
self._team_name = None
self._reporter = None
self._ssl_verify = not os.environ.get("SA_TESTING", False)
self._config_path = (
expanduser(config_path) if config_path else constances.CONFIG_FILE_LOCATION
)
self._ssl_verify = not (os.environ.get("SA_TESTING", "False").lower() == "false")

self._backend_url = os.environ.get("SA_URL", constances.BACKEND_URL)
if not token and not config_path:

if token:
self._token = self._validate_token(os.environ.get("SA_TOKEN"))
elif config_path:
config_path = expanduser(config_path)
self.retrieve_configs(Path(config_path), raise_exception=True)
else:
env_token = os.environ.get("SA_TOKEN")
if env_token:
self._token = self._validate_token(os.environ.get("SA_TOKEN"))
if token:
self._token = self._validate_token(token)

if not self._token:
self._token, self._backend_url, self._ssl_verify = self.retrieve_configs(
Path(self._config_path), raise_exception=False
)
else:
config_path = expanduser(constances.CONFIG_FILE_LOCATION)
self.retrieve_configs(Path(config_path), raise_exception=False)
self.initialize_backend_client()

def retrieve_configs(
self, path: Path, raise_exception=True
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
def retrieve_configs(self, path: Path, raise_exception=True):
token, backend_url, ssl_verify = None, None, None
if not path.is_file() or not os.access(path, os.R_OK):
if raise_exception:
raise AppException(
Expand All @@ -99,7 +84,7 @@ def retrieve_configs(
)
try:
config_repo = ConfigRepository(str(path))
return (
token, backend_url, ssl_verify = (
self._validate_token(config_repo.get_one("token").value),
config_repo.get_one("main_endpoint").value,
config_repo.get_one("ssl_verify").value,
Expand All @@ -109,54 +94,36 @@ def retrieve_configs(
raise AppException(
f"Incorrect config file: token is not present in the config file {path}"
)
return None, None, None
self._token = token
self._backend_url = backend_url or self._backend_url
self._ssl_verify = ssl_verify or self._ssl_verify

def init(
self, token: str = None, backend_url: str = None, config_path: str = None,
):
if backend_url:
self._backend_url = backend_url
if token:
if self._validate_token(token):
self._token = token
return self
else:
raise AppException("Invalid token.")
if not config_path:
raise AppException(
" Please provide correct config file location to sa.init(<path>)."
)
self._config_path = config_path
self._token, self._backend_url, ssl_verify = self.retrieve_configs(
Path(config_path), raise_exception=True
@staticmethod
def _validate_token(token: str):
try:
int(token.split("=")[-1])
except ValueError:
raise AppException("Invalid token.")
return token

def initialize_backend_client(self):
if not self._token:
raise AppException("Team token not provided")
self._backend_client = SuperannotateBackendService(
api_url=self._backend_url,
auth_token=self._token,
logger=self._logger,
verify_ssl=self._ssl_verify,
)
self._ssl_verify = ssl_verify
return self
self._backend_client.get_session.cache_clear()
return self._backend_client

@property
def backend_client(self):
if not self._backend_client:
if not self._token:
raise AppException("Team token not provided")
self._backend_client = SuperannotateBackendService(
api_url=self._backend_url,
auth_token=self._token,
logger=self._logger,
verify_ssl=self._ssl_verify,
)
self._backend_client._api_url = self._backend_url
self._backend_client._auth_token = self._token
self._backend_client.get_session.cache_clear()
self.initialize_backend_client()
return self._backend_client

@staticmethod
def is_valid_token(token: str):
return int(token.split("=")[-1])

@property
def config_path(self):
return self._config_path

@property
def user_id(self):
if not self._user_id:
Expand All @@ -169,18 +136,11 @@ def team_name(self):
_, self._team_name = self.get_team()
return self._team_name

@staticmethod
def _validate_token(token: str):
try:
int(token.split("=")[-1])
except ValueError:
raise AppException("Invalid token.")
return token

def set_token(self, token: str, backend_url: str = constances.BACKEND_URL):
self._token = self._validate_token(token)
self._backend_url = backend_url
self._backend_client = self.backend_client
if backend_url:
self._backend_url = backend_url
self.initialize_backend_client()

@property
def projects(self):
Expand Down Expand Up @@ -263,8 +223,9 @@ def annotation_validators(self) -> AnnotationValidators:


class Controller(BaseController):
def __init__(self, config_path: str = None):
super().__init__(config_path)

def __init__(self, config_path: str = None, token: str = None):
super().__init__(config_path, token)
self._team = None

def _get_project(self, name: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_annotation_last_action_and_creation_type(self, reporter):
self.assertEqual(instance["creationType"], CreationTypeEnum.PRE_ANNOTATION.value)
self.assertEqual(
type(annotation["metadata"]["lastAction"]["email"]),
type(sa.controller.team_data.data.creator_id)
type(sa.get_default_controller().team_data.data.creator_id)
)
self.assertEqual(
type(annotation["metadata"]["lastAction"]["timestamp"]),
Expand Down
34 changes: 22 additions & 12 deletions tests/unit/test_controller_init.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from os.path import join
import json
import pkg_resources
Expand All @@ -20,21 +21,18 @@


class CLITest(TestCase):
CONFIG_FILE_DATA = '{"main_endpoint": "https://amazonaws.com:3000","token": "c9c55ct=6085","ssl_verify": false}'
# @pytest.mark.skip(reason="Need to adjust")

@pytest.mark.skip(reason="Need to adjust")
@patch('builtins.input')
def test_init_update(self, input_mock):
input_mock.side_effect = ["y", "token"]
with open(CONFIG_FILE_LOCATION) as f:
config_data = f.read()
with patch('builtins.open', mock_open(read_data=config_data)) as config_file:
with patch('builtins.open', mock_open(read_data=self.CONFIG_FILE_DATA)) as config_file:
try:
with catch_prints() as out:
cli = CLIFacade()
cli.init()
except SystemExit:
input_args = [i[0][0] for i in input_mock.call_args_list]
self.assertIn(f"File {CONFIG_FILE_LOCATION} exists. Do you want to overwrite? [y/n] : ", input_args)
input_mock.assert_called_with("Input the team SDK token from https://app.superannotate.com/team : ")
config_file().write.assert_called_once_with(
json.dumps(
Expand Down Expand Up @@ -65,6 +63,8 @@ def test_init_create(self, input_mock):


class SKDInitTest(TestCase):
TEST_TOKEN = "toke=123"

VALID_JSON = {
"token": "a"*28 + "=1234"
}
Expand All @@ -74,23 +74,33 @@ class SKDInitTest(TestCase):
FILE_NAME = "config.json"
FILE_NAME_2 = "config.json"

def test_init_flow(self):
@patch.dict(os.environ, {"SA_TOKEN": TEST_TOKEN})
def test_env_flow(self):
import superannotate as sa
self.assertEqual(sa.get_default_controller()._token, self.TEST_TOKEN)

def test_init_via_config_file(self):
with tempfile.TemporaryDirectory() as temp_dir:
token_path = f"{temp_dir}/config.json"
with open(token_path, "w") as temp_config:
json.dump({"token": "token=1234"}, temp_config)
json.dump({"token": self.TEST_TOKEN}, temp_config)
temp_config.close()
import src.superannotate as sa
sa.init(token_path)

@patch("lib.infrastructure.controller.Controller.retrieve_configs")
def test_init_default_configs_open(self, retrieve_configs):
import src.superannotate as sa
try:
sa.init()
except Exception:
self.assertTrue(retrieve_configs.call_args[0], sa.constances.CONFIG_FILE_LOCATION)

def test_init(self):
with tempfile.TemporaryDirectory() as temp_dir:
path = join(temp_dir, self.FILE_NAME)
with open(path, "w") as config:
json.dump(self.VALID_JSON, config)
import src.superannotate as sa
sa.init(path)
self.assertEqual(sa.controller.team_id, 1234)

def test_(self):
import superannotate as sa
self.assertEqual(sa.get_default_controller().team_id, 1234)