Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(datastore): remove wrapper project and eval_version useless env #1274

Merged
merged 2 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from starwhale.utils.fs import ensure_dir
from starwhale.consts.env import SWEnv
from starwhale.utils.error import MissingFieldError
from starwhale.utils.error import MissingFieldError, FieldTypeOrValueError
from starwhale.utils.retry import http_retry
from starwhale.utils.config import SWCliConfigMixed

Expand Down Expand Up @@ -900,11 +900,15 @@ def dump(self) -> None:


class RemoteDataStore:
def __init__(self, instance_uri: str, token: str = "") -> None:
def __init__(self, instance_uri: str, token: str) -> None:
if not instance_uri:
raise FieldTypeOrValueError("instance_uri not set")

if not token:
raise FieldTypeOrValueError("token not set")

self.instance_uri = instance_uri
self.token = token or os.getenv(SWEnv.instance_token)
if self.token is None:
raise RuntimeError("SW_TOKEN is not found in environment")
self.token = token

@http_retry
def update_table(
Expand Down Expand Up @@ -1033,15 +1037,19 @@ def scan_tables(
...


def get_data_store(instance_uri: str = "") -> DataStore:
def get_data_store(instance_uri: str = "", token: str = "") -> DataStore:
_instance_uri = instance_uri or os.getenv(SWEnv.instance_uri)
if _instance_uri is None or _instance_uri == "local":
return LocalDataStore.get_instance()
else:
print(f"instance:{instance_uri}")
token = (
token
or SWCliConfigMixed().get_sw_token(instance=instance_uri)
or os.getenv(SWEnv.instance_token, "")
)
return RemoteDataStore(
instance_uri=_instance_uri,
token=SWCliConfigMixed().get_sw_token(instance=instance_uri),
token=token,
)


Expand Down
22 changes: 9 additions & 13 deletions client/starwhale/api/_impl/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
import re
import threading
from typing import Any, Dict, List, Union, Iterator, Optional

from loguru import logger

from starwhale.consts import VERSION_PREFIX_CNT
from starwhale.consts.env import SWEnv

from . import data_store

Expand Down Expand Up @@ -47,20 +45,18 @@ def _log(self, table_name: str, record: Dict[str, Any]) -> None:


class Evaluation(Logger):
def __init__(self, eval_id: str = "", project: str = "", instance: str = ""):
eval_id = eval_id or os.getenv(SWEnv.eval_version, "")
def __init__(self, eval_id: str, project: str, instance: str = ""):
if not eval_id:
raise RuntimeError("eval id should not be None")
if re.match(r"^[A-Za-z0-9-_]+$", eval_id) is None:
raise RuntimeError(
f"invalid eval id {eval_id}, only letters(A-Z, a-z), digits(0-9), hyphen('-'), and underscore('_') are allowed"
)
self.eval_id = eval_id

self.project = project or os.getenv(SWEnv.project, "")
if not self.project:
raise RuntimeError(f"{SWEnv.project} is not set")
if not project:
raise RuntimeError("project is not set")

self.eval_id = eval_id
self.project = project
self._results_table_name = self._get_datastore_table_name("results")
self._summary_table_name = f"project/{self.project}/eval/summary"
self._init_writers([self._results_table_name, self._summary_table_name])
Expand Down Expand Up @@ -117,15 +113,15 @@ def get(self, table_name: str) -> Iterator[Dict[str, Any]]:


class Dataset(Logger):
def __init__(self, dataset_id: str, project: str = "") -> None:
def __init__(self, dataset_id: str, project: str) -> None:
if not dataset_id:
raise RuntimeError("id should not be None")

self.dataset_id = dataset_id
self.project = project or os.getenv(SWEnv.project)
if not self.project:
if not project:
raise RuntimeError("project is not set")

self.dataset_id = dataset_id
self.project = project
self._meta_table_name = f"project/{self.project}/dataset/{self.dataset_id}/meta"
self._data_store = data_store.get_data_store()
self._init_writers([self._meta_table_name])
Expand Down
7 changes: 3 additions & 4 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,7 @@ def test_data_store_scan(self) -> None:

class TestRemoteDataStore(unittest.TestCase):
def setUp(self) -> None:
os.environ["SW_TOKEN"] = "tt"
self.ds = data_store.RemoteDataStore("http://test")
self.ds = data_store.RemoteDataStore("http://test", "tt")

@patch("starwhale.api._impl.data_store.requests.post")
def test_update_table(self, mock_post: Mock) -> None:
Expand Down Expand Up @@ -1500,7 +1499,7 @@ def test_run_thread_exception_limit(self, request_mock: Mocker) -> None:
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_store = data_store.RemoteDataStore("http://1.1.1.1", "tt")
remote_writer = data_store.TableWriter(
"p/test", "k", remote_store, run_exceptions_limits=0
)
Expand Down Expand Up @@ -1538,7 +1537,7 @@ def test_run_thread_exception(self, request_mock: Mocker) -> None:
url="http://1.1.1.1/api/v1/datastore/updateTable",
status_code=400,
)
remote_store = data_store.RemoteDataStore("http://1.1.1.1")
remote_store = data_store.RemoteDataStore("http://1.1.1.1", "tt")
remote_writer = data_store.TableWriter("p/test", "k", remote_store)

assert remote_writer.is_alive()
Expand Down
11 changes: 4 additions & 7 deletions client/tests/sdk/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
class TestEvaluation(BaseTestCase):
def setUp(self) -> None:
super().setUp()
os.environ[SWEnv.project] = "test"
os.environ[SWEnv.eval_version] = "tt"

def tearDown(self) -> None:
super().tearDown()
os.environ.pop(SWEnv.instance_uri, None)
os.environ.pop(SWEnv.instance_token, None)

def test_log_results_and_scan(self) -> None:
eval = wrapper.Evaluation("test")
eval = wrapper.Evaluation("tt", "test")
eval.log_result("0", 3)
eval.log_result("1", 4)
eval.log_result("2", 5, a="0", B="1")
Expand All @@ -38,7 +36,7 @@ def test_log_results_and_scan(self) -> None:
)

def test_log_metrics(self) -> None:
eval = wrapper.Evaluation()
eval = wrapper.Evaluation("tt", "test")
eval.log_metrics(a=0, B=1, c=None)
eval.log_metrics({"a/b": 2})
eval.close()
Expand All @@ -61,7 +59,7 @@ def test_exception_close(self, request_mock: Mocker) -> None:

os.environ[SWEnv.instance_token] = "abcd"
os.environ[SWEnv.instance_uri] = "http://1.1.1.1"
eval = wrapper.Evaluation("test")
eval = wrapper.Evaluation("tt", "test")
eval.log_result("0", 3)
eval.log_metrics({"a/b": 2})

Expand All @@ -83,10 +81,9 @@ def test_exception_close(self, request_mock: Mocker) -> None:
class TestDataset(BaseTestCase):
def setUp(self) -> None:
super().setUp()
os.environ[SWEnv.project] = "test"

def test_put_and_scan(self) -> None:
dataset = wrapper.Dataset("dt")
dataset = wrapper.Dataset("dt", "test")
dataset.put("0", a=1, b=2)
dataset.put("1", a=2, b=3)
dataset.put("2", a=3, b=4)
Expand Down