Skip to content

Commit

Permalink
fix(#1548): access datasets for superusers when workspace is not prov…
Browse files Browse the repository at this point in the history
…ided (#1572, #1608)

- create ds with same name in different ws and different task
  • Loading branch information
frascuchon committed Jul 8, 2022
1 parent 04e101c commit 0b04bc8
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 8 deletions.
14 changes: 10 additions & 4 deletions src/rubrix/server/services/datasets.py
Expand Up @@ -99,7 +99,9 @@ def find_by_name(
if found_ds is None:
raise EntityNotFoundError(name=name, type=Dataset)
if found_ds.owner and owner and found_ds.owner != owner:
raise ForbiddenOperationError()
raise EntityNotFoundError(
name=name, type=Dataset
) if user.is_superuser() else ForbiddenOperationError()

return cast(Dataset, found_ds)

Expand All @@ -115,9 +117,13 @@ def __find_by_name_with_superuser_fallback__(
name=name, owner=owner, task=task, as_dataset_class=as_dataset_class
)
if not found_ds and user.is_superuser():
found_ds = self.__dao__.find_by_name(
name=name, owner=None, task=task, as_dataset_class=as_dataset_class
)
try:
found_ds = self.__dao__.find_by_name(
name=name, owner=None, task=task, as_dataset_class=as_dataset_class
)
except WrongTaskError:
# A dataset exists in a different workspace and with a different task
pass
return found_ds

def delete(self, user: User, dataset: Dataset):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Expand Up @@ -9,7 +9,7 @@


@pytest.fixture
def mocked_client(monkeypatch):
def mocked_client(monkeypatch) -> SecuredClient:
with TestClient(app, raise_server_exceptions=False) as _client:
client = SecuredClient(_client)

Expand Down
1 change: 0 additions & 1 deletion tests/datasets/test_datasets.py
Expand Up @@ -3,7 +3,6 @@
import rubrix as rb
from rubrix import TextClassificationSettings, TokenClassificationSettings
from rubrix.client import api
from rubrix.client.sdk.commons.errors import AlreadyExistsApiError


@pytest.mark.parametrize(
Expand Down
27 changes: 27 additions & 0 deletions tests/functional_tests/test_log_for_text_classification.py
Expand Up @@ -3,6 +3,7 @@
import rubrix as rb
from rubrix.client.sdk.commons.errors import BadRequestApiError, ValidationApiError
from rubrix.server.apis.v0.settings.server import settings
from tests.helpers import SecuredClient


def test_log_records_with_multi_and_single_label_task(mocked_client):
Expand Down Expand Up @@ -50,6 +51,32 @@ def test_delete_and_create_for_different_task(mocked_client):
rb.load(dataset)


def test_log_data_in_several_workspaces(mocked_client: SecuredClient):

workspace = "test-ws"
dataset = "test_log_data_in_several_workspaces"
text = "This is a text"

mocked_client.add_workspaces_to_rubrix_user([workspace])

curr_ws = rb.get_workspace()
for ws in [curr_ws, workspace]:
rb.set_workspace(ws)
rb.delete(dataset)

rb.set_workspace(curr_ws)
rb.log(rb.TextClassificationRecord(id=0, inputs=text), name=dataset)

rb.set_workspace(workspace)
rb.log(rb.TextClassificationRecord(id=1, inputs=text), name=dataset)
ds = rb.load(dataset)
assert len(ds) == 1

rb.set_workspace(curr_ws)
ds = rb.load(dataset)
assert len(ds) == 1


def test_search_keywords(mocked_client):
dataset = "test_search_keywords"
from datasets import load_dataset
Expand Down
52 changes: 50 additions & 2 deletions tests/server/datasets/test_api.py
Expand Up @@ -12,9 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from rubrix.server.apis.v0.models.commons.model import TaskType
from rubrix.server.apis.v0.models.datasets import Dataset
from rubrix.server.apis.v0.models.text_classification import TextClassificationBulkData
from tests.helpers import SecuredClient


def test_delete_dataset(mocked_client):
Expand Down Expand Up @@ -62,6 +65,48 @@ def test_create_dataset(mocked_client):
assert response.status_code == 409


def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient):
ws = "mock-ws"
dataset_name = "test_fetch_dataset_using_workspaces"
mocked_client.add_workspaces_to_rubrix_user([ws])

delete_dataset(mocked_client, dataset_name, workspace=ws)
delete_dataset(mocked_client, dataset_name)
request = dict(
name=dataset_name,
task=TaskType.text_classification,
)
response = mocked_client.post(
f"/api/datasets?workspace={ws}",
json=request,
)

assert response.status_code == 200, response.json()
dataset = Dataset.parse_obj(response.json())
assert dataset.created_by == "rubrix"
assert dataset.name == dataset_name
assert dataset.owner == ws
assert dataset.task == TaskType.text_classification

response = mocked_client.post(
f"/api/datasets?workspace={ws}",
json=request,
)
assert response.status_code == 409, response.json()

response = mocked_client.post(
f"/api/datasets",
json=request,
)

assert response.status_code == 200, response.json()
dataset = Dataset.parse_obj(response.json())
assert dataset.created_by == "rubrix"
assert dataset.name == dataset_name
assert dataset.owner == "rubrix"
assert dataset.task == TaskType.text_classification


def test_dataset_naming_validation(mocked_client):
request = TextClassificationBulkData(records=[])
dataset = "Wrong dataset name"
Expand Down Expand Up @@ -166,8 +211,11 @@ def test_open_and_close_dataset(mocked_client):
)


def delete_dataset(client, dataset):
assert client.delete(f"/api/datasets/{dataset}").status_code == 200
def delete_dataset(client, dataset, workspace: Optional[str] = None):
url = f"/api/datasets/{dataset}"
if workspace:
url += f"?workspace={workspace}"
assert client.delete(url).status_code == 200


def create_mock_dataset(client, dataset):
Expand Down

0 comments on commit 0b04bc8

Please sign in to comment.