Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/postgrest/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ services:
- 5432
postgrest:
image: postgrest/postgrest:latest
command:
- postgrest
- /etc/postgrest.conf
ports:
- '0.0.0.0:8080:8080'
volumes:
Expand Down
160 changes: 149 additions & 11 deletions splitgraph/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from splitgraph.__version__ import __version__
from splitgraph.cloud.models import (
Credential,
Metadata,
MetadataResponse,
External,
Expand All @@ -27,6 +28,7 @@
AddExternalCredentialRequest,
UpdateExternalCredentialResponse,
AddExternalRepositoryRequest,
AddExternalRepositoriesRequest,
)
from splitgraph.commandline.engine import patch_and_save_config
from splitgraph.config import create_config_dict, get_singleton, CONFIG
Expand Down Expand Up @@ -57,6 +59,69 @@ def get_headers():
}
}

_BULK_UPSERT_REPO_PROFILES_QUERY = """mutation BulkUpsertRepoProfilesMutation(
$namespaces: [String!]
$repositories: [String!]
$descriptions: [String]
$readmes: [String]
$licenses: [String]
$metadata: [JSON]
) {
__typename
bulkUpsertRepoProfiles(
input: {
namespaces: $namespaces
repositories: $repositories
descriptions: $descriptions
readmes: $readmes
licenses: $licenses
metadata: $metadata
}
) {
clientMutationId
__typename
}
}
"""

_BULK_UPDATE_REPO_SOURCES_QUERY = """mutation BulkUpdateRepoSourcesMutation(
$namespaces: [String!]
$repositories: [String!]
$sources: [DatasetSourceInput]
) {
__typename
bulkUpdateRepoSources(
input: {
namespaces: $namespaces
repositories: $repositories
sources: $sources
}
) {
clientMutationId
__typename
}
}
"""

_BULK_UPSERT_REPO_TOPICS_QUERY = """mutation BulkUpsertRepoTopicsMutation(
$namespaces: [String!]
$repositories: [String!]
$topics: [String]
) {
__typename
bulkUpsertRepoTopics(
input: {
namespaces: $namespaces
repositories: $repositories
topics: $topics
}
) {
clientMutationId
__typename
}
}
"""

_PROFILE_UPSERT_QUERY = """mutation UpsertRepoProfile(
$namespace: String!
$repository: String!
Expand Down Expand Up @@ -567,17 +632,11 @@ def ensure_external_credential(
assert credential
return credential.credential_id

def upsert_external(
self,
namespace: str,
repository: str,
external: External,
credentials_map: Optional[Dict[str, str]] = None,
):
request = AddExternalRepositoryRequest.from_external(
namespace, repository, external, credentials_map
def bulk_upsert_external(self, repositories: List[AddExternalRepositoryRequest]):
request = AddExternalRepositoriesRequest(repositories=repositories)
self._perform_request(
"/bulk-add", self.access_token, request, endpoint=self.externals_endpoint
)
self._perform_request("/add", self.access_token, request, endpoint=self.externals_endpoint)


def AuthAPIClient(*args, **kwargs):
Expand Down Expand Up @@ -633,7 +692,7 @@ def _gql(self, query: Dict, endpoint=None, handle_errors=False) -> requests.Resp
return result

@staticmethod
def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Metadata, v1=False):
def _validate_metadata(namespace: str, repository: str, metadata: Metadata):
# Pre-flight validation
if metadata.description and len(metadata.description) > 160:
raise ValueError("The description should be 160 characters or shorter!")
Expand Down Expand Up @@ -669,6 +728,12 @@ def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Meta
if "readme" in variables and isinstance(variables["readme"], dict):
variables["readme"] = variables["readme"]["text"]

return variables

@staticmethod
def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Metadata, v1=False):
variables = GQLAPIClient._validate_metadata(namespace, repository, metadata)

gql_query = _PROFILE_UPSERT_QUERY
if v1:
gql_query = gql_query.replace("createRepoTopicsAgg", "createRepoTopic").replace(
Expand Down Expand Up @@ -706,6 +771,79 @@ def upsert_metadata(self, namespace: str, repository: str, metadata: Metadata):
)
return response

def bulk_upsert_metadata(
self, namespace_list: List[str], repository_list: List[str], metadata_list: List[Metadata]
):
repo_profiles: Dict[str, List[Any]] = dict(
namespaces=namespace_list,
repositories=repository_list,
descriptions=[],
readmes=[],
licenses=[],
metadata=[],
)
repo_sources: Dict[str, List[Any]] = dict(namespaces=[], repositories=[], sources=[])
repo_topics: Dict[str, List[str]] = dict(namespaces=[], repositories=[], topics=[])

# populate mutation payloads
for ind, metadata in enumerate(metadata_list):
validated_metadata = GQLAPIClient._validate_metadata(
namespace_list[ind], repository_list[ind], metadata
)

repo_profiles["descriptions"].append(validated_metadata.get("description"))
repo_profiles["readmes"].append(validated_metadata.get("readme"))
repo_profiles["licenses"].append(validated_metadata.get("license"))
repo_profiles["metadata"].append(validated_metadata.get("metadata"))

# flatten sources, which will be aggregated on the server side
if len(validated_metadata.get("sources", [])) > 0:
for source in validated_metadata["sources"]:
repo_sources["namespaces"].append(namespace_list[ind])
repo_sources["repositories"].append(repository_list[ind])
repo_sources["sources"].append(source)

# flatten topics, which will be aggregated on the server side
if len(validated_metadata.get("topics", [])) > 0:
for topic in validated_metadata["topics"]:
repo_topics["namespaces"].append(namespace_list[ind])
repo_topics["repositories"].append(repository_list[ind])
repo_topics["topics"].append(topic)

self._bulk_upsert_repo_profiles(repo_profiles)
self._bulk_upsert_repo_sources(repo_sources)
self._bulk_upsert_repo_topics(repo_topics)

@handle_gql_errors
def _bulk_upsert_repo_profiles(self, repo_profiles: Dict[str, List[Any]]):
repo_profiles_query = {
"operationName": "BulkUpsertRepoProfilesMutation",
"variables": repo_profiles,
"query": _BULK_UPSERT_REPO_PROFILES_QUERY,
}
response = self._gql(repo_profiles_query)
return response

@handle_gql_errors
def _bulk_upsert_repo_sources(self, repo_sources: Dict[str, List[Any]]):
repo_sources_query = {
"operationName": "BulkUpdateRepoSourcesMutation",
"variables": repo_sources,
"query": _BULK_UPDATE_REPO_SOURCES_QUERY,
}
response = self._gql(repo_sources_query)
return response

@handle_gql_errors
def _bulk_upsert_repo_topics(self, repo_topics: Dict[str, List[str]]):
repo_topics_query = {
"operationName": "BulkUpsertRepoTopicsMutation",
"variables": repo_topics,
"query": _BULK_UPSERT_REPO_TOPICS_QUERY,
}
response = self._gql(repo_topics_query)
return response

def upsert_readme(self, namespace: str, repository: str, readme: str):
return self.upsert_metadata(namespace, repository, Metadata(readme=readme))

Expand Down
12 changes: 12 additions & 0 deletions splitgraph/cloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@ class Credential(BaseModel):
data: Dict[str, Any]


class IngestionSchedule(BaseModel):
schedule: str
enabled = True


class External(BaseModel):
credential_id: Optional[str]
credential: Optional[str]
plugin: str
params: Dict[str, Any]
tables: Dict[str, Table]
schedule: Optional[IngestionSchedule]


# Models for the catalog metadata (description, README, topics etc)
Expand Down Expand Up @@ -226,6 +232,7 @@ class AddExternalRepositoryRequest(BaseModel):
is_live: bool
tables: Optional[Dict[str, ExternalTableRequest]]
credential_id: Optional[str]
schedule: Optional[IngestionSchedule]

@classmethod
def from_external(
Expand Down Expand Up @@ -259,4 +266,9 @@ def from_external(
},
credential_id=credential_id,
is_live=True,
schedule=external.schedule,
)


class AddExternalRepositoriesRequest(BaseModel):
repositories: List[AddExternalRepositoryRequest]
37 changes: 26 additions & 11 deletions splitgraph/commandline/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
from click import wrap_text
from tqdm import tqdm

from splitgraph.cloud.models import Metadata, RepositoriesYAML
from splitgraph.cloud.models import Metadata, RepositoriesYAML, AddExternalRepositoryRequest
from splitgraph.commandline.common import (
ImageType,
RepositoryType,
emit_sql_results,
Color,
)
from splitgraph.commandline.engine import patch_and_save_config, inject_config_into_engines
from splitgraph.core.output import pluralise

# Hardcoded database name for the Splitgraph DDN (ddn instead of sgregistry)
from splitgraph.config.config import get_from_subsection
Expand Down Expand Up @@ -675,16 +676,30 @@ def load_c(remote, readme_dir, repositories_file, limit_repositories):
r for r in repositories if f"{r.namespace}/{r.repository}" in limit_repositories
]

with tqdm(repositories) as t:
for repository in t:
t.set_description(f"{repository.namespace}/{repository.repository}")
if repository.external:
rest_client.upsert_external(
repository.namespace, repository.repository, repository.external, credential_map
)
if repository.metadata:
metadata = _prepare_metadata(repository.metadata, readme_basedir=readme_dir)
gql_client.upsert_metadata(repository.namespace, repository.repository, metadata)
logging.info("Uploading images...")
external_repositories = []
for repository in repositories:
if repository.external:
external_repository = AddExternalRepositoryRequest.from_external(
repository.namespace, repository.repository, repository.external, credential_map
)
external_repositories.append(external_repository)
rest_client.bulk_upsert_external(repositories=external_repositories)
logging.info(f"Uploaded images for {pluralise('repository', len(external_repositories))}")

logging.info("Updating metadata...")
namespace_list = []
repository_list = []
metadata_list = []
for repository in repositories:
if repository.metadata:
namespace_list.append(repository.namespace)
repository_list.append(repository.repository)

metadata = _prepare_metadata(repository.metadata, readme_basedir=readme_dir)
metadata_list.append(metadata)
gql_client.bulk_upsert_metadata(namespace_list, repository_list, metadata_list)
logging.info(f"Updated metadata for {pluralise('repository', len(repository_list))}")


def _build_credential_map(auth_client, credentials=None):
Expand Down
2 changes: 2 additions & 0 deletions splitgraph/core/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def pretty_size(size: Union[int, float]) -> str:

def pluralise(word: str, number: int) -> str:
"""1 banana, 2 bananas"""
if word.endswith("y"):
return "%d %s" % (number, word if number == 1 else word[:-1] + "ies")
return "%d %s%s" % (number, word, "" if number == 1 else "s")


Expand Down
13 changes: 4 additions & 9 deletions splitgraph/engine/postgres/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
# the connection property otherwise
from psycopg2._psycopg import connection as Connection

psycopg2.extensions.register_adapter(dict, Json)
Comment thread
gruuya marked this conversation as resolved.

_AUDIT_SCHEMA = "splitgraph_audit"
_AUDIT_TRIGGER = "resources/static/audit_trigger.sql"
_PUSH_PULL = "resources/static/splitgraph_api.sql"
Expand Down Expand Up @@ -511,7 +513,7 @@ def run_sql(
with connection.cursor(**cursor_kwargs) as cur:
try:
self.notices = []
cur.execute(statement, _convert_vals(arguments) if arguments else None)
cur.execute(statement, arguments)
if connection.notices:
self.notices = connection.notices[:]
del connection.notices[:]
Expand Down Expand Up @@ -603,7 +605,7 @@ def run_sql_batch(
batches = _paginate_by_size(
cur,
statement,
(_convert_vals(a) for a in arguments),
arguments,
max_size=API_MAX_QUERY_LENGTH,
)
for batch in batches:
Expand Down Expand Up @@ -1603,13 +1605,6 @@ def _convert_audit_change(
_KIND = {"I": 0, "D": 1, "U": 2}


def _convert_vals(vals: Any) -> Any:
Comment thread
gruuya marked this conversation as resolved.
"""Psycopg returns jsonb objects as dicts/lists but doesn't actually accept them directly
as a query param (or in the case of lists coerces them into an array.
Hence, we have to wrap them in the Json datatype when doing a dump + load."""
return [Json(v) if isinstance(v, dict) else v for v in vals]


def _generate_where_clause(table: str, cols: List[str], table_2: str) -> Composed:
return SQL(" AND ").join(
SQL("{}.{} = {}.{}").format(
Expand Down
Loading