Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/generate_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
stub_c,
plugins_c,
sync_c,
download_c,
)
from splitgraph.commandline.engine import (
add_engine_c,
Expand Down Expand Up @@ -89,6 +90,7 @@
"cloud status",
"cloud logs",
"cloud upload",
"cloud download",
"cloud plugins",
"cloud stub",
"cloud validate",
Expand Down Expand Up @@ -127,6 +129,7 @@
"cloud status": status_c,
"cloud logs": logs_c,
"cloud upload": upload_c,
"cloud download": download_c,
"cloud sync": sync_c,
"cloud plugins": plugins_c,
"cloud stub": stub_c,
Expand Down
42 changes: 42 additions & 0 deletions splitgraph/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AddExternalCredentialRequest,
AddExternalRepositoriesRequest,
AddExternalRepositoryRequest,
ExportJobStatus,
ExternalResponse,
IngestionJobStatus,
ListExternalCredentialsResponse,
Expand All @@ -48,6 +49,7 @@
BULK_UPSERT_REPO_PROFILES,
BULK_UPSERT_REPO_TOPICS,
CSV_URL,
EXPORT_JOB_STATUS,
FIND_REPO,
GET_PLUGIN,
GET_PLUGINS,
Expand All @@ -58,6 +60,7 @@
PROFILE_UPSERT,
REPO_CONDITIONS,
REPO_PARAMS,
START_EXPORT,
START_LOAD,
)
from splitgraph.config import CONFIG, create_config_dict, get_singleton
Expand Down Expand Up @@ -828,6 +831,45 @@ def get_latest_ingestion_job_status(
status=node.status,
)

def get_export_job_status(self, task_id: str) -> Optional[ExportJobStatus]:
response = self._gql(
{
"query": EXPORT_JOB_STATUS,
"operationName": "ExportJobStatus",
"variables": {"taskId": task_id},
},
handle_errors=True,
)

data = response.json()["data"]["exportJobStatus"]
if not data:
return None
return ExportJobStatus(
task_id=data["taskId"],
started=data["started"],
finished=data["finished"],
status=data["status"],
user_id=data["userId"],
export_format=data["exportFormat"],
output=data["output"],
)

def start_export(self, query: str) -> str:
query = query.strip()
if query.endswith(";"):
logging.warning("The query ends with ';', automatically removing")
query = query[:-1]

response = self._gql(
{
"query": START_EXPORT,
"operationName": "StartExport",
"variables": {"query": query},
},
handle_errors=True,
)
return str(response.json()["data"]["exportQuery"]["id"])

def get_ingestion_job_logs(self, namespace: str, repository: str, task_id: str) -> str:
response = self._gql(
{
Expand Down
13 changes: 11 additions & 2 deletions splitgraph/cloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,23 @@ def to_external(self) -> External:
)


class IngestionJobStatus(BaseModel):
class JobStatus(BaseModel):
task_id: str
started: datetime
finished: Optional[datetime]
is_manual: bool
status: Optional[str]


class IngestionJobStatus(JobStatus):
is_manual: bool


class ExportJobStatus(JobStatus):
user_id: Optional[str]
export_format: str
output: Optional[Dict[str, Any]]


class RepositoryIngestionJobStatusResponse(BaseModel):
class RepositoryIngestionJobStatus(BaseModel):
class Node(BaseModel):
Expand Down
20 changes: 20 additions & 0 deletions splitgraph/cloud/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,23 @@
supportsLoad
}
}"""

START_EXPORT = """mutation StartExport($query: String!) {
exportQuery(query: $query, exportFormat: "csv") {
id
}
}
"""

EXPORT_JOB_STATUS = """query ExportJobStatus($taskId: UUID!) {
exportJobStatus(taskId: $taskId) {
taskId
started
finished
status
userId
exportFormat
output
}
}
"""
105 changes: 50 additions & 55 deletions splitgraph/commandline/cloud.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Command line routines related to registering/setting up connections to the Splitgraph registry."""
import hashlib
import itertools
import logging
import os
import shutil
import string
import subprocess
import sys
import time
from copy import copy
from glob import glob
from pathlib import Path
Expand All @@ -18,7 +16,14 @@
from click import wrap_text
from splitgraph.cloud.models import AddExternalRepositoryRequest
from splitgraph.cloud.project.models import Metadata, SplitgraphYAML
from splitgraph.commandline.common import ImageType, RepositoryType, emit_sql_results
from splitgraph.commandline.common import (
ImageType,
RepositoryType,
download_file,
emit_sql_results,
upload_file,
wait_for_job,
)
from splitgraph.commandline.engine import inject_config_into_engines
from splitgraph.config.config import get_from_subsection
from splitgraph.config.management import patch_and_save_config
Expand Down Expand Up @@ -915,10 +920,8 @@ def upload_c(remote, file_format, repository, files):
This uses the upload API to add data like CSV files to a remote Splitgraph instance,
trigger a load and wait for the data to load into a repository.
"""
import requests

from splitgraph.cloud import GQLAPIClient
from tqdm import tqdm
from tqdm.utils import CallbackIOWrapper

client = GQLAPIClient(remote)

Expand All @@ -932,12 +935,7 @@ def upload_c(remote, file_format, repository, files):
for file in files:
upload, download = client.get_csv_upload_download_urls()
download_urls.append(download)
size = os.fstat(file.fileno()).st_size

with tqdm(total=size, unit="B", unit_scale=True, unit_divisor=1024) as t:
wrapped_file = CallbackIOWrapper(t.update, file, "read")
t.set_description(os.path.basename(file.name))
requests.put(upload, data=wrapped_file)
upload_file(file, upload)

task_id = client.start_csv_load(
repository.namespace, repository.repository, download_urls, table_names
Expand All @@ -953,53 +951,49 @@ def upload_c(remote, file_format, repository, files):
click.echo(f' sgr cloud sql \'SELECT * FROM "{repository}"."{table_names[0]}"\'') # nosec


GQL_POLL_TIME = 5
SPINNER_FREQUENCY = 10
def wait_for_load(client: "GQLAPIClient", namespace: str, repository: str, task_id: str) -> None:
final_status = wait_for_job(
task_id, lambda: client.get_latest_ingestion_job_status(namespace, repository)
)

if final_status.status == "FAILURE":
logs = client.get_ingestion_job_logs(
namespace=namespace, repository=repository, task_id=task_id
)
click.echo(logs)
raise ValueError("Error loading data")

def wait_for_load(client: "GQLAPIClient", namespace: str, repository: str, task_id: str) -> None:
from splitgraph.config import SG_CMD_ASCII

chars = ["|", "/", "-", "\\"] if SG_CMD_ASCII else ["⣾", "⣽", "⣻", "⢿", "⡿", "⣟", "⣯", "⣷"]
spinner = itertools.cycle(chars)

interval = 0
poll_interval = max(int(SPINNER_FREQUENCY * GQL_POLL_TIME), 1)
status_str: Optional[str] = None
while True:
if interval % poll_interval == 0:
status = client.get_latest_ingestion_job_status(namespace, repository)
if not status:
raise AssertionError("Ingestion job not found")
if status.task_id != task_id:
raise AssertionError("Unexpected task ID")

if not sys.stdout.isatty() and status_str != status.status:
click.echo(
f" ({status.status}) Loading {namespace}/{repository}, task ID {task_id}",
)

status_str = status.status
def wait_for_download(client: "GQLAPIClient", task_id: str) -> str:
final_status = wait_for_job(task_id, lambda: client.get_export_job_status(task_id))
if final_status.status == "SUCCESS":
assert final_status.output
return str(final_status.output["url"])
else:
raise ValueError(
"Error running query. This could be due to a syntax error. "
"Run the query interactively with `sgr cloud sql` to investigate the cause."
)

if sys.stdout.isatty():
click.echo(f"\033[2K\033[1G{next(spinner)}", nl=False)
click.echo(
f" ({status_str}) Loading {namespace}/{repository}, task ID {task_id}", nl=False
)
sys.stdout.flush()
time.sleep(1.0 / SPINNER_FREQUENCY)
interval += 1

if status_str == "SUCCESS":
click.echo()
return
if status_str == "FAILURE":
click.echo()
logs = client.get_ingestion_job_logs(
namespace=namespace, repository=repository, task_id=task_id
)
click.echo(logs)
raise ValueError("Error loading data")

@click.command("download")
@click.option("--remote", default="data.splitgraph.com", help="Name of the remote registry to use.")
@click.option("--file-format", default="csv", type=click.Choice(["csv"]))
@click.argument("query", type=str)
@click.argument("output_path", type=str, default=None, required=False)
def download_c(remote, file_format, query, output_path):
"""
Download query results from Splitgraph.

This runs a query on Splitgraph Cloud and exports the results into a csv.gz format.
"""
from splitgraph.cloud import GQLAPIClient

client = GQLAPIClient(remote)

task_id = client.start_export(query=query)
download_url = wait_for_download(client, task_id)
download_file(download_url, output_path)


@click.command("sync")
Expand Down Expand Up @@ -1205,6 +1199,7 @@ def cloud_c():
cloud_c.add_command(status_c)
cloud_c.add_command(logs_c)
cloud_c.add_command(upload_c)
cloud_c.add_command(download_c)
cloud_c.add_command(sync_c)
cloud_c.add_command(plugins_c)
cloud_c.add_command(stub_c)
Expand Down
Loading