Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/src/Multicorn
Submodule Multicorn updated 1 files
+2 −2 src/python.c
2 changes: 1 addition & 1 deletion examples/cross-db-analytics/mounting/matomo.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"remote_schema": "matomo",
"dbname": "matomo",
"tables": {
"matomo_access": {
"schema": {
Expand Down
16 changes: 11 additions & 5 deletions examples/custom_fdw/src/hn_fdw/mount.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Dict, Mapping
from typing import Dict, Optional

from splitgraph.core.types import TableColumn, TableSchema
from splitgraph.core.types import (
TableColumn,
TableInfo,
IntrospectionResult,
)

# Define the schema of the foreign table we wish to create
# We're only going to be fetching stories, so limit the columns to the ones that
Expand Down Expand Up @@ -54,7 +58,9 @@ def get_name(cls) -> str:
def get_description(cls) -> str:
return "Query Hacker News stories through the Firebase API"

def get_table_options(self, table_name: str) -> Mapping[str, str]:
def get_table_options(
self, table_name: str, tables: Optional[TableInfo] = None
) -> Dict[str, str]:
# Pass the endpoint name into the FDW
return {"table": table_name}

Expand All @@ -69,7 +75,7 @@ def get_server_options(self):
"wrapper": "hn_fdw.fdw.HNForeignDataWrapper",
}

def introspect(self) -> Dict[str, TableSchema]:
def introspect(self) -> IntrospectionResult:
# Return a list of this FDW's tables and their schema.
endpoints = self.params.get("endpoints") or _all_endpoints
return {e: _story_schema_spec for e in endpoints}
return {e: (_story_schema_spec, {}) for e in endpoints}
4 changes: 2 additions & 2 deletions examples/dbt_two_databases/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ $ sgr mount mongo_fdw order_data -c originro:originpass@mongo:27017 -o @- <<EOF
{
"orders":
{
"db": "origindb",
"coll": "orders",
"database": "origindb",
"collection": "orders",
"schema":
{
"name": "text",
Expand Down
4 changes: 2 additions & 2 deletions examples/import-from-mongo/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
sgr mount mongo_fdw staging -c originro:originpass@mongo:27017 -o '{"tables": {"stuff": {
"options":
{
"db": "origindb",
"coll": "stuff"
"database": "origindb",
"collection": "stuff"
},
"schema": {
"name": "text",
Expand Down
4 changes: 2 additions & 2 deletions examples/import-from-mongo/mongo_import.splitfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM MOUNT mongo_fdw originro:originpass@mongo:27017 '{"tables": {"stuff": {
"options": {
"db": "origindb",
"coll": "stuff"
"database": "origindb",
"collection": "stuff"
},
"schema": {
"name": "text",
Expand Down
8 changes: 1 addition & 7 deletions splitgraph/commandline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from splitgraph.commandline.mount import mount_c
from splitgraph.commandline.push_pull import pull_c, clone_c, push_c, upstream_c
from splitgraph.commandline.splitfile import build_c, provenance_c, rebuild_c, dependents_c
from splitgraph.exceptions import get_exception_name
from splitgraph.ingestion.singer.commandline import singer_group

logger = logging.getLogger()
Expand All @@ -61,13 +62,6 @@ def emit(self, record):
logger.handlers[0].formatter = ColorFormatter()


def get_exception_name(o):
module = o.__class__.__module__
if module is None or module == str.__class__.__module__:
return o.__class__.__name__
return module + "." + o.__class__.__name__


def _patch_wrap_text():
# Patch click's formatter to strip ```
# etc which are used to format help text on the website.
Expand Down
3 changes: 2 additions & 1 deletion splitgraph/core/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import Optional, Tuple, cast, TYPE_CHECKING, List, TypeVar, Dict, DefaultDict

from packaging.version import Version
from psycopg2.sql import SQL, Identifier

from splitgraph.core.sql import select, insert
Expand Down Expand Up @@ -116,7 +117,7 @@ def source_files_to_apply(
) -> Tuple[List[str], str]:
""" Get the ordered list of .sql files to apply to the database"""
version_tuples = get_version_tuples(schema_files)
target_version = target_version or max([v[1] for v in version_tuples])
target_version = target_version or max([v[1] for v in version_tuples], key=Version)
if static:
# For schemata like splitgraph_api which we want to apply in any
# case, bypassing the upgrade mechanism, we just run the latest
Expand Down
63 changes: 54 additions & 9 deletions splitgraph/core/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Dict, Tuple, Any, NamedTuple, Optional, List, Sequence, Union
from typing import Dict, Tuple, Any, NamedTuple, Optional, List, Sequence, Union, TypeVar

Changeset = Dict[Tuple[str, ...], Tuple[bool, Dict[str, Any], Dict[str, Any]]]

Expand All @@ -18,22 +18,67 @@ class TableColumn(NamedTuple):
SourcesList = List[Dict[str, str]]
ProvenanceLine = Dict[str, Union[str, List[str], List[bool], SourcesList]]

# Ingestion-related params
Credentials = Dict[str, Any]
Params = Dict[str, Any]
TableParams = Dict[str, Any]
TableInfo = Union[List[str], Dict[str, Tuple[TableSchema, TableParams]]]
SyncState = Dict[str, Any]


class MountError(NamedTuple):
table_name: str
error: str
error_text: str


PreviewResult = Dict[str, Union[MountError, List[Dict[str, Any]]]]
IntrospectionResult = Dict[str, Union[Tuple[TableSchema, TableParams], MountError]]

T = TypeVar("T")


def unwrap(
result: Dict[str, Union[MountError, T]],
) -> Tuple[Dict[str, T], Dict[str, MountError]]:
good = {}
bad = {}
for k, v in result.items():
if isinstance(v, MountError):
bad[k] = v
else:
good[k] = v
return good, bad


class Comparable(metaclass=ABCMeta):
@abstractmethod
def __lt__(self, other: Any) -> bool:
...


def dict_to_tableschema(tables: Dict[str, Dict[str, Any]]) -> Dict[str, TableSchema]:
def dict_to_table_schema_params(
tables: Dict[str, Dict[str, Any]]
) -> Dict[str, Tuple[TableSchema, TableParams]]:
return {
t: [
TableColumn(i + 1, cname, ctype, False, None)
for (i, (cname, ctype)) in enumerate(ts.items())
]
for t, ts in tables.items()
t: (
[
TableColumn(i + 1, cname, ctype, False, None)
for (i, (cname, ctype)) in enumerate(tsp["schema"].items())
],
tsp.get("options", {}),
)
for t, tsp in tables.items()
}


def tableschema_to_dict(tables: Dict[str, TableSchema]) -> Dict[str, Dict[str, str]]:
return {t: {c.name: c.pg_type for c in ts} for t, ts in tables.items()}
def table_schema_params_to_dict(
tables: Dict[str, Tuple[TableSchema, TableParams]]
) -> Dict[str, Dict[str, Dict[str, str]]]:
return {
t: {
"schema": {c.name: c.pg_type for c in ts},
"options": {tpk: str(tpv) for tpk, tpv in tp.items()},
}
for t, (ts, tp) in tables.items()
}
17 changes: 12 additions & 5 deletions splitgraph/engine/postgres/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def __init__(
self.registry = registry
self.in_fdw = in_fdw

"""List of notices issued by the server during the previous execution of run_sql."""
self.notices: List[str] = []

if conn_params:
self.conn_params = conn_params

Expand Down Expand Up @@ -505,13 +508,17 @@ def run_sql(
while True:
with connection.cursor(**cursor_kwargs) as cur:
try:
self.notices = []
cur.execute(statement, _convert_vals(arguments) if arguments else None)
if connection.notices and self.registry:
# Forward NOTICE messages from the registry back to the user
# (e.g. to nag them to upgrade etc).
for notice in connection.notices:
logging.info("%s says: %s", self.name, notice)
if connection.notices:
self.notices = connection.notices[:]
del connection.notices[:]

if self.registry:
# Forward NOTICE messages from the registry back to the user
# (e.g. to nag them to upgrade etc).
for notice in self.notices:
logging.info("%s says: %s", self.name, notice)
except Exception as e:
# Rollback the transaction (to a savepoint if we're inside the savepoint() context manager)
self.rollback()
Expand Down
11 changes: 10 additions & 1 deletion splitgraph/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class IncompleteObjectDownloadError(SplitGraphError):
cleanup and reraise `reason` at the earliest opportunity."""

def __init__(
self, reason: Optional[BaseException], successful_objects: List[str],
self,
reason: Optional[BaseException],
successful_objects: List[str],
):
self.reason = reason
self.successful_objects = successful_objects
Expand All @@ -119,3 +121,10 @@ class GQLUnauthenticatedError(GQLAPIError):

class GQLRepoDoesntExistError(GQLAPIError):
"""Repository doesn't exist"""


def get_exception_name(o):
module = o.__class__.__module__
if module is None or module == str.__class__.__module__:
return o.__class__.__name__
return module + "." + o.__class__.__name__
60 changes: 35 additions & 25 deletions splitgraph/hooks/data_source/base.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import json
from abc import ABC, abstractmethod
from random import getrandbits
from typing import Dict, Any, Union, List, Optional, TYPE_CHECKING, cast, Tuple
from typing import Dict, Any, Optional, TYPE_CHECKING, cast, Tuple, List

from psycopg2._json import Json
from psycopg2.sql import SQL, Identifier

from splitgraph.core.engine import repository_exists
from splitgraph.core.image import Image
from splitgraph.core.types import TableSchema, TableColumn
from splitgraph.core.types import (
TableColumn,
Credentials,
Params,
TableInfo,
SyncState,
MountError,
IntrospectionResult,
)
from splitgraph.engine import ResultShape

if TYPE_CHECKING:
from splitgraph.engine.postgres.engine import PostgresEngine
from splitgraph.core.repository import Repository

Credentials = Dict[str, Any]
Params = Dict[str, Any]
TableInfo = Union[List[str], Dict[str, TableSchema]]
SyncState = Dict[str, Any]
PreviewResult = Dict[str, Union[str, List[Dict[str, Any]]]]


INGESTION_STATE_TABLE = "_sg_ingestion_state"
INGESTION_STATE_SCHEMA = [
TableColumn(1, "timestamp", "timestamp", True, None),
Expand All @@ -32,6 +32,7 @@
class DataSource(ABC):
params_schema: Dict[str, Any]
credentials_schema: Dict[str, Any]
table_params_schema: Dict[str, Any]

supports_mount = False
supports_sync = False
Expand All @@ -47,29 +48,34 @@ def get_name(cls) -> str:
def get_description(cls) -> str:
raise NotImplementedError

def __init__(self, engine: "PostgresEngine", credentials: Credentials, params: Params):
def __init__(
self,
engine: "PostgresEngine",
credentials: Credentials,
params: Params,
tables: Optional[TableInfo] = None,
):
import jsonschema

self.engine = engine

if "tables" in params:
tables = params.pop("tables")

jsonschema.validate(instance=credentials, schema=self.credentials_schema)
jsonschema.validate(instance=params, schema=self.params_schema)

self.credentials = credentials
self.params = params

if isinstance(tables, dict):
for _, table_params in tables.values():
jsonschema.validate(instance=table_params, schema=self.table_params_schema)

self.tables = tables

@abstractmethod
def introspect(self) -> Dict[str, TableSchema]:
# TODO here: dict str -> [tableschema, dict of suggested options]
# params -- add table options as a separate field?
# separate table schema

# When going through the repo addition loop:
# * add separate options field mapping table names to options
# * return separate schema for table options
# * table options are optional for introspection
# * how to introspect: do import foreign schema; check fdw params
# *
def introspect(self) -> IntrospectionResult:
raise NotImplementedError


Expand All @@ -78,8 +84,11 @@ class MountableDataSource(DataSource, ABC):

@abstractmethod
def mount(
self, schema: str, tables: Optional[TableInfo] = None, overwrite: bool = True,
):
self,
schema: str,
tables: Optional[TableInfo] = None,
overwrite: bool = True,
) -> Optional[List[MountError]]:
"""Instantiate the data source as foreign tables in a schema"""
raise NotImplementedError

Expand All @@ -98,7 +107,8 @@ def load(self, repository: "Repository", tables: Optional[TableInfo] = None) ->
image_hash = "{:064x}".format(getrandbits(256))
tmp_schema = "{:064x}".format(getrandbits(256))
repository.images.add(
parent_id=None, image=image_hash,
parent_id=None,
image=image_hash,
)
repository.object_engine.create_schema(tmp_schema)

Expand Down
Loading