diff --git a/examples/postgrest/docker-compose.yml b/examples/postgrest/docker-compose.yml index af400656..0d5adb3f 100644 --- a/examples/postgrest/docker-compose.yml +++ b/examples/postgrest/docker-compose.yml @@ -13,6 +13,9 @@ services: - 5432 postgrest: image: postgrest/postgrest:latest + command: + - postgrest + - /etc/postgrest.conf ports: - '0.0.0.0:8080:8080' volumes: diff --git a/splitgraph/cloud/__init__.py b/splitgraph/cloud/__init__.py index 1ca147b9..4843c6d9 100644 --- a/splitgraph/cloud/__init__.py +++ b/splitgraph/cloud/__init__.py @@ -16,6 +16,7 @@ from splitgraph.__version__ import __version__ from splitgraph.cloud.models import ( + Credential, Metadata, MetadataResponse, External, @@ -27,6 +28,7 @@ AddExternalCredentialRequest, UpdateExternalCredentialResponse, AddExternalRepositoryRequest, + AddExternalRepositoriesRequest, ) from splitgraph.commandline.engine import patch_and_save_config from splitgraph.config import create_config_dict, get_singleton, CONFIG @@ -57,6 +59,69 @@ def get_headers(): } } +_BULK_UPSERT_REPO_PROFILES_QUERY = """mutation BulkUpsertRepoProfilesMutation( + $namespaces: [String!] + $repositories: [String!] + $descriptions: [String] + $readmes: [String] + $licenses: [String] + $metadata: [JSON] +) { + __typename + bulkUpsertRepoProfiles( + input: { + namespaces: $namespaces + repositories: $repositories + descriptions: $descriptions + readmes: $readmes + licenses: $licenses + metadata: $metadata + } + ) { + clientMutationId + __typename + } +} +""" + +_BULK_UPDATE_REPO_SOURCES_QUERY = """mutation BulkUpdateRepoSourcesMutation( + $namespaces: [String!] + $repositories: [String!] + $sources: [DatasetSourceInput] +) { + __typename + bulkUpdateRepoSources( + input: { + namespaces: $namespaces + repositories: $repositories + sources: $sources + } + ) { + clientMutationId + __typename + } +} +""" + +_BULK_UPSERT_REPO_TOPICS_QUERY = """mutation BulkUpsertRepoTopicsMutation( + $namespaces: [String!] + $repositories: [String!] + $topics: [String] +) { + __typename + bulkUpsertRepoTopics( + input: { + namespaces: $namespaces + repositories: $repositories + topics: $topics + } + ) { + clientMutationId + __typename + } +} +""" + _PROFILE_UPSERT_QUERY = """mutation UpsertRepoProfile( $namespace: String! $repository: String! @@ -567,17 +632,11 @@ def ensure_external_credential( assert credential return credential.credential_id - def upsert_external( - self, - namespace: str, - repository: str, - external: External, - credentials_map: Optional[Dict[str, str]] = None, - ): - request = AddExternalRepositoryRequest.from_external( - namespace, repository, external, credentials_map + def bulk_upsert_external(self, repositories: List[AddExternalRepositoryRequest]): + request = AddExternalRepositoriesRequest(repositories=repositories) + self._perform_request( + "/bulk-add", self.access_token, request, endpoint=self.externals_endpoint ) - self._perform_request("/add", self.access_token, request, endpoint=self.externals_endpoint) def AuthAPIClient(*args, **kwargs): @@ -633,7 +692,7 @@ def _gql(self, query: Dict, endpoint=None, handle_errors=False) -> requests.Resp return result @staticmethod - def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Metadata, v1=False): + def _validate_metadata(namespace: str, repository: str, metadata: Metadata): # Pre-flight validation if metadata.description and len(metadata.description) > 160: raise ValueError("The description should be 160 characters or shorter!") @@ -669,6 +728,12 @@ def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Meta if "readme" in variables and isinstance(variables["readme"], dict): variables["readme"] = variables["readme"]["text"] + return variables + + @staticmethod + def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Metadata, v1=False): + variables = GQLAPIClient._validate_metadata(namespace, repository, metadata) + gql_query = _PROFILE_UPSERT_QUERY if v1: gql_query = gql_query.replace("createRepoTopicsAgg", "createRepoTopic").replace( @@ -706,6 +771,79 @@ def upsert_metadata(self, namespace: str, repository: str, metadata: Metadata): ) return response + def bulk_upsert_metadata( + self, namespace_list: List[str], repository_list: List[str], metadata_list: List[Metadata] + ): + repo_profiles: Dict[str, List[Any]] = dict( + namespaces=namespace_list, + repositories=repository_list, + descriptions=[], + readmes=[], + licenses=[], + metadata=[], + ) + repo_sources: Dict[str, List[Any]] = dict(namespaces=[], repositories=[], sources=[]) + repo_topics: Dict[str, List[str]] = dict(namespaces=[], repositories=[], topics=[]) + + # populate mutation payloads + for ind, metadata in enumerate(metadata_list): + validated_metadata = GQLAPIClient._validate_metadata( + namespace_list[ind], repository_list[ind], metadata + ) + + repo_profiles["descriptions"].append(validated_metadata.get("description")) + repo_profiles["readmes"].append(validated_metadata.get("readme")) + repo_profiles["licenses"].append(validated_metadata.get("license")) + repo_profiles["metadata"].append(validated_metadata.get("metadata")) + + # flatten sources, which will be aggregated on the server side + if len(validated_metadata.get("sources", [])) > 0: + for source in validated_metadata["sources"]: + repo_sources["namespaces"].append(namespace_list[ind]) + repo_sources["repositories"].append(repository_list[ind]) + repo_sources["sources"].append(source) + + # flatten topics, which will be aggregated on the server side + if len(validated_metadata.get("topics", [])) > 0: + for topic in validated_metadata["topics"]: + repo_topics["namespaces"].append(namespace_list[ind]) + repo_topics["repositories"].append(repository_list[ind]) + repo_topics["topics"].append(topic) + + self._bulk_upsert_repo_profiles(repo_profiles) + self._bulk_upsert_repo_sources(repo_sources) + self._bulk_upsert_repo_topics(repo_topics) + + @handle_gql_errors + def _bulk_upsert_repo_profiles(self, repo_profiles: Dict[str, List[Any]]): + repo_profiles_query = { + "operationName": "BulkUpsertRepoProfilesMutation", + "variables": repo_profiles, + "query": _BULK_UPSERT_REPO_PROFILES_QUERY, + } + response = self._gql(repo_profiles_query) + return response + + @handle_gql_errors + def _bulk_upsert_repo_sources(self, repo_sources: Dict[str, List[Any]]): + repo_sources_query = { + "operationName": "BulkUpdateRepoSourcesMutation", + "variables": repo_sources, + "query": _BULK_UPDATE_REPO_SOURCES_QUERY, + } + response = self._gql(repo_sources_query) + return response + + @handle_gql_errors + def _bulk_upsert_repo_topics(self, repo_topics: Dict[str, List[str]]): + repo_topics_query = { + "operationName": "BulkUpsertRepoTopicsMutation", + "variables": repo_topics, + "query": _BULK_UPSERT_REPO_TOPICS_QUERY, + } + response = self._gql(repo_topics_query) + return response + def upsert_readme(self, namespace: str, repository: str, readme: str): return self.upsert_metadata(namespace, repository, Metadata(readme=readme)) diff --git a/splitgraph/cloud/models.py b/splitgraph/cloud/models.py index 39ea6a83..266a59f6 100644 --- a/splitgraph/cloud/models.py +++ b/splitgraph/cloud/models.py @@ -27,12 +27,18 @@ class Credential(BaseModel): data: Dict[str, Any] +class IngestionSchedule(BaseModel): + schedule: str + enabled = True + + class External(BaseModel): credential_id: Optional[str] credential: Optional[str] plugin: str params: Dict[str, Any] tables: Dict[str, Table] + schedule: Optional[IngestionSchedule] # Models for the catalog metadata (description, README, topics etc) @@ -226,6 +232,7 @@ class AddExternalRepositoryRequest(BaseModel): is_live: bool tables: Optional[Dict[str, ExternalTableRequest]] credential_id: Optional[str] + schedule: Optional[IngestionSchedule] @classmethod def from_external( @@ -259,4 +266,9 @@ def from_external( }, credential_id=credential_id, is_live=True, + schedule=external.schedule, ) + + +class AddExternalRepositoriesRequest(BaseModel): + repositories: List[AddExternalRepositoryRequest] diff --git a/splitgraph/commandline/cloud.py b/splitgraph/commandline/cloud.py index 269ab500..c8f2244d 100644 --- a/splitgraph/commandline/cloud.py +++ b/splitgraph/commandline/cloud.py @@ -14,7 +14,7 @@ from click import wrap_text from tqdm import tqdm -from splitgraph.cloud.models import Metadata, RepositoriesYAML +from splitgraph.cloud.models import Metadata, RepositoriesYAML, AddExternalRepositoryRequest from splitgraph.commandline.common import ( ImageType, RepositoryType, @@ -22,6 +22,7 @@ Color, ) from splitgraph.commandline.engine import patch_and_save_config, inject_config_into_engines +from splitgraph.core.output import pluralise # Hardcoded database name for the Splitgraph DDN (ddn instead of sgregistry) from splitgraph.config.config import get_from_subsection @@ -675,16 +676,30 @@ def load_c(remote, readme_dir, repositories_file, limit_repositories): r for r in repositories if f"{r.namespace}/{r.repository}" in limit_repositories ] - with tqdm(repositories) as t: - for repository in t: - t.set_description(f"{repository.namespace}/{repository.repository}") - if repository.external: - rest_client.upsert_external( - repository.namespace, repository.repository, repository.external, credential_map - ) - if repository.metadata: - metadata = _prepare_metadata(repository.metadata, readme_basedir=readme_dir) - gql_client.upsert_metadata(repository.namespace, repository.repository, metadata) + logging.info("Uploading images...") + external_repositories = [] + for repository in repositories: + if repository.external: + external_repository = AddExternalRepositoryRequest.from_external( + repository.namespace, repository.repository, repository.external, credential_map + ) + external_repositories.append(external_repository) + rest_client.bulk_upsert_external(repositories=external_repositories) + logging.info(f"Uploaded images for {pluralise('repository', len(external_repositories))}") + + logging.info("Updating metadata...") + namespace_list = [] + repository_list = [] + metadata_list = [] + for repository in repositories: + if repository.metadata: + namespace_list.append(repository.namespace) + repository_list.append(repository.repository) + + metadata = _prepare_metadata(repository.metadata, readme_basedir=readme_dir) + metadata_list.append(metadata) + gql_client.bulk_upsert_metadata(namespace_list, repository_list, metadata_list) + logging.info(f"Updated metadata for {pluralise('repository', len(repository_list))}") def _build_credential_map(auth_client, credentials=None): diff --git a/splitgraph/core/output.py b/splitgraph/core/output.py index 2d9064a2..fc2f750a 100644 --- a/splitgraph/core/output.py +++ b/splitgraph/core/output.py @@ -21,6 +21,8 @@ def pretty_size(size: Union[int, float]) -> str: def pluralise(word: str, number: int) -> str: """1 banana, 2 bananas""" + if word.endswith("y"): + return "%d %s" % (number, word if number == 1 else word[:-1] + "ies") return "%d %s%s" % (number, word, "" if number == 1 else "s") diff --git a/splitgraph/engine/postgres/engine.py b/splitgraph/engine/postgres/engine.py index 41c77df1..22a1bc69 100644 --- a/splitgraph/engine/postgres/engine.py +++ b/splitgraph/engine/postgres/engine.py @@ -59,6 +59,8 @@ # the connection property otherwise from psycopg2._psycopg import connection as Connection +psycopg2.extensions.register_adapter(dict, Json) + _AUDIT_SCHEMA = "splitgraph_audit" _AUDIT_TRIGGER = "resources/static/audit_trigger.sql" _PUSH_PULL = "resources/static/splitgraph_api.sql" @@ -511,7 +513,7 @@ def run_sql( with connection.cursor(**cursor_kwargs) as cur: try: self.notices = [] - cur.execute(statement, _convert_vals(arguments) if arguments else None) + cur.execute(statement, arguments) if connection.notices: self.notices = connection.notices[:] del connection.notices[:] @@ -603,7 +605,7 @@ def run_sql_batch( batches = _paginate_by_size( cur, statement, - (_convert_vals(a) for a in arguments), + arguments, max_size=API_MAX_QUERY_LENGTH, ) for batch in batches: @@ -1603,13 +1605,6 @@ def _convert_audit_change( _KIND = {"I": 0, "D": 1, "U": 2} -def _convert_vals(vals: Any) -> Any: - """Psycopg returns jsonb objects as dicts/lists but doesn't actually accept them directly - as a query param (or in the case of lists coerces them into an array. - Hence, we have to wrap them in the Json datatype when doing a dump + load.""" - return [Json(v) if isinstance(v, dict) else v for v in vals] - - def _generate_where_clause(table: str, cols: List[str], table_2: str) -> Composed: return SQL(" AND ").join( SQL("{}.{} = {}.{}").format( diff --git a/test/splitgraph/commandline/http_fixtures.py b/test/splitgraph/commandline/http_fixtures.py index 1af12a28..7a4df032 100644 --- a/test/splitgraph/commandline/http_fixtures.py +++ b/test/splitgraph/commandline/http_fixtures.py @@ -1,6 +1,11 @@ import json -from splitgraph.cloud import _PROFILE_UPSERT_QUERY +from splitgraph.cloud import ( + _PROFILE_UPSERT_QUERY, + _BULK_UPSERT_REPO_PROFILES_QUERY, + _BULK_UPDATE_REPO_SOURCES_QUERY, + _BULK_UPSERT_REPO_TOPICS_QUERY, +) REMOTE = "remote_engine" AUTH_ENDPOINT = "http://some-auth-service.example.com" @@ -332,28 +337,9 @@ def add_external_credential(request, uri, response_headers): def add_external_repo(request, uri, response_headers): data = json.loads(request.body) - if data["namespace"] == "someuser" and data["repository"] == "somerepo_1": - assert data == { - "namespace": "someuser", - "repository": "somerepo_1", - "plugin_name": "plugin_2", - "params": {}, - "is_live": True, - "tables": {}, - "credential_id": "123e4567-e89b-12d3-a456-426655440000", - } - elif data["namespace"] == "someuser" and data["repository"] == "somerepo_2": - assert data == { - "namespace": "someuser", - "repository": "somerepo_2", - "plugin_name": "plugin_3", - "params": {}, - "is_live": True, - "tables": {}, - "credential_id": "00000000-0000-0000-0000-000000000000", - } - elif data["namespace"] == "otheruser" and data["repository"] == "somerepo_2": - assert data == { + assert data["repositories"] is not None + assert data["repositories"] == [ + { "credential_id": "98765432-aaaa-bbbb-a456-000000000000", "is_live": True, "namespace": "otheruser", @@ -368,77 +354,85 @@ def add_external_repo(request, uri, response_headers): "table_2": {"options": {"param_1": "val_2"}, "schema": {}}, "table_3": {"options": {}, "schema": {"id": "text", "val": "text"}}, }, - } - else: - raise AssertionError("Unknown repository %s/%s!" % (data["namespace"], data["repository"])) + "schedule": None, + }, + { + "namespace": "someuser", + "repository": "somerepo_1", + "plugin_name": "plugin_2", + "params": {}, + "is_live": True, + "tables": {}, + "credential_id": "123e4567-e89b-12d3-a456-426655440000", + "schedule": None, + }, + { + "namespace": "someuser", + "repository": "somerepo_2", + "plugin_name": "plugin_3", + "params": {}, + "is_live": True, + "tables": {}, + "credential_id": "00000000-0000-0000-0000-000000000000", + "schedule": None, + }, + ] return [ 200, response_headers, - json.dumps({"live_image_hash": "abcdef12" * 8}), + json.dumps({"live_image_hashes": ["abcdef12" * 8, "ghijkl34" * 8, "mnoprs56" * 8]}), ] -def upsert_repository_metadata(request, uri, response_headers): +def assert_repository_profiles(request): data = json.loads(request.body) - assert data["operationName"] == "UpsertRepoProfile" - assert data["query"] == _PROFILE_UPSERT_QUERY + assert data["operationName"] == "BulkUpsertRepoProfilesMutation" + assert data["query"] == _BULK_UPSERT_REPO_PROFILES_QUERY variables = data["variables"] - if variables["namespace"] == "someuser" and variables["repository"] == "somerepo_1": - assert variables == { - "namespace": "someuser", - "repository": "somerepo_1", - "readme": "# Readme 1", - "description": "Repository Description 1", - "topics": [], - "sources": [ - { - "anchor": "test data source", - "href": "https://example.com", - "isCreator": True, - "isSameAs": False, - } - ], - "license": "Public Domain", - } - elif variables["namespace"] == "someuser" and variables["repository"] == "somerepo_2": - assert variables == { - "description": "Another Repository", - "namespace": "someuser", - "repository": "somerepo_2", - } - elif variables["namespace"] == "otheruser" and variables["repository"] == "somerepo_2": - assert variables == { - "description": "Repository Description 2", - "namespace": "otheruser", - "readme": "# Readme 2", - "repository": "somerepo_2", - "sources": [{"anchor": "test data source", "href": "https://example.com"}], - "topics": ["topic_1", "topic_2"], - } - else: - raise AssertionError( - "Unknown repository %s/%s!" % (variables["namespace"], variables["repository"]) - ) - - success_response = { - "data": { - "__typename": "Mutation", - "upsertRepoProfileByNamespaceAndRepository": { - "clientMutationId": None, - "__typename": "UpsertRepoProfilePayload", - }, - } - } + assert variables["namespaces"] == ["otheruser", "someuser", "someuser"] + assert variables["repositories"] == ["somerepo_2", "somerepo_1", "somerepo_2"] + assert variables["readmes"] == ["# Readme 2", "# Readme 1", None] + assert variables["descriptions"] == [ + "Repository Description 2", + "Repository Description 1", + "Another Repository", + ] + assert variables["licenses"] == [None, "Public Domain", None] + assert variables["metadata"] == [None, None, None] - return [ - 200, - response_headers, - json.dumps(success_response), + +def assert_repository_sources(request): + data = json.loads(request.body) + assert data["operationName"] == "BulkUpdateRepoSourcesMutation" + assert data["query"] == _BULK_UPDATE_REPO_SOURCES_QUERY + + variables = data["variables"] + assert variables["namespaces"] == ["otheruser", "someuser"] + assert variables["repositories"] == ["somerepo_2", "somerepo_1"] + assert variables["sources"] == [ + {"anchor": "test data source", "href": "https://example.com"}, + { + "anchor": "test data source", + "href": "https://example.com", + "isCreator": True, + "isSameAs": False, + }, ] +def assert_repository_topics(request): + data = json.loads(request.body) + assert data["operationName"] == "BulkUpsertRepoTopicsMutation" + assert data["query"] == _BULK_UPSERT_REPO_TOPICS_QUERY + + variables = data["variables"] + assert variables["namespaces"] == ["otheruser", "otheruser"] + assert variables["repositories"] == ["somerepo_2", "somerepo_2"] + assert variables["topics"] == ["topic_1", "topic_2"] + + def register_user(request, uri, response_headers): assert json.loads(request.body) == { "username": "someuser", diff --git a/test/splitgraph/commandline/test_cloud_metadata.py b/test/splitgraph/commandline/test_cloud_metadata.py index c96ea283..4b298f0b 100644 --- a/test/splitgraph/commandline/test_cloud_metadata.py +++ b/test/splitgraph/commandline/test_cloud_metadata.py @@ -34,7 +34,9 @@ update_external_credential, add_external_credential, add_external_repo, - upsert_repository_metadata, + assert_repository_profiles, + assert_repository_sources, + assert_repository_topics, AUTH_ENDPOINT, ) from test.splitgraph.conftest import RESOURCES @@ -363,13 +365,11 @@ def test_commandline_load(): httpretty.register_uri( httpretty.HTTPretty.POST, - QUERY_ENDPOINT + "/api/external/add", + QUERY_ENDPOINT + "/api/external/bulk-add", body=add_external_repo, ) - httpretty.register_uri( - httpretty.HTTPretty.POST, GQL_ENDPOINT + "/", body=upsert_repository_metadata - ) + httpretty.register_uri(httpretty.HTTPretty.POST, GQL_ENDPOINT + "/") def get_remote_param(remote, param): if param == "SG_AUTH_API": @@ -398,4 +398,11 @@ def get_remote_param(remote, param): catch_exceptions=False, ) assert result.exit_code == 0 - assert "someuser/somerepo_1" in result.output + + reqs = httpretty.latest_requests() + + assert_repository_topics(reqs.pop()) + reqs.pop() # discard duplicate request + assert_repository_sources(reqs.pop()) + reqs.pop() # discard duplicate request + assert_repository_profiles(reqs.pop())