diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index 0bdcbfcba9..62151a3bf5 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -21,7 +21,7 @@ from starwhale.utils.fs import ensure_dir from starwhale.consts.env import SWEnv -from starwhale.utils.error import MissingFieldError +from starwhale.utils.error import MissingFieldError, FieldTypeOrValueError from starwhale.utils.retry import http_retry from starwhale.utils.config import SWCliConfigMixed @@ -900,11 +900,15 @@ def dump(self) -> None: class RemoteDataStore: - def __init__(self, instance_uri: str, token: str = "") -> None: + def __init__(self, instance_uri: str, token: str) -> None: + if not instance_uri: + raise FieldTypeOrValueError("instance_uri not set") + + if not token: + raise FieldTypeOrValueError("token not set") + self.instance_uri = instance_uri - self.token = token or os.getenv(SWEnv.instance_token) - if self.token is None: - raise RuntimeError("SW_TOKEN is not found in environment") + self.token = token @http_retry def update_table( @@ -1033,15 +1037,19 @@ def scan_tables( ... -def get_data_store(instance_uri: str = "") -> DataStore: +def get_data_store(instance_uri: str = "", token: str = "") -> DataStore: _instance_uri = instance_uri or os.getenv(SWEnv.instance_uri) if _instance_uri is None or _instance_uri == "local": return LocalDataStore.get_instance() else: - print(f"instance:{instance_uri}") + token = ( + token + or SWCliConfigMixed().get_sw_token(instance=instance_uri) + or os.getenv(SWEnv.instance_token, "") + ) return RemoteDataStore( instance_uri=_instance_uri, - token=SWCliConfigMixed().get_sw_token(instance=instance_uri), + token=token, ) diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index 4cdd749591..e04221062a 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -1,4 +1,3 @@ -import os import re import threading from typing import Any, Dict, List, Union, Iterator, Optional @@ -6,7 +5,6 @@ from loguru import logger from starwhale.consts import VERSION_PREFIX_CNT -from starwhale.consts.env import SWEnv from . import data_store @@ -47,20 +45,18 @@ def _log(self, table_name: str, record: Dict[str, Any]) -> None: class Evaluation(Logger): - def __init__(self, eval_id: str = "", project: str = "", instance: str = ""): - eval_id = eval_id or os.getenv(SWEnv.eval_version, "") + def __init__(self, eval_id: str, project: str, instance: str = ""): if not eval_id: raise RuntimeError("eval id should not be None") if re.match(r"^[A-Za-z0-9-_]+$", eval_id) is None: raise RuntimeError( f"invalid eval id {eval_id}, only letters(A-Z, a-z), digits(0-9), hyphen('-'), and underscore('_') are allowed" ) - self.eval_id = eval_id - - self.project = project or os.getenv(SWEnv.project, "") - if not self.project: - raise RuntimeError(f"{SWEnv.project} is not set") + if not project: + raise RuntimeError("project is not set") + self.eval_id = eval_id + self.project = project self._results_table_name = self._get_datastore_table_name("results") self._summary_table_name = f"project/{self.project}/eval/summary" self._init_writers([self._results_table_name, self._summary_table_name]) @@ -117,15 +113,15 @@ def get(self, table_name: str) -> Iterator[Dict[str, Any]]: class Dataset(Logger): - def __init__(self, dataset_id: str, project: str = "") -> None: + def __init__(self, dataset_id: str, project: str) -> None: if not dataset_id: raise RuntimeError("id should not be None") - self.dataset_id = dataset_id - self.project = project or os.getenv(SWEnv.project) - if not self.project: + if not project: raise RuntimeError("project is not set") + self.dataset_id = dataset_id + self.project = project self._meta_table_name = f"project/{self.project}/dataset/{self.dataset_id}/meta" self._data_store = data_store.get_data_store() self._init_writers([self._meta_table_name]) diff --git a/client/tests/sdk/test_data_store.py b/client/tests/sdk/test_data_store.py index 4deb835063..6942fdeb86 100644 --- a/client/tests/sdk/test_data_store.py +++ b/client/tests/sdk/test_data_store.py @@ -1083,8 +1083,7 @@ def test_data_store_scan(self) -> None: class TestRemoteDataStore(unittest.TestCase): def setUp(self) -> None: - os.environ["SW_TOKEN"] = "tt" - self.ds = data_store.RemoteDataStore("http://test") + self.ds = data_store.RemoteDataStore("http://test", "tt") @patch("starwhale.api._impl.data_store.requests.post") def test_update_table(self, mock_post: Mock) -> None: @@ -1500,7 +1499,7 @@ def test_run_thread_exception_limit(self, request_mock: Mocker) -> None: url="http://1.1.1.1/api/v1/datastore/updateTable", status_code=400, ) - remote_store = data_store.RemoteDataStore("http://1.1.1.1") + remote_store = data_store.RemoteDataStore("http://1.1.1.1", "tt") remote_writer = data_store.TableWriter( "p/test", "k", remote_store, run_exceptions_limits=0 ) @@ -1538,7 +1537,7 @@ def test_run_thread_exception(self, request_mock: Mocker) -> None: url="http://1.1.1.1/api/v1/datastore/updateTable", status_code=400, ) - remote_store = data_store.RemoteDataStore("http://1.1.1.1") + remote_store = data_store.RemoteDataStore("http://1.1.1.1", "tt") remote_writer = data_store.TableWriter("p/test", "k", remote_store) assert remote_writer.is_alive() diff --git a/client/tests/sdk/test_wrapper.py b/client/tests/sdk/test_wrapper.py index 77dde2be0d..7fc01117cd 100644 --- a/client/tests/sdk/test_wrapper.py +++ b/client/tests/sdk/test_wrapper.py @@ -12,8 +12,6 @@ class TestEvaluation(BaseTestCase): def setUp(self) -> None: super().setUp() - os.environ[SWEnv.project] = "test" - os.environ[SWEnv.eval_version] = "tt" def tearDown(self) -> None: super().tearDown() @@ -21,7 +19,7 @@ def tearDown(self) -> None: os.environ.pop(SWEnv.instance_token, None) def test_log_results_and_scan(self) -> None: - eval = wrapper.Evaluation("test") + eval = wrapper.Evaluation("tt", "test") eval.log_result("0", 3) eval.log_result("1", 4) eval.log_result("2", 5, a="0", B="1") @@ -38,7 +36,7 @@ def test_log_results_and_scan(self) -> None: ) def test_log_metrics(self) -> None: - eval = wrapper.Evaluation() + eval = wrapper.Evaluation("tt", "test") eval.log_metrics(a=0, B=1, c=None) eval.log_metrics({"a/b": 2}) eval.close() @@ -61,7 +59,7 @@ def test_exception_close(self, request_mock: Mocker) -> None: os.environ[SWEnv.instance_token] = "abcd" os.environ[SWEnv.instance_uri] = "http://1.1.1.1" - eval = wrapper.Evaluation("test") + eval = wrapper.Evaluation("tt", "test") eval.log_result("0", 3) eval.log_metrics({"a/b": 2}) @@ -83,10 +81,9 @@ def test_exception_close(self, request_mock: Mocker) -> None: class TestDataset(BaseTestCase): def setUp(self) -> None: super().setUp() - os.environ[SWEnv.project] = "test" def test_put_and_scan(self) -> None: - dataset = wrapper.Dataset("dt") + dataset = wrapper.Dataset("dt", "test") dataset.put("0", a=1, b=2) dataset.put("1", a=2, b=3) dataset.put("2", a=3, b=4)