Skip to content

Commit

Permalink
refactor(datastore): remove wrapper project and eval_version useless …
Browse files Browse the repository at this point in the history
…env (#1274)
  • Loading branch information
tianweidut committed Sep 22, 2022
1 parent 0819efd commit 10a9204
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 32 deletions.
24 changes: 16 additions & 8 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from starwhale.utils.fs import ensure_dir
from starwhale.consts.env import SWEnv
from starwhale.utils.error import MissingFieldError
from starwhale.utils.error import MissingFieldError, FieldTypeOrValueError
from starwhale.utils.retry import http_retry
from starwhale.utils.config import SWCliConfigMixed

Expand Down Expand Up @@ -900,11 +900,15 @@ def dump(self) -> None:


class RemoteDataStore:
def __init__(self, instance_uri: str, token: str = "") -> None:
def __init__(self, instance_uri: str, token: str) -> None:
if not instance_uri:
raise FieldTypeOrValueError("instance_uri not set")

if not token:
raise FieldTypeOrValueError("token not set")

self.instance_uri = instance_uri
self.token = token or os.getenv(SWEnv.instance_token)
if self.token is None:
raise RuntimeError("SW_TOKEN is not found in environment")
self.token = token

@http_retry
def update_table(
Expand Down Expand Up @@ -1033,15 +1037,19 @@ def scan_tables(
...


def get_data_store(instance_uri: str = "") -> DataStore:
def get_data_store(instance_uri: str = "", token: str = "") -> DataStore:
_instance_uri = instance_uri or os.getenv(SWEnv.instance_uri)
if _instance_uri is None or _instance_uri == "local":
return LocalDataStore.get_instance()
else:
print(f"instance:{instance_uri}")
token = (
token
or SWCliConfigMixed().get_sw_token(instance=instance_uri)
or os.getenv(SWEnv.instance_token, "")
)
return RemoteDataStore(
instance_uri=_instance_uri,
token=SWCliConfigMixed().get_sw_token(instance=instance_uri),
token=token,
)


Expand Down
22 changes: 9 additions & 13 deletions client/starwhale/api/_impl/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
import re
import threading
from typing import Any, Dict, List, Union, Iterator, Optional

from loguru import logger

from starwhale.consts import VERSION_PREFIX_CNT
from starwhale.consts.env import SWEnv

from . import data_store

Expand Down Expand Up @@ -47,20 +45,18 @@ def _log(self, table_name: str, record: Dict[str, Any]) -> None:


class Evaluation(Logger):
def __init__(self, eval_id: str = "", project: str = "", instance: str = ""):
eval_id = eval_id or os.getenv(SWEnv.eval_version, "")
def __init__(self, eval_id: str, project: str, instance: str = ""):
if not eval_id:
raise RuntimeError("eval id should not be None")
if re.match(r"^[A-Za-z0-9-_]+$", eval_id) is None:
raise RuntimeError(
f"invalid eval id {eval_id}, only letters(A-Z, a-z), digits(0-9), hyphen('-'), and underscore('_') are allowed"
)
self.eval_id = eval_id

self.project = project or os.getenv(SWEnv.project, "")
if not self.project:
raise RuntimeError(f"{SWEnv.project} is not set")
if not project:
raise RuntimeError("project is not set")

self.eval_id = eval_id
self.project = project
self._results_table_name = self._get_datastore_table_name("results")
self._summary_table_name = f"project/{self.project}/eval/summary"
self._init_writers([self._results_table_name, self._summary_table_name])
Expand Down Expand Up @@ -117,15 +113,15 @@ def get(self, table_name: str) -> Iterator[Dict[str, Any]]:


class Dataset(Logger):
def __init__(self, dataset_id: str, project: str = "") -> None:
def __init__(self, dataset_id: str, project: str) -> None:
if not dataset_id:
raise RuntimeError("id should not be None")

self.dataset_id = dataset_id
self.project = project or os.getenv(SWEnv.project)
if not self.project:
if not project:
raise RuntimeError("project is not set")

self.dataset_id = dataset_id
self.project = project
self._meta_table_name = f"project/{self.project}/dataset/{self.dataset_id}/meta"
self._data_store = data_store.get_data_store()
self._init_writers([self._meta_table_name])
Expand Down
7 changes: 3 additions & 4 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,7 @@ def test_data_store_scan(self) -> None:

class TestRemoteDataStore(unittest.TestCase):
def setUp(self) -> None:
os.environ["SW_TOKEN"] = "tt"
self.ds = data_store.RemoteDataStore("http://test")
self.ds = data_store.RemoteDataStore("http://test", "tt")

@patch("starwhale.api._impl.data_store.requests.post")
def test_update_table(self, mock_post: Mock) -> None:
Expand Down Expand Up @@ -1500,7 +1499,7 @@ def test_run_thread_exception_limit(self, request_mock: Mocker) -> None:
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_store = data_store.RemoteDataStore("http://1.1.1.1", "tt")
remote_writer = data_store.TableWriter(
"p/test", "k", remote_store, run_exceptions_limits=0
)
Expand Down Expand Up @@ -1538,7 +1537,7 @@ def test_run_thread_exception(self, request_mock: Mocker) -> None:
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_store = data_store.RemoteDataStore("http://1.1.1.1", "tt")
remote_writer = data_store.TableWriter("p/test", "k", remote_store)

assert remote_writer.is_alive()
Expand Down
11 changes: 4 additions & 7 deletions client/tests/sdk/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
class TestEvaluation(BaseTestCase):
def setUp(self) -> None:
super().setUp()
os.environ[SWEnv.project] = "test"
os.environ[SWEnv.eval_version] = "tt"

def tearDown(self) -> None:
super().tearDown()
os.environ.pop(SWEnv.instance_uri, None)
os.environ.pop(SWEnv.instance_token, None)

def test_log_results_and_scan(self) -> None:
eval = wrapper.Evaluation("test")
eval = wrapper.Evaluation("tt", "test")
eval.log_result("0", 3)
eval.log_result("1", 4)
eval.log_result("2", 5, a="0", B="1")
Expand All @@ -38,7 +36,7 @@ def test_log_results_and_scan(self) -> None:
)

def test_log_metrics(self) -> None:
eval = wrapper.Evaluation()
eval = wrapper.Evaluation("tt", "test")
eval.log_metrics(a=0, B=1, c=None)
eval.log_metrics({"a/b": 2})
eval.close()
Expand All @@ -61,7 +59,7 @@ def test_exception_close(self, request_mock: Mocker) -> None:

os.environ[SWEnv.instance_token] = "abcd"
os.environ[SWEnv.instance_uri] = "http://1.1.1.1"
eval = wrapper.Evaluation("test")
eval = wrapper.Evaluation("tt", "test")
eval.log_result("0", 3)
eval.log_metrics({"a/b": 2})

Expand All @@ -83,10 +81,9 @@ def test_exception_close(self, request_mock: Mocker) -> None:
class TestDataset(BaseTestCase):
def setUp(self) -> None:
super().setUp()
os.environ[SWEnv.project] = "test"

def test_put_and_scan(self) -> None:
dataset = wrapper.Dataset("dt")
dataset = wrapper.Dataset("dt", "test")
dataset.put("0", a=1, b=2)
dataset.put("1", a=2, b=3)
dataset.put("2", a=3, b=4)
Expand Down

0 comments on commit 10a9204

Please sign in to comment.