diff --git a/engine/src/Multicorn b/engine/src/Multicorn index 5edb6814..0f95b96c 160000 --- a/engine/src/Multicorn +++ b/engine/src/Multicorn @@ -1 +1 @@ -Subproject commit 5edb681438530479516c4eb60a2b2e2d5c707efb +Subproject commit 0f95b96c7a92d4cac8ae83f9e24fe82111219395 diff --git a/examples/cross-db-analytics/mounting/matomo.json b/examples/cross-db-analytics/mounting/matomo.json index f3a304e6..52ee55fc 100644 --- a/examples/cross-db-analytics/mounting/matomo.json +++ b/examples/cross-db-analytics/mounting/matomo.json @@ -1,5 +1,5 @@ { - "remote_schema": "matomo", + "dbname": "matomo", "tables": { "matomo_access": { "schema": { diff --git a/examples/custom_fdw/src/hn_fdw/mount.py b/examples/custom_fdw/src/hn_fdw/mount.py index bad22038..cb25cb99 100644 --- a/examples/custom_fdw/src/hn_fdw/mount.py +++ b/examples/custom_fdw/src/hn_fdw/mount.py @@ -1,6 +1,10 @@ -from typing import Dict, Mapping +from typing import Dict, Optional -from splitgraph.core.types import TableColumn, TableSchema +from splitgraph.core.types import ( + TableColumn, + TableInfo, + IntrospectionResult, +) # Define the schema of the foreign table we wish to create # We're only going to be fetching stories, so limit the columns to the ones that @@ -54,7 +58,9 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Query Hacker News stories through the Firebase API" - def get_table_options(self, table_name: str) -> Mapping[str, str]: + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: # Pass the endpoint name into the FDW return {"table": table_name} @@ -69,7 +75,7 @@ def get_server_options(self): "wrapper": "hn_fdw.fdw.HNForeignDataWrapper", } - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> IntrospectionResult: # Return a list of this FDW's tables and their schema. endpoints = self.params.get("endpoints") or _all_endpoints - return {e: _story_schema_spec for e in endpoints} + return {e: (_story_schema_spec, {}) for e in endpoints} diff --git a/examples/dbt_two_databases/README.md b/examples/dbt_two_databases/README.md index a25224bd..51bbdefb 100644 --- a/examples/dbt_two_databases/README.md +++ b/examples/dbt_two_databases/README.md @@ -72,8 +72,8 @@ $ sgr mount mongo_fdw order_data -c originro:originpass@mongo:27017 -o @- < Tuple[List[str], str]: """ Get the ordered list of .sql files to apply to the database""" version_tuples = get_version_tuples(schema_files) - target_version = target_version or max([v[1] for v in version_tuples]) + target_version = target_version or max([v[1] for v in version_tuples], key=Version) if static: # For schemata like splitgraph_api which we want to apply in any # case, bypassing the upgrade mechanism, we just run the latest diff --git a/splitgraph/core/types.py b/splitgraph/core/types.py index 385a2738..f97ba569 100644 --- a/splitgraph/core/types.py +++ b/splitgraph/core/types.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Dict, Tuple, Any, NamedTuple, Optional, List, Sequence, Union +from typing import Dict, Tuple, Any, NamedTuple, Optional, List, Sequence, Union, TypeVar Changeset = Dict[Tuple[str, ...], Tuple[bool, Dict[str, Any], Dict[str, Any]]] @@ -18,6 +18,38 @@ class TableColumn(NamedTuple): SourcesList = List[Dict[str, str]] ProvenanceLine = Dict[str, Union[str, List[str], List[bool], SourcesList]] +# Ingestion-related params +Credentials = Dict[str, Any] +Params = Dict[str, Any] +TableParams = Dict[str, Any] +TableInfo = Union[List[str], Dict[str, Tuple[TableSchema, TableParams]]] +SyncState = Dict[str, Any] + + +class MountError(NamedTuple): + table_name: str + error: str + error_text: str + + +PreviewResult = Dict[str, Union[MountError, List[Dict[str, Any]]]] +IntrospectionResult = Dict[str, Union[Tuple[TableSchema, TableParams], MountError]] + +T = TypeVar("T") + + +def unwrap( + result: Dict[str, Union[MountError, T]], +) -> Tuple[Dict[str, T], Dict[str, MountError]]: + good = {} + bad = {} + for k, v in result.items(): + if isinstance(v, MountError): + bad[k] = v + else: + good[k] = v + return good, bad + class Comparable(metaclass=ABCMeta): @abstractmethod @@ -25,15 +57,28 @@ def __lt__(self, other: Any) -> bool: ... -def dict_to_tableschema(tables: Dict[str, Dict[str, Any]]) -> Dict[str, TableSchema]: +def dict_to_table_schema_params( + tables: Dict[str, Dict[str, Any]] +) -> Dict[str, Tuple[TableSchema, TableParams]]: return { - t: [ - TableColumn(i + 1, cname, ctype, False, None) - for (i, (cname, ctype)) in enumerate(ts.items()) - ] - for t, ts in tables.items() + t: ( + [ + TableColumn(i + 1, cname, ctype, False, None) + for (i, (cname, ctype)) in enumerate(tsp["schema"].items()) + ], + tsp.get("options", {}), + ) + for t, tsp in tables.items() } -def tableschema_to_dict(tables: Dict[str, TableSchema]) -> Dict[str, Dict[str, str]]: - return {t: {c.name: c.pg_type for c in ts} for t, ts in tables.items()} +def table_schema_params_to_dict( + tables: Dict[str, Tuple[TableSchema, TableParams]] +) -> Dict[str, Dict[str, Dict[str, str]]]: + return { + t: { + "schema": {c.name: c.pg_type for c in ts}, + "options": {tpk: str(tpv) for tpk, tpv in tp.items()}, + } + for t, (ts, tp) in tables.items() + } diff --git a/splitgraph/engine/postgres/engine.py b/splitgraph/engine/postgres/engine.py index 0cbc3327..93e34617 100644 --- a/splitgraph/engine/postgres/engine.py +++ b/splitgraph/engine/postgres/engine.py @@ -228,6 +228,9 @@ def __init__( self.registry = registry self.in_fdw = in_fdw + """List of notices issued by the server during the previous execution of run_sql.""" + self.notices: List[str] = [] + if conn_params: self.conn_params = conn_params @@ -505,13 +508,17 @@ def run_sql( while True: with connection.cursor(**cursor_kwargs) as cur: try: + self.notices = [] cur.execute(statement, _convert_vals(arguments) if arguments else None) - if connection.notices and self.registry: - # Forward NOTICE messages from the registry back to the user - # (e.g. to nag them to upgrade etc). - for notice in connection.notices: - logging.info("%s says: %s", self.name, notice) + if connection.notices: + self.notices = connection.notices[:] del connection.notices[:] + + if self.registry: + # Forward NOTICE messages from the registry back to the user + # (e.g. to nag them to upgrade etc). + for notice in self.notices: + logging.info("%s says: %s", self.name, notice) except Exception as e: # Rollback the transaction (to a savepoint if we're inside the savepoint() context manager) self.rollback() diff --git a/splitgraph/exceptions.py b/splitgraph/exceptions.py index aae209c7..55db97fa 100644 --- a/splitgraph/exceptions.py +++ b/splitgraph/exceptions.py @@ -99,7 +99,9 @@ class IncompleteObjectDownloadError(SplitGraphError): cleanup and reraise `reason` at the earliest opportunity.""" def __init__( - self, reason: Optional[BaseException], successful_objects: List[str], + self, + reason: Optional[BaseException], + successful_objects: List[str], ): self.reason = reason self.successful_objects = successful_objects @@ -119,3 +121,10 @@ class GQLUnauthenticatedError(GQLAPIError): class GQLRepoDoesntExistError(GQLAPIError): """Repository doesn't exist""" + + +def get_exception_name(o): + module = o.__class__.__module__ + if module is None or module == str.__class__.__module__: + return o.__class__.__name__ + return module + "." + o.__class__.__name__ diff --git a/splitgraph/hooks/data_source/base.py b/splitgraph/hooks/data_source/base.py index 037f102a..430c3b60 100644 --- a/splitgraph/hooks/data_source/base.py +++ b/splitgraph/hooks/data_source/base.py @@ -1,27 +1,27 @@ -import json from abc import ABC, abstractmethod from random import getrandbits -from typing import Dict, Any, Union, List, Optional, TYPE_CHECKING, cast, Tuple +from typing import Dict, Any, Optional, TYPE_CHECKING, cast, Tuple, List from psycopg2._json import Json from psycopg2.sql import SQL, Identifier from splitgraph.core.engine import repository_exists from splitgraph.core.image import Image -from splitgraph.core.types import TableSchema, TableColumn +from splitgraph.core.types import ( + TableColumn, + Credentials, + Params, + TableInfo, + SyncState, + MountError, + IntrospectionResult, +) from splitgraph.engine import ResultShape if TYPE_CHECKING: from splitgraph.engine.postgres.engine import PostgresEngine from splitgraph.core.repository import Repository -Credentials = Dict[str, Any] -Params = Dict[str, Any] -TableInfo = Union[List[str], Dict[str, TableSchema]] -SyncState = Dict[str, Any] -PreviewResult = Dict[str, Union[str, List[Dict[str, Any]]]] - - INGESTION_STATE_TABLE = "_sg_ingestion_state" INGESTION_STATE_SCHEMA = [ TableColumn(1, "timestamp", "timestamp", True, None), @@ -32,6 +32,7 @@ class DataSource(ABC): params_schema: Dict[str, Any] credentials_schema: Dict[str, Any] + table_params_schema: Dict[str, Any] supports_mount = False supports_sync = False @@ -47,29 +48,34 @@ def get_name(cls) -> str: def get_description(cls) -> str: raise NotImplementedError - def __init__(self, engine: "PostgresEngine", credentials: Credentials, params: Params): + def __init__( + self, + engine: "PostgresEngine", + credentials: Credentials, + params: Params, + tables: Optional[TableInfo] = None, + ): import jsonschema self.engine = engine + if "tables" in params: + tables = params.pop("tables") + jsonschema.validate(instance=credentials, schema=self.credentials_schema) jsonschema.validate(instance=params, schema=self.params_schema) self.credentials = credentials self.params = params + if isinstance(tables, dict): + for _, table_params in tables.values(): + jsonschema.validate(instance=table_params, schema=self.table_params_schema) + + self.tables = tables + @abstractmethod - def introspect(self) -> Dict[str, TableSchema]: - # TODO here: dict str -> [tableschema, dict of suggested options] - # params -- add table options as a separate field? - # separate table schema - - # When going through the repo addition loop: - # * add separate options field mapping table names to options - # * return separate schema for table options - # * table options are optional for introspection - # * how to introspect: do import foreign schema; check fdw params - # * + def introspect(self) -> IntrospectionResult: raise NotImplementedError @@ -78,8 +84,11 @@ class MountableDataSource(DataSource, ABC): @abstractmethod def mount( - self, schema: str, tables: Optional[TableInfo] = None, overwrite: bool = True, - ): + self, + schema: str, + tables: Optional[TableInfo] = None, + overwrite: bool = True, + ) -> Optional[List[MountError]]: """Instantiate the data source as foreign tables in a schema""" raise NotImplementedError @@ -98,7 +107,8 @@ def load(self, repository: "Repository", tables: Optional[TableInfo] = None) -> image_hash = "{:064x}".format(getrandbits(256)) tmp_schema = "{:064x}".format(getrandbits(256)) repository.images.add( - parent_id=None, image=image_hash, + parent_id=None, + image=image_hash, ) repository.object_engine.create_schema(tmp_schema) diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index 29139143..0ce23c93 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -1,17 +1,25 @@ +import json import logging +import re from abc import ABC, abstractmethod from copy import deepcopy -from typing import Optional, Mapping, Dict, List, Any, cast, Union, TYPE_CHECKING +from typing import Optional, Mapping, Dict, List, Any, cast, TYPE_CHECKING, Tuple import psycopg2 from psycopg2.sql import SQL, Identifier -from splitgraph.core.types import dict_to_tableschema, TableSchema, TableColumn -from splitgraph.hooks.data_source.base import ( - Credentials, - Params, +from splitgraph.core.types import ( + TableSchema, + TableColumn, + dict_to_table_schema_params, + TableParams, TableInfo, PreviewResult, + MountError, + IntrospectionResult, +) +from splitgraph.exceptions import get_exception_name +from splitgraph.hooks.data_source.base import ( MountableDataSource, LoadableDataSource, ) @@ -20,24 +28,17 @@ from splitgraph.engine.postgres.engine import PostgresEngine -_table_options_schema = { - "type": "object", - "additionalProperties": { - "options": {"type": "object", "additionalProperties": {"type": "string"}}, - }, -} - - class ForeignDataWrapperDataSource(MountableDataSource, LoadableDataSource, ABC): - credentials_schema = { + credentials_schema: Dict[str, Any] = { "type": "object", - "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, - "required": ["username", "password"], } - params_schema = { + params_schema: Dict[str, Any] = { + "type": "object", + } + + table_params_schema: Dict[str, Any] = { "type": "object", - "properties": {"tables": _table_options_schema}, } commandline_help: str = "" @@ -46,15 +47,6 @@ class ForeignDataWrapperDataSource(MountableDataSource, LoadableDataSource, ABC) supports_mount = True supports_load = True - def __init__( - self, - engine: "PostgresEngine", - credentials: Optional[Credentials] = None, - params: Optional[Params] = None, - ): - self.tables: Optional[TableInfo] = None - super().__init__(engine, credentials or {}, params or {}) - @classmethod def from_commandline(cls, engine, commandline_kwargs) -> "ForeignDataWrapperDataSource": """Instantiate an FDW data source from commandline arguments.""" @@ -64,28 +56,19 @@ def from_commandline(cls, engine, commandline_kwargs) -> "ForeignDataWrapperData # By convention, the "tables" object can be: # * A list of tables to import - # * A dictionary table_name -> {"schema": schema, **mount_options} - table_kwargs = params.pop("tables", None) - tables: Optional[TableInfo] - - if isinstance(table_kwargs, dict): - tables = dict_to_tableschema( - {t: to["schema"] for t, to in table_kwargs.items() if "schema" in to} - ) - params["tables"] = { - t: {"options": to["options"]} for t, to in table_kwargs.items() if "options" in to - } - elif isinstance(table_kwargs, list): - tables = table_kwargs - else: - tables = None + # * A dictionary table_name -> {"schema": schema, "options": table options} + tables = params.pop("tables", None) + if isinstance(tables, dict): + tables = dict_to_table_schema_params(tables) + # Extract credentials from the cmdline params credentials: Dict[str, Any] = {} for k in cast(Dict[str, Any], cls.credentials_schema["properties"]).keys(): if k in params: credentials[k] = params[k] - result = cls(engine, credentials, params) - result.tables = tables + + result = cls(engine, credentials, params, tables) + return result def get_user_options(self) -> Mapping[str, str]: @@ -94,10 +77,19 @@ def get_user_options(self) -> Mapping[str, str]: def get_server_options(self) -> Mapping[str, str]: return {} - def get_table_options(self, table_name: str) -> Mapping[str, str]: + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: + tables = tables or self.tables + + if not isinstance(tables, dict): + return {} + return { k: str(v) - for k, v in self.params.get("tables", {}).get(table_name, {}).get("options", {}).items() + for k, v in tables.get(table_name, cast(Tuple[TableSchema, TableParams], ({}, {})))[ + 1 + ].items() } def get_table_schema(self, table_name: str, table_schema: TableSchema) -> TableSchema: @@ -113,8 +105,11 @@ def get_fdw_name(self): pass def mount( - self, schema: str, tables: Optional[TableInfo] = None, overwrite: bool = True, - ): + self, + schema: str, + tables: Optional[TableInfo] = None, + overwrite: bool = True, + ) -> Optional[List[MountError]]: tables = tables or self.tables or [] fdw = self.get_fdw_name() @@ -131,9 +126,15 @@ def mount( # Allow mounting tables into existing schemas self.engine.run_sql(SQL("CREATE SCHEMA IF NOT EXISTS {}").format(Identifier(schema))) - self._create_foreign_tables(schema, server_id, tables) + errors = self._create_foreign_tables(schema, server_id, tables) + if errors: + for e in errors: + logging.warning("Error mounting %s: %s: %s", e.table_name, e.error, e.error_text) + return errors - def _create_foreign_tables(self, schema, server_id, tables): + def _create_foreign_tables( + self, schema: str, server_id: str, tables: TableInfo + ) -> List[MountError]: if isinstance(tables, list): try: remote_schema = self.get_remote_schema_name() @@ -141,33 +142,50 @@ def _create_foreign_tables(self, schema, server_id, tables): raise NotImplementedError( "The FDW does not support IMPORT FOREIGN SCHEMA! Pass a tables dictionary." ) - _import_foreign_schema(self.engine, schema, remote_schema, server_id, tables) + return import_foreign_schema(self.engine, schema, remote_schema, server_id, tables) else: - for table_name, table_schema in tables.items(): + for table_name, (table_schema, _) in tables.items(): logging.info("Mounting table %s", table_name) query, args = create_foreign_table( schema, server_id, table_name, schema_spec=self.get_table_schema(table_name, table_schema), - extra_options=self.get_table_options(table_name), + extra_options=self.get_table_options(table_name, tables), ) self.engine.run_sql(query, args) + return [] - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> IntrospectionResult: # Ability to override introspection by e.g. contacting the remote database without having # to mount the actual table. By default, just call out into mount() # Local import here since this data source gets imported by the commandline entry point from splitgraph.core.common import get_temporary_table_id + import jsonschema tmp_schema = get_temporary_table_id() try: - self.mount(tmp_schema) - result = { - t: self.engine.get_full_table_schema(tmp_schema, t) + mount_errors = self.mount(tmp_schema) or [] + + table_options = dict(self._get_foreign_table_options(tmp_schema)) + + # Sanity check for adapters: validate the foreign table options that we get back + # to make sure they're still appropriate. + for v in table_options.values(): + jsonschema.validate(v, self.table_params_schema) + + result: IntrospectionResult = { + t: ( + self.engine.get_full_table_schema(tmp_schema, t), + cast(TableParams, table_options.get(t, {})), + ) for t in self.engine.get_all_tables(tmp_schema) } + + # Add errors to the result as well + for error in mount_errors: + result[error.table_name] = error return result finally: self.engine.rollback() @@ -189,7 +207,36 @@ def _preview_table(self, schema: str, table: str, limit: int = 10) -> List[Dict[ ) return result_json - def preview(self, schema: Dict[str, TableSchema]) -> PreviewResult: + def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, Any]]]: + """ + Get a list of options the foreign tables in this schema were instantiated with + :return: List of tables and their options + """ + # We use this to suggest table options during introspection. With FDWs, we do this by + # mounting the data source first (using IMPORT FOREIGN SCHEMA) and then scraping the + # foreign table options it inferred. + + # Downstream FDWs can override this: if they remap some table options into different + # FDW options (e.g. "remote_schema" on the data source side turns into "schema" on the + # FDW side), they have to map them back in this routine (otherwise the introspection will + # suggest "schema", which is wrong. + + # This is also used for type remapping, since table params can only be strings on PG. + # TODO: this will lead to a bunch of serialization/deserialization code duplication + # in data sources. One potential solution is, at least in Multicorn-backed wrappers + # that we control, passing a JSON as a single table param instead. + + return cast( + List[Tuple[str, Dict[str, str]]], + self.engine.run_sql( + "SELECT foreign_table_name, json_object_agg(option_name, option_value) " + "FROM information_schema.foreign_table_options " + "WHERE foreign_table_schema = %s GROUP BY foreign_table_name", + (schema,), + ), + ) + + def preview(self, tables: Optional[TableInfo]) -> PreviewResult: # Preview data in tables mounted by this FDW / data source # Local import here since this data source gets imported by the commandline entry point @@ -197,16 +244,23 @@ def preview(self, schema: Dict[str, TableSchema]) -> PreviewResult: tmp_schema = get_temporary_table_id() try: - self.mount(tmp_schema, tables=schema) + # Seed the result with the tables that failed to mount + result: PreviewResult = { + e.table_name: e for e in self.mount(tmp_schema, tables=tables) or [] + } + # Commit so that errors don't cancel the mount self.engine.commit() - result: Dict[str, Union[str, List[Dict[str, Any]]]] = {} for t in self.engine.get_all_tables(tmp_schema): + if t in result: + continue try: result[t] = self._preview_table(tmp_schema, t) except psycopg2.DatabaseError as e: logging.warning("Could not preview data for table %s", t, exc_info=e) - result[t] = str(e) + result[t] = MountError( + table_name=t, error=get_exception_name(e), error_text=str(e).strip() + ) return result finally: self.engine.rollback() @@ -286,9 +340,14 @@ def _format_options(option_names): ) -def _import_foreign_schema( - engine: "PostgresEngine", mountpoint: str, remote_schema: str, server_id: str, tables: List[str] -) -> None: +def import_foreign_schema( + engine: "PostgresEngine", + mountpoint: str, + remote_schema: str, + server_id: str, + tables: List[str], + options: Optional[Dict[str, str]] = None, +) -> List[MountError]: from psycopg2.sql import Identifier, SQL # Construct a query: import schema limit to (%s, %s, ...) from server mountpoint_server into mountpoint @@ -296,7 +355,33 @@ def _import_foreign_schema( if tables: query += SQL("LIMIT TO (") + SQL(",").join(Identifier(t) for t in tables) + SQL(")") query += SQL("FROM SERVER {} INTO {}").format(Identifier(server_id), Identifier(mountpoint)) - engine.run_sql(query) + args: List[str] = [] + + if options: + query += _format_options(options.keys()) + args.extend(options.values()) + + engine.run_sql(query, args) + + # Some of our FDWs use the PG notices as a side channel to pass errors and information back + # to the user, as IMPORT FOREIGN SCHEMA is all-or-nothing. If some of the tables failed to + # mount, we return those + the reasons. + import_errors: List[MountError] = [] + if engine.notices: + for notice in engine.notices: + match = re.match(r".*SPLITGRAPH: (.*)\n", notice) + if not match: + continue + notice_j = json.loads(match.group(1)) + import_errors.append( + MountError( + table_name=notice_j["table_name"], + error=notice_j["error"], + error_text=notice_j["error_text"], + ) + ) + + return import_errors def create_foreign_table( @@ -324,7 +409,9 @@ def create_foreign_table( for col in schema_spec: if col.comment: query += SQL("COMMENT ON COLUMN {}.{}.{} IS %s;").format( - Identifier(schema), Identifier(table_name), Identifier(col.name), + Identifier(schema), + Identifier(table_name), + Identifier(col.name), ) args.append(col.comment) return query, args @@ -342,6 +429,12 @@ def get_description(cls) -> str: "based on postgres_fdw" ) + credentials_schema = { + "type": "object", + "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, + "required": ["username", "password"], + } + params_schema = { "type": "object", "properties": { @@ -349,11 +442,12 @@ def get_description(cls) -> str: "port": {"type": "integer", "description": "Port"}, "dbname": {"type": "string", "description": "Database name"}, "remote_schema": {"type": "string", "description": "Remote schema name"}, - "tables": _table_options_schema, }, "required": ["host", "port", "dbname", "remote_schema"], } + table_params_schema = {"type": "object"} + commandline_help: str = """Mount a Postgres database. Mounts a schema on a remote Postgres database as a set of foreign tables locally.""" @@ -378,8 +472,10 @@ def get_server_options(self): def get_user_options(self): return {"user": self.credentials["username"], "password": self.credentials["password"]} - def get_table_options(self, table_name: str): - return {"schema_name": self.params["remote_schema"]} + def get_table_options(self, table_name: str, tables: Optional[TableInfo] = None): + options = super().get_table_options(table_name, tables) + options["schema_name"] = self.params["remote_schema"] + return options def get_fdw_name(self): return "postgres_fdw" @@ -389,27 +485,28 @@ def get_remote_schema_name(self) -> str: class MongoDataSource(ForeignDataWrapperDataSource): + credentials_schema = { + "type": "object", + "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, + "required": ["username", "password"], + } + params_schema = { "type": "object", "properties": { "host": {"type": "string"}, "port": {"type": "integer"}, - "tables": { - "type": "object", - "additionalProperties": { - "options": { - "type": "object", - "properties": { - "db": {"type": "string"}, - "coll": {"type": "string"}, - "required": ["db", "coll"], - }, - }, - "required": ["options"], - }, - }, }, - "required": ["host", "port", "tables"], + "required": ["host", "port"], + } + + table_params_schema = { + "type": "object", + "properties": { + "database": {"type": "string"}, + "collection": {"type": "string"}, + }, + "required": ["database", "collection"], } commandline_help = """Mount a Mongo database. @@ -421,7 +518,7 @@ class MongoDataSource(ForeignDataWrapperDataSource): { "table_name": { "schema": {"col1": "type1"...}, - "options": {"db": , "coll": } + "options": {"database": , "collection": } } } ``` @@ -444,14 +541,6 @@ def get_server_options(self): def get_user_options(self): return {"username": self.credentials["username"], "password": self.credentials["password"]} - def get_table_options(self, table_name: str): - try: - table_params = self.params["tables"][table_name]["options"] - except KeyError: - raise ValueError("No options specified for table %s!" % table_name) - - return {"database": table_params["db"], "collection": table_params["coll"]} - def get_fdw_name(self): return "mongo_fdw" @@ -464,22 +553,27 @@ def get_table_schema(self, table_name, table_schema): class MySQLDataSource(ForeignDataWrapperDataSource): + credentials_schema = { + "type": "object", + "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, + "required": ["username", "password"], + } + params_schema = { "type": "object", "properties": { "host": {"type": "string"}, "port": {"type": "integer"}, - "remote_schema": {"type": "string"}, - "tables": _table_options_schema, + "dbname": {"type": "string"}, }, - "required": ["host", "port", "remote_schema"], + "required": ["host", "port", "dbname"], } commandline_help: str = """Mount a MySQL database. Mounts a schema on a remote MySQL database as a set of foreign tables locally.""" - commandline_kwargs_help: str = """remote_schema: Remote schema name (required) + commandline_kwargs_help: str = """dbname: Remote MySQL database name (required) tables: Tables to mount (default all). If a list, then will use IMPORT FOREIGN SCHEMA. If a dictionary, must have the format {"table_name": {"schema": {"col_1": "type_1", ...}, @@ -492,7 +586,7 @@ def get_name(cls) -> str: @classmethod def get_description(cls) -> str: - return "Data source for MySQL databases that supports live querying, " "based on mysql_fdw" + return "Data source for MySQL databases that supports live querying, based on mysql_fdw" def get_server_options(self): return { @@ -504,14 +598,16 @@ def get_server_options(self): def get_user_options(self): return {"username": self.credentials["username"], "password": self.credentials["password"]} - def get_table_options(self, table_name: str): - return {"dbname": self.params["remote_schema"]} + def get_table_options(self, table_name: str, tables: Optional[TableInfo] = None): + options = super().get_table_options(table_name, tables) + options["dbname"] = options.get("dbname", self.params["dbname"]) + return options def get_fdw_name(self): return "mysql_fdw" def get_remote_schema_name(self) -> str: - return str(self.params["remote_schema"]) + return str(self.params["dbname"]) class ElasticSearchDataSource(ForeignDataWrapperDataSource): @@ -528,42 +624,39 @@ class ElasticSearchDataSource(ForeignDataWrapperDataSource): "properties": { "host": {"type": "string"}, "port": {"type": "integer"}, - "tables": { - "type": "object", - "additionalProperties": { - "options": { - "type": "object", - "properties": { - "index": { - "type": "string", - "description": 'ES index name or pattern to use, for example, "events-*"', - }, - "type": { - "type": "string", - "description": "Pre-ES7 doc_type, not required in ES7 or later", - }, - "query_column": { - "type": "string", - "description": "Name of the column to use to pass queries in", - }, - "score_column": { - "type": "string", - "description": "Name of the column with the document score", - }, - "scroll_size": { - "type": "integer", - "description": "Fetch size, default 1000", - }, - "scroll_duration": { - "type": "string", - "description": "How long to hold the scroll context open for, default 10m", - }, - }, - }, - }, + }, + "required": ["host", "port"], + } + + table_params_schema = { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": 'ES index name or pattern to use, for example, "events-*"', + }, + "type": { + "type": "string", + "description": "Pre-ES7 doc_type, not required in ES7 or later", + }, + "query_column": { + "type": "string", + "description": "Name of the column to use to pass queries in", + }, + "score_column": { + "type": "string", + "description": "Name of the column with the document score", + }, + "scroll_size": { + "type": "integer", + "description": "Fetch size, default 1000", + }, + "scroll_duration": { + "type": "string", + "description": "How long to hold the scroll context open for, default 10m", }, }, - "required": ["host", "port", "tables"], + "required": ["index"], } commandline_help = """Mount an ElasticSearch instance. diff --git a/splitgraph/hooks/mount_handlers.py b/splitgraph/hooks/mount_handlers.py index 51ab60a6..023cb07a 100644 --- a/splitgraph/hooks/mount_handlers.py +++ b/splitgraph/hooks/mount_handlers.py @@ -6,7 +6,7 @@ from splitgraph.exceptions import DataSourceError if TYPE_CHECKING: - from splitgraph.hooks.data_source.base import TableInfo + from splitgraph.core.types import TableInfo def mount_postgres(mountpoint, **kwargs) -> None: diff --git a/splitgraph/ingestion/csv/__init__.py b/splitgraph/ingestion/csv/__init__.py index a3480226..261bf213 100644 --- a/splitgraph/ingestion/csv/__init__.py +++ b/splitgraph/ingestion/csv/__init__.py @@ -1,9 +1,11 @@ +import json from copy import deepcopy -from typing import Optional, TYPE_CHECKING, Dict, Mapping, cast +from typing import Optional, TYPE_CHECKING, Dict, List, Tuple, Any from psycopg2.sql import SQL, Identifier -from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource +from splitgraph.core.types import TableInfo, MountError +from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource, import_foreign_schema from splitgraph.ingestion.common import IngestionAdapter, build_commandline_help if TYPE_CHECKING: @@ -118,6 +120,15 @@ class CSVDataSource(ForeignDataWrapperDataSource): }, "quotechar": {"type": "string", "description": "Character used to quote fields"}, }, + "oneOf": [{"required": ["url"]}, {"required": ["s3_endpoint", "s3_bucket"]}], + } + + table_params_schema = { + "type": "object", + "properties": { + "url": {"type": "string", "description": "HTTP URL to the CSV file"}, + "s3_object": {"type": "string", "description": "S3 object of the CSV file"}, + }, } supports_mount = True @@ -173,13 +184,72 @@ def from_commandline(cls, engine, commandline_kwargs) -> "CSVDataSource": credentials[k] = params[k] return cls(engine, credentials, params) - def get_table_options(self, table_name: str) -> Mapping[str, str]: - result = cast(Dict[str, str], super().get_table_options(table_name)) - result["s3_object"] = result.get( - "s3_object", self.params.get("s3_object_prefix", "") + table_name - ) + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: + result = super().get_table_options(table_name, tables) + + # Set a default s3_object if we're using S3 and not HTTP + if "url" not in result: + result["s3_object"] = result.get( + "s3_object", self.params.get("s3_object_prefix", "") + table_name + ) return result + def _create_foreign_tables( + self, schema: str, server_id: str, tables: TableInfo + ) -> List[MountError]: + # Override _create_foreign_tables (actual mounting code) to support TableInfo structs + # where the schema is empty. This is so that we can call this data source with a limited + # list of CSV files and their delimiters / other params and have it introspect just + # those tables (instead of e.g. scanning the whole bucket). + errors: List[MountError] = [] + + if isinstance(tables, dict): + to_introspect = { + table_name: table_options + for table_name, (table_schema, table_options) in tables.items() + if not table_schema + } + if to_introspect: + # This FDW's implementation of IMPORT FOREIGN SCHEMA supports passing a JSON of table + # options in options + errors.extend( + import_foreign_schema( + self.engine, + schema, + self.get_remote_schema_name(), + server_id, + tables=list(to_introspect.keys()), + options={"table_options": json.dumps(to_introspect)}, + ) + ) + + # Create the remaining tables (that have a schema) as usual. + tables = { + table_name: (table_schema, table_options) + for table_name, (table_schema, table_options) in tables.items() + if table_schema + } + if not tables: + return errors + + errors.extend(super()._create_foreign_tables(schema, server_id, tables)) + return errors + + def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, Any]]]: + options = super()._get_foreign_table_options(schema) + + # Deserialize things like booleans from the foreign table options that the FDW inferred + # for us. + def _destring_table_options(table_options): + for k in ["autodetect_header", "autodetect_encoding", "autodetect_dialect", "header"]: + if k in table_options: + table_options[k] = table_options[k].lower() == "true" + return table_options + + return [(t, _destring_table_options(d)) for t, d in options] + def get_server_options(self): options: Dict[str, Optional[str]] = { "wrapper": "splitgraph.ingestion.csv.fdw.CSVForeignDataWrapper" diff --git a/splitgraph/ingestion/csv/common.py b/splitgraph/ingestion/csv/common.py index 3039cb2f..7af33645 100644 --- a/splitgraph/ingestion/csv/common.py +++ b/splitgraph/ingestion/csv/common.py @@ -1,6 +1,6 @@ import csv import io -from typing import Optional, Dict, Tuple, NamedTuple, Union, Type, TYPE_CHECKING +from typing import Dict, Tuple, NamedTuple, TYPE_CHECKING, Any if TYPE_CHECKING: import _csv @@ -18,7 +18,6 @@ class CSVOptions(NamedTuple): schema_inference_rows: int = 10000 delimiter: str = "," quotechar: str = '"' - dialect: Optional[Union[str, Type[csv.Dialect]]] = None header: bool = True encoding: str = "utf-8" ignore_decode_errors: bool = False @@ -34,16 +33,34 @@ def from_fdw_options(cls, fdw_options): header=get_bool(fdw_options, "header"), delimiter=fdw_options.get("delimiter", ","), quotechar=fdw_options.get("quotechar", '"'), - dialect=fdw_options.get("dialect"), encoding=fdw_options.get("encoding", "utf-8"), ignore_decode_errors=get_bool(fdw_options, "ignore_decode_errors", default=False), ) def to_csv_kwargs(self): - if self.dialect: - return {"dialect": self.dialect} return {"delimiter": self.delimiter, "quotechar": self.quotechar} + def to_table_options(self): + """ + Turn this into a dict of table options that can be plugged back into CSVDataSource. + """ + + # The purpose is to return to the user the CSV dialect options that we inferred + # so that they can freeze them in the table options (instead of rescanning the CSV + # on every mount) + iterate on them. + + # We flip the autodetect flags to False here so that if we merge the new params with + # the old params again, it won't rerun CSV dialect detection. + return { + "autodetect_header": "false", + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "header": bool_to_str(self.header), + "delimiter": self.delimiter, + "quotechar": self.quotechar, + "encoding": self.encoding, + } + def autodetect_csv(stream: io.RawIOBase, csv_options: CSVOptions) -> CSVOptions: """Autodetect the CSV dialect, encoding, header etc.""" @@ -77,7 +94,10 @@ def autodetect_csv(stream: io.RawIOBase, csv_options: CSVOptions) -> CSVOptions: if csv_options.autodetect_dialect: dialect = csv.Sniffer().sniff(sample) - csv_options = csv_options._replace(dialect=dialect) + # These are meant to be set, but mypy claims they might not be. + csv_options = csv_options._replace( + delimiter=dialect.delimiter or ",", quotechar=dialect.quotechar or '"' + ) if csv_options.autodetect_header: has_header = csv.Sniffer().has_header(sample) @@ -86,10 +106,16 @@ def autodetect_csv(stream: io.RawIOBase, csv_options: CSVOptions) -> CSVOptions: return csv_options -def get_bool(params: Dict[str, str], key: str, default: bool = True) -> bool: +def get_bool(params: Dict[str, Any], key: str, default: bool = True) -> bool: if key not in params: return default - return params[key].lower() == "true" + if isinstance(params[key], bool): + return bool(params[key]) + return bool(params[key].lower() == "true") + + +def bool_to_str(boolean: bool) -> str: + return "true" if boolean else "false" def make_csv_reader( diff --git a/splitgraph/ingestion/csv/fdw.py b/splitgraph/ingestion/csv/fdw.py index 201da8a1..f7eba951 100644 --- a/splitgraph/ingestion/csv/fdw.py +++ b/splitgraph/ingestion/csv/fdw.py @@ -1,15 +1,17 @@ import gzip +import json import logging import os -from copy import deepcopy +from contextlib import contextmanager +from copy import copy from itertools import islice -from typing import Tuple +from typing import Tuple, Optional import requests from minio import Minio import splitgraph.config -from splitgraph.commandline import get_exception_name +from splitgraph.exceptions import get_exception_name from splitgraph.ingestion.common import generate_column_names from splitgraph.ingestion.csv.common import CSVOptions, get_bool, make_csv_reader from splitgraph.ingestion.inference import infer_sg_schema @@ -37,6 +39,10 @@ def _get_table_definition(response, fdw_options, table_name, table_options): + # Allow overriding introspection options with per-table params (e.g. encoding, delimiter...) + fdw_options = copy(fdw_options) + fdw_options.update(table_options) + csv_options, reader = make_csv_reader(response, CSVOptions.from_fdw_options(fdw_options)) sample = list(islice(reader, csv_options.schema_inference_rows)) @@ -48,17 +54,49 @@ def _get_table_definition(response, fdw_options, table_name, table_options): # For nonexistent column names: replace with autogenerated ones (can't have empty column names) sg_schema = generate_column_names(sg_schema) + # Merge the autodetected table options with the ones passed to us originally (e.g. + # S3 object etc) + new_table_options = copy(table_options) + new_table_options.update(csv_options.to_table_options()) + # Build Multicorn TableDefinition. ColumnDefinition takes in type OIDs, # typmods and other internal PG stuff but other FDWs seem to get by with just # the textual type name. return TableDefinition( - table_name=table_name[len(fdw_options.get("s3_object_prefix", "")) :], + table_name=table_name, schema=None, columns=[ColumnDefinition(column_name=c.name, type_name=c.pg_type) for c in sg_schema], - options=table_options, + options=new_table_options, ) +@contextmanager +def report_errors(table_name: str): + """Context manager that ignores exceptions and serializes them to JSON using PG's notice + mechanism instead. The data source is meant to load these to report on partial failures + (e.g. failed to load one table, but not others).""" + try: + yield + except Exception as e: + logging.error( + "Error scanning %s, ignoring: %s: %s", + table_name, + get_exception_name(e), + e, + exc_info=e, + ) + log_to_postgres( + "SPLITGRAPH: " + + json.dumps( + { + "table_name": table_name, + "error": get_exception_name(e), + "error_text": str(e), + } + ) + ) + + class CSVForeignDataWrapper(ForeignDataWrapper): """Foreign data wrapper for CSV files stored in S3 buckets or HTTP""" @@ -133,72 +171,115 @@ def import_schema(cls, schema, srv_options, options, restriction_type, restricts # Implement IMPORT FOREIGN SCHEMA to instead scan an S3 bucket for CSV files # and infer their CSV schema. - # Merge server options and options passed to IMPORT FOREIGN SCHEMA - fdw_options = deepcopy(srv_options) - for k, v in options.items(): - fdw_options[k] = v - - if fdw_options.get("url"): - # Infer from HTTP -- singular table with name "data" - with requests.get( - fdw_options["url"], stream=True, verify=os.environ.get("SSL_CERT_FILE", True) - ) as response: - response.raise_for_status() - stream = response.raw - if response.headers.get("Content-Encoding") == "gzip": - stream = gzip.GzipFile(fileobj=stream) - return [_get_table_definition(stream, fdw_options, "data", None)] - - # Get S3 options - client, bucket, prefix = cls._get_s3_params(fdw_options) - - # Note that we ignore the "schema" here (e.g. IMPORT FOREIGN SCHEMA some_schema) - # and take all interesting parameters through FDW options. - - # Allow just introspecting one object - if "s3_object" in fdw_options: - objects = [fdw_options["s3_object"]] + # 1) if we don't have options["table_options"], do a full scan as normal, + # treat LIMIT TO as a list of S3 objects + # 2) if we do, go through these tables and treat each one as a partial override + # of server options + if "table_options" in options: + table_options = json.loads(options["table_options"]) else: - objects = [ - o.object_name - for o in client.list_objects( - bucket_name=bucket, prefix=prefix or None, recursive=True - ) - ] - - result = [] - - for o in objects: - if restriction_type == "limit" and o not in restricts: - continue - - if restriction_type == "except" and o in restricts: - continue + table_options = None + + if not table_options: + # Do a full scan of the file at URL / S3 bucket w. prefix + if srv_options.get("url"): + # Infer from HTTP -- singular table with name "data" + return [cls._introspect_url(srv_options, srv_options["url"])] + else: + # Get S3 options + client, bucket, prefix = cls._get_s3_params(srv_options) + + # Note that we ignore the "schema" here (e.g. IMPORT FOREIGN SCHEMA some_schema) + # and take all interesting parameters through FDW options. + + # Allow just introspecting one object + if "s3_object" in srv_options: + objects = [srv_options["s3_object"]] + elif restriction_type == "limit": + objects = restricts + else: + objects = [ + o.object_name + for o in client.list_objects( + bucket_name=bucket, prefix=prefix or None, recursive=True + ) + ] + + result = [] + + for o in objects: + if restriction_type == "except" and o in restricts: + continue + result.append(cls._introspect_s3(client, bucket, o, srv_options)) + + return [r for r in result if r] + else: + result = [] + + # Note we ignore LIMIT/EXCEPT here. There's no point in using them if the user + # is passing a dict of table options anyway. + for table_name, this_table_options in table_options.items(): + if "s3_object" in this_table_options: + # TODO: we can support overriding S3 params per-table here, but currently + # we don't do it. + client, bucket, _ = cls._get_s3_params(srv_options) + result.append( + cls._introspect_s3( + client, + bucket, + this_table_options["s3_object"], + srv_options, + table_name, + this_table_options, + ) + ) + else: + result.append( + cls._introspect_url( + srv_options, this_table_options["url"], table_name, this_table_options + ) + ) + return [r for r in result if r] - response = None + @classmethod + def _introspect_s3( + cls, client, bucket, object_id, srv_options, table_name=None, table_options=None + ) -> Optional[TableDefinition]: + response = None + # Default table name: truncate S3 object key up to the prefix + table_name = table_name or object_id[len(srv_options.get("s3_object_prefix", "")) :] + table_options = table_options or {} + table_options.update({"s3_object": object_id}) + with report_errors(table_name): try: - response = client.get_object(bucket, o) - result.append( - _get_table_definition( - response, - fdw_options, - o, - {"s3_object": o}, - ) - ) - except Exception as e: - logging.error( - "Error scanning object %s, ignoring: %s: %s", o, get_exception_name(e), e - ) - log_to_postgres( - "Error scanning object %s, ignoring: %s: %s" % (o, get_exception_name(e), e) + response = client.get_object(bucket, object_id) + return _get_table_definition( + response, + srv_options, + table_name, + table_options, ) finally: if response: response.close() response.release_conn() - return result + @classmethod + def _introspect_url( + cls, srv_options, url, table_name=None, table_options=None + ) -> Optional[TableDefinition]: + table_name = table_name or "data" + table_options = table_options or {} + + with report_errors(table_name): + with requests.get( + url, stream=True, verify=os.environ.get("SSL_CERT_FILE", True) + ) as response: + response.raise_for_status() + stream = response.raw + if response.headers.get("Content-Encoding") == "gzip": + stream = gzip.GzipFile(fileobj=stream) + return _get_table_definition(stream, srv_options, table_name, table_options) @classmethod def _get_s3_params(cls, fdw_options) -> Tuple[Minio, str, str]: @@ -242,6 +323,4 @@ def __init__(self, fdw_options, fdw_columns): self.mode = "s3" self.s3_client, self.s3_bucket, self.s3_object_prefix = self._get_s3_params(fdw_options) - # TODO need a way to pass the table params (e.g. the actual S3 object name which - # might be different from table) into the preview and return it from introspection. self.s3_object = fdw_options["s3_object"] diff --git a/splitgraph/ingestion/singer/data_source.py b/splitgraph/ingestion/singer/data_source.py index 760bc21c..8f5832d0 100644 --- a/splitgraph/ingestion/singer/data_source.py +++ b/splitgraph/ingestion/singer/data_source.py @@ -12,16 +12,14 @@ from psycopg2.sql import Identifier, SQL from splitgraph.core.repository import Repository -from splitgraph.core.types import TableSchema +from splitgraph.core.types import TableParams, TableInfo, SyncState, IntrospectionResult from splitgraph.exceptions import DataSourceError from splitgraph.hooks.data_source.base import ( get_ingestion_state, INGESTION_STATE_TABLE, INGESTION_STATE_SCHEMA, - TableInfo, prepare_new_image, SyncableDataSource, - SyncState, ) from splitgraph.ingestion.singer.db_sync import ( get_table_name, @@ -197,15 +195,15 @@ def build_singer_catalog( catalog, tables=tables, use_legacy_stream_selection=self.use_legacy_stream_selection ) - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> IntrospectionResult: config = self.get_singer_config() singer_schema = self._run_singer_discovery(config) - result = {} + result: IntrospectionResult = {} for stream in singer_schema["streams"]: stream_name = get_table_name(stream) stream_schema = get_sg_schema(stream) - result[stream_name] = stream_schema + result[stream_name] = (stream_schema, cast(TableParams, {})) return result diff --git a/splitgraph/ingestion/snowflake/__init__.py b/splitgraph/ingestion/snowflake/__init__.py index ac3c2b27..01cae739 100644 --- a/splitgraph/ingestion/snowflake/__init__.py +++ b/splitgraph/ingestion/snowflake/__init__.py @@ -1,8 +1,9 @@ import base64 import json import urllib.parse -from typing import Dict, Optional, cast, Mapping, Any +from typing import Dict, Optional, Any +from splitgraph.core.types import TableInfo from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource from splitgraph.ingestion.common import build_commandline_help @@ -50,12 +51,6 @@ class SnowflakeDataSource(ForeignDataWrapperDataSource): params_schema = { "type": "object", "properties": { - "tables": { - "type": "object", - "additionalProperties": { - "options": {"type": "object", "additionalProperties": {"type": "string"}}, - }, - }, "database": {"type": "string", "description": "Snowflake database name"}, "schema": {"type": "string", "description": "Snowflake schema"}, "warehouse": {"type": "string", "description": "Warehouse name"}, @@ -72,6 +67,15 @@ class SnowflakeDataSource(ForeignDataWrapperDataSource): "required": ["database"], } + table_params_schema = { + "type": "object", + "properties": { + "subquery": { + "type": "string", + "description": "Subquery for this table to run on the server side", + } + }, + } supports_mount = True supports_load = True supports_sync = False @@ -135,8 +139,10 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Schema, table or a subquery from a Snowflake database" - def get_table_options(self, table_name: str) -> Mapping[str, str]: - result = cast(Dict[str, str], super().get_table_options(table_name)) + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: + result = super().get_table_options(table_name, tables) result["tablename"] = result.get("tablename", table_name) return result diff --git a/splitgraph/ingestion/socrata/mount.py b/splitgraph/ingestion/socrata/mount.py index 1d3f04eb..8252ecf7 100644 --- a/splitgraph/ingestion/socrata/mount.py +++ b/splitgraph/ingestion/socrata/mount.py @@ -2,10 +2,11 @@ import json import logging from copy import deepcopy -from typing import Optional, Dict +from typing import Optional, Dict, List from psycopg2.sql import SQL, Identifier +from splitgraph.core.types import TableInfo, MountError from splitgraph.exceptions import RepositoryNotFoundError from splitgraph.hooks.data_source.fdw import create_foreign_table, ForeignDataWrapperDataSource @@ -29,17 +30,17 @@ class SocrataDataSource(ForeignDataWrapperDataSource): "type": "integer", "description": "Amount of rows to fetch from Socrata per request (limit parameter). Maximum 50000.", }, - "tables": { - "type": "object", - }, }, "required": ["domain"], } - """ - tables: A dictionary mapping PostgreSQL table names to Socrata table IDs. For example, - {"salaries": "xzkq-xp2w"}. If skipped, ALL tables in the Socrata endpoint will be mounted. - """ + table_params_schema = { + "type": "object", + "properties": { + "socrata_id": {"type": "string", "description": "Socrata dataset ID, e.g. xzkq-xp2w"} + }, + "required": ["socrata_id"], + } @classmethod def get_name(cls) -> str: @@ -55,8 +56,19 @@ def get_fdw_name(self): @classmethod def from_commandline(cls, engine, commandline_kwargs) -> "SocrataDataSource": params = deepcopy(commandline_kwargs) + + # Convert the old-style "tables" param ({"table_name": "some_id"}) + # to the schema this data source expects + # {"table_name": [table schema], {"socrata_id": "some_id"}} + # Note that we don't actually care about the table schema in this data source, + # since it reintrospects on every mount. + + tables = params.pop("tables", []) + if isinstance(tables, dict) and isinstance(next(iter(tables.values())), str): + tables = {k: ([], {"socrata_id": v}) for k, v in tables.items()} + credentials = {"app_token": params.pop("app_token", None)} - return cls(engine, credentials, params) + return cls(engine, credentials, params, tables) def get_server_options(self): options: Dict[str, Optional[str]] = { @@ -69,14 +81,20 @@ def get_server_options(self): options["app_token"] = str(self.credentials["app_token"]) return options - def _create_foreign_tables(self, schema, server_id, tables): + def _create_foreign_tables( + self, schema: str, server_id: str, tables: TableInfo + ) -> List[MountError]: from sodapy import Socrata from psycopg2.sql import SQL logging.info("Getting Socrata metadata") client = Socrata(domain=self.params["domain"], app_token=self.credentials.get("app_token")) - tables = self.params.get("tables") - sought_ids = tables.values() if tables else [] + + tables = self.tables or tables + if isinstance(tables, list): + sought_ids = tables + else: + sought_ids = [t[1]["socrata_id"] for t in tables.values()] try: datasets = client.datasets(ids=sought_ids, only=["dataset"]) @@ -96,30 +114,20 @@ def _create_foreign_tables(self, schema, server_id, tables): ) self.engine.run_sql(SQL(";").join(mount_statements), mount_args) + return [] -def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, tables): +def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, tables: TableInfo): # Local imports since this module gets run from commandline entrypoint on startup. from splitgraph.core.output import slugify - from splitgraph.core.output import truncate_list from splitgraph.core.output import pluralise from splitgraph.ingestion.socrata.querying import socrata_to_sg_schema found_ids = set(d["resource"]["id"] for d in datasets) logging.info("Loaded metadata for %s", pluralise("Socrata table", len(found_ids))) - if tables: - missing_ids = [d for d in found_ids if d not in sought_ids] - if missing_ids: - raise ValueError( - "Some Socrata tables couldn't be found! Missing tables: %s" - % truncate_list(missing_ids) - ) - - tables_inv = {s: p for p, s in tables.items()} - else: - tables_inv = {} + tables_inv = _get_table_map(found_ids, sought_ids, tables) mount_statements = [] mount_args = [] @@ -148,3 +156,20 @@ def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, mount_args.extend(args) return mount_statements, mount_args + + +def _get_table_map(found_ids, sought_ids, tables: TableInfo) -> Dict[str, str]: + """Get a map of Socrata ID -> local table name""" + from splitgraph.core.output import truncate_list + + if isinstance(tables, (dict, list)) and tables: + missing_ids = [d for d in found_ids if d not in sought_ids] + if missing_ids: + raise ValueError( + "Some Socrata tables couldn't be found! Missing tables: %s" + % truncate_list(missing_ids) + ) + + if isinstance(tables, dict): + return {to["socrata_id"]: p for p, (ts, to) in tables.items()} + return {} diff --git a/test/resources/custom_plugin_dir/some_plugin/plugin.py b/test/resources/custom_plugin_dir/some_plugin/plugin.py index 861cc34c..2d7e3a8b 100644 --- a/test/resources/custom_plugin_dir/some_plugin/plugin.py +++ b/test/resources/custom_plugin_dir/some_plugin/plugin.py @@ -1,12 +1,10 @@ -from typing import Dict - -from splitgraph.core.types import TableSchema +from splitgraph.core.types import IntrospectionResult from splitgraph.hooks.data_source import DataSource class TestDataSource(DataSource): - def introspect(self) -> Dict[str, TableSchema]: - return {"some_table": []} + def introspect(self) -> IntrospectionResult: + return {"some_table": ([], {})} @classmethod def get_name(cls) -> str: diff --git a/test/splitgraph/commandline/test_mount.py b/test/splitgraph/commandline/test_mount.py index 64b971a1..2a85d281 100644 --- a/test/splitgraph/commandline/test_mount.py +++ b/test/splitgraph/commandline/test_mount.py @@ -23,8 +23,15 @@ _MONGO_PARAMS = { "tables": { "stuff": { - "options": {"db": "origindb", "coll": "stuff",}, - "schema": {"name": "text", "duration": "numeric", "happy": "boolean",}, + "options": { + "database": "origindb", + "collection": "stuff", + }, + "schema": { + "name": "text", + "duration": "numeric", + "happy": "boolean", + }, } } } @@ -55,7 +62,8 @@ def test_misc_mountpoint_management(pg_repo_local, mg_repo_local): # sgr mount with a file with tempfile.NamedTemporaryFile("w") as f: json.dump( - _MONGO_PARAMS, f, + _MONGO_PARAMS, + f, ) f.flush() @@ -150,5 +158,5 @@ def test_mount_plugin_dir(): plugin = plugin_class( engine=None, credentials={"access_token": "abc"}, params={"some_field": "some_value"} ) - assert plugin.introspect() == {"some_table": []} + assert plugin.introspect() == {"some_table": ([], {})} assert plugin.get_name() == "Test Data Source" diff --git a/test/splitgraph/commands/test_mounting.py b/test/splitgraph/commands/test_mounting.py index cd5de8a7..634fb781 100644 --- a/test/splitgraph/commands/test_mounting.py +++ b/test/splitgraph/commands/test_mounting.py @@ -63,26 +63,34 @@ def test_mount_introspection_preview(local_engine_empty): params={"host": "pgorigin", "port": 5432, "dbname": "origindb", "remote_schema": "public"}, ) - schema = handler.introspect() - - assert schema == { - "fruits": [ - TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), - TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], - "vegetables": [ - TableColumn( - ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None - ), - TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], + tables = handler.introspect() + + assert tables == { + "fruits": ( + [ + TableColumn( + ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None + ), + ], + {"schema_name": "public", "table_name": "fruits"}, + ), + "vegetables": ( + [ + TableColumn( + ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None + ), + ], + {"schema_name": "public", "table_name": "vegetables"}, + ), } - preview = handler.preview(schema=schema) + preview = handler.preview(tables=tables) assert preview == { "fruits": [{"fruit_id": 1, "name": "apple"}, {"fruit_id": 2, "name": "orange"}], "vegetables": [ @@ -92,6 +100,38 @@ def test_mount_introspection_preview(local_engine_empty): } +@pytest.mark.mounting +def test_mount_rename_table(local_engine_empty): + tables = { + "fruits_renamed": ( + [ + TableColumn( + ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, + ), + ], + {"table_name": "fruits"}, + ) + } + handler = PostgreSQLDataSource( + engine=local_engine_empty, + credentials={"username": "originro", "password": "originpass"}, + params={"host": "pgorigin", "port": 5432, "dbname": "origindb", "remote_schema": "public"}, + tables=tables, + ) + + preview = handler.preview(tables) + assert preview == { + "fruits_renamed": [{"fruit_id": 1, "name": "apple"}, {"fruit_id": 2, "name": "orange"}], + } + + @pytest.mark.mounting def test_cross_joins(local_engine_empty): _mount_postgres(PG_MNT) @@ -131,9 +171,11 @@ def test_mount_elasticsearch(local_engine_empty): "col_1": "text", "col_2": "boolean", }, - "index": "index-pattern*", - "rowid_column": "id", - "query_column": "query", + "options": { + "index": "index-pattern*", + "rowid_column": "id", + "query_column": "query", + }, } }, ), diff --git a/test/splitgraph/conftest.py b/test/splitgraph/conftest.py index 48da84ba..6fbf8e1d 100644 --- a/test/splitgraph/conftest.py +++ b/test/splitgraph/conftest.py @@ -98,8 +98,8 @@ def _mount_mongo(repository): tables=dict( stuff={ "options": { - "db": "origindb", - "coll": "stuff", + "database": "origindb", + "collection": "stuff", }, "schema": {"name": "text", "duration": "numeric", "happy": "boolean"}, }, @@ -120,17 +120,20 @@ def _mount_mysql(repository): port=3306, username="originuser", password="originpass", - remote_schema="mysqlschema", + dbname="mysqlschema", ), tables={ - "mushrooms": [ - TableColumn(1, "mushroom_id", "integer", False), - TableColumn(2, "name", "character varying (20)", False), - TableColumn(3, "discovery", "timestamp", False), - TableColumn(4, "friendly", "boolean", False), - TableColumn(5, "binary_data", "bytea", False), - TableColumn(6, "varbinary_data", "bytea", False), - ] + "mushrooms": ( + [ + TableColumn(1, "mushroom_id", "integer", False), + TableColumn(2, "name", "character varying (20)", False), + TableColumn(3, "discovery", "timestamp", False), + TableColumn(4, "friendly", "boolean", False), + TableColumn(5, "binary_data", "bytea", False), + TableColumn(6, "varbinary_data", "bytea", False), + ], + {}, + ) }, ) diff --git a/test/splitgraph/ingestion/test_common.py b/test/splitgraph/ingestion/test_common.py index d5772f7d..e5a84826 100644 --- a/test/splitgraph/ingestion/test_common.py +++ b/test/splitgraph/ingestion/test_common.py @@ -1,15 +1,25 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Tuple, cast from splitgraph.core.repository import Repository -from splitgraph.core.types import TableSchema, TableColumn +from splitgraph.core.types import ( + TableSchema, + TableColumn, + TableInfo, + SyncState, + TableParams, + IntrospectionResult, +) from splitgraph.engine import ResultShape -from splitgraph.hooks.data_source.base import SyncState, TableInfo, SyncableDataSource - -SCHEMA = { - "test_table": [ - TableColumn(1, "key", "integer", True), - TableColumn(2, "value", "character varying", False), - ] +from splitgraph.hooks.data_source.base import SyncableDataSource + +SCHEMA: Dict[str, Tuple[TableSchema, TableParams]] = { + "test_table": ( + [ + TableColumn(1, "key", "integer", True), + TableColumn(2, "value", "character varying", False), + ], + {}, + ) } TEST_REPO = "test/generic_sync" @@ -26,26 +36,22 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Test ingestion" - def introspect(self) -> Dict[str, TableSchema]: - return SCHEMA + def introspect(self) -> IntrospectionResult: + return cast(IntrospectionResult, SCHEMA) def _sync( self, schema: str, state: Optional[SyncState] = None, tables: Optional[TableInfo] = None ) -> SyncState: if not self.engine.table_exists(schema, "test_table"): - self.engine.create_table(schema, "test_table", SCHEMA["test_table"]) + self.engine.create_table(schema, "test_table", SCHEMA["test_table"][0]) if not state: - self.engine.run_sql_in( - schema, "INSERT INTO test_table (key, value) " "VALUES (1, 'one')" - ) + self.engine.run_sql_in(schema, "INSERT INTO test_table (key, value) VALUES (1, 'one')") return {"last_value": 1} else: last_value = state["last_value"] assert last_value == 1 - self.engine.run_sql_in( - schema, "INSERT INTO test_table (key, value) " "VALUES (2, 'two')" - ) + self.engine.run_sql_in(schema, "INSERT INTO test_table (key, value) VALUES (2, 'two')") return {"last_value": 2} diff --git a/test/splitgraph/ingestion/test_csv.py b/test/splitgraph/ingestion/test_csv.py index 51a79945..05ef92ba 100644 --- a/test/splitgraph/ingestion/test_csv.py +++ b/test/splitgraph/ingestion/test_csv.py @@ -1,9 +1,11 @@ +import json import os from io import BytesIO +from unittest import mock import pytest -from splitgraph.core.types import TableColumn +from splitgraph.core.types import TableColumn, MountError, unwrap from splitgraph.engine import ResultShape from splitgraph.hooks.s3_server import MINIO from splitgraph.ingestion.common import generate_column_names @@ -13,6 +15,28 @@ from splitgraph.ingestion.inference import infer_sg_schema from test.splitgraph.conftest import INGESTION_RESOURCES_CSV +_s3_win_1252_opts = { + "s3_object": "some_prefix/encoding-win-1252.csv", + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ";", + "encoding": "Windows-1252", + "header": "true", + "quotechar": '"', +} + +_s3_fruits_opts = { + "s3_object": "some_prefix/fruits.csv", + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ",", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', +} + def test_csv_introspection_s3(): fdw_options = { @@ -37,7 +61,7 @@ def test_csv_introspection_s3(): {"column_name": "DATE", "type_name": "character varying"}, {"column_name": "TEXT", "type_name": "character varying"}, ], - "options": {"s3_object": "some_prefix/encoding-win-1252.csv"}, + "options": _s3_win_1252_opts, "schema": None, "table_name": "encoding-win-1252.csv", } @@ -52,15 +76,11 @@ def test_csv_introspection_s3(): {"column_name": "bignumber", "type_name": "bigint"}, {"column_name": "vbignumber", "type_name": "numeric"}, ], - "options": {"s3_object": "some_prefix/fruits.csv"}, + "options": _s3_fruits_opts, } assert schema[2]["table_name"] == "rdu-weather-history.csv" assert schema[2]["columns"][0] == {"column_name": "date", "type_name": "date"} - # TODO we need a way to pass suggested table options in the inference / preview response, - # since we need to somehow decouple the table name from the S3 object name and/or customize - # delimiter/quotechar - def test_csv_introspection_http(): # Pre-sign the S3 URL for an easy HTTP URL to test this @@ -84,7 +104,87 @@ def test_csv_introspection_http(): {"column_name": "bignumber", "type_name": "bigint"}, {"column_name": "vbignumber", "type_name": "numeric"}, ], - "options": None, + "options": { + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ",", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', + }, + } + + +def test_csv_introspection_multiple(): + # Test running the introspection passing the table options as CREATE FOREIGN SCHEMA params. + # In effect, we specify the table names, S3 key/URL and expect the FDW to figure out + # the rest. + + fdw_options = { + "s3_endpoint": "objectstorage:9000", + "s3_secure": "false", + "s3_access_key": "minioclient", + "s3_secret_key": "supersecure", + "s3_bucket": "test_csv", + "s3_object_prefix": "some_prefix/", + } + + url = MINIO.presigned_get_object("test_csv", "some_prefix/rdu-weather-history.csv") + schema = CSVForeignDataWrapper.import_schema( + schema=None, + srv_options=fdw_options, + options={ + "table_options": json.dumps( + { + "from_url": {"url": url}, + "from_s3_rdu": {"s3_object": "some_prefix/rdu-weather-history.csv"}, + "from_s3_encoding": {"s3_object": "some_prefix/encoding-win-1252.csv"}, + } + ) + }, + restriction_type=None, + restricts=[], + ) + + assert len(schema) == 3 + schema = sorted(schema, key=lambda s: s["table_name"]) + + assert schema[0] == { + "table_name": "from_s3_encoding", + "schema": None, + "columns": mock.ANY, + "options": _s3_win_1252_opts, + } + assert schema[1] == { + "table_name": "from_s3_rdu", + "schema": None, + "columns": mock.ANY, + "options": { + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ";", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', + "s3_object": "some_prefix/rdu-weather-history.csv", + }, + } + assert schema[2] == { + "table_name": "from_url", + "schema": None, + "columns": mock.ANY, + "options": { + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ";", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', + "url": url, + }, } @@ -105,33 +205,90 @@ def test_csv_data_source_s3(local_engine_empty): schema = source.introspect() - assert len(schema.keys()) == 3 - assert schema["fruits.csv"] == [ - TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), - TableColumn( - ordinal=2, - name="timestamp", - pg_type="timestamp without time zone", - is_pk=False, - comment=None, - ), - TableColumn(ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None), - TableColumn(ordinal=4, name="number", pg_type="integer", is_pk=False, comment=None), - TableColumn(ordinal=5, name="bignumber", pg_type="bigint", is_pk=False, comment=None), - TableColumn(ordinal=6, name="vbignumber", pg_type="numeric", is_pk=False, comment=None), - ] - assert schema["encoding-win-1252.csv"] == [ - TableColumn(ordinal=1, name="col_1", pg_type="integer", is_pk=False, comment=None), - TableColumn(ordinal=2, name="DATE", pg_type="character varying", is_pk=False, comment=None), - TableColumn(ordinal=3, name="TEXT", pg_type="character varying", is_pk=False, comment=None), - ] - assert len(schema["rdu-weather-history.csv"]) == 28 + assert len(schema.keys()) == 4 + assert schema["fruits.csv"] == ( + [ + TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), + TableColumn( + ordinal=2, + name="timestamp", + pg_type="timestamp without time zone", + is_pk=False, + comment=None, + ), + TableColumn( + ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None + ), + TableColumn(ordinal=4, name="number", pg_type="integer", is_pk=False, comment=None), + TableColumn(ordinal=5, name="bignumber", pg_type="bigint", is_pk=False, comment=None), + TableColumn(ordinal=6, name="vbignumber", pg_type="numeric", is_pk=False, comment=None), + ], + { + "s3_object": "some_prefix/fruits.csv", + "autodetect_dialect": False, + "autodetect_encoding": False, + "autodetect_header": False, + "delimiter": ",", + "encoding": "utf-8", + "header": True, + "quotechar": '"', + }, + ) + assert schema["encoding-win-1252.csv"] == ( + [ + TableColumn(ordinal=1, name="col_1", pg_type="integer", is_pk=False, comment=None), + TableColumn( + ordinal=2, name="DATE", pg_type="character varying", is_pk=False, comment=None + ), + TableColumn( + ordinal=3, name="TEXT", pg_type="character varying", is_pk=False, comment=None + ), + ], + { + "s3_object": "some_prefix/encoding-win-1252.csv", + "autodetect_dialect": False, + "autodetect_encoding": False, + "autodetect_header": False, + "delimiter": ";", + "encoding": "Windows-1252", + "header": True, + "quotechar": '"', + }, + ) + assert len(schema["rdu-weather-history.csv"][0]) == 28 + + assert schema["not_a_csv.txt"] == MountError( + table_name="not_a_csv.txt", + error="ValueError", + error_text="Malformed CSV: header has 7 columns, rows have 0 columns", + ) + + schema = unwrap(schema)[0] + + # Add a nonexistent file to the schema with malformed params to check preview error reporting + schema["doesnt_exist"] = ( + [], + {"s3_object": "doesnt_exist"}, + ) + schema["exists_but_broken"] = ( + # Force a schema that doesn't work for this CSV + [TableColumn(1, "col_1", "date", False)], + {"s3_object": "some_prefix/fruits.csv"}, + ) preview = source.preview(schema) - assert len(preview.keys()) == 3 + assert len(preview.keys()) == 5 assert len(preview["fruits.csv"]) == 4 assert len(preview["encoding-win-1252.csv"]) == 3 assert len(preview["rdu-weather-history.csv"]) == 10 + assert preview["doesnt_exist"] == MountError( + table_name="doesnt_exist", error="minio.error.S3Error", error_text=mock.ANY + ) + assert preview["exists_but_broken"] == MountError( + table_name="exists_but_broken", + error="psycopg2.errors.InvalidDatetimeFormat", + error_text='invalid input syntax for type date: "1"', + ) try: source.mount("temp_data") @@ -158,6 +315,148 @@ def test_csv_data_source_s3(local_engine_empty): local_engine_empty.delete_schema("temp_data") +def test_csv_data_source_multiple(local_engine_empty): + # End-to-end version for test_csv_introspection_multiple to check things like table params + # getting serialized and deserialized properly. + + url = MINIO.presigned_get_object("test_csv", "some_prefix/rdu-weather-history.csv") + + credentials = { + "s3_access_key": "minioclient", + "s3_secret_key": "supersecure", + } + + params = { + "s3_endpoint": "objectstorage:9000", + "s3_secure": False, + "s3_bucket": "test_csv", + # Put this delimiter in as a canary to make sure table params override server params. + "delimiter": ",", + } + + tables = { + # Pass an empty table schema to denote we want to introspect it + "from_url": ([], {"url": url}), + "from_s3_rdu": ([], {"s3_object": "some_prefix/rdu-weather-history.csv"}), + "from_s3_encoding": ([], {"s3_object": "some_prefix/encoding-win-1252.csv"}), + "from_url_broken": ([], {"url": "invalid_url"}), + "from_s3_broken": ([], {"s3_object": "invalid_object"}), + } + + source = CSVDataSource( + local_engine_empty, + credentials, + params, + tables, + ) + + schema = source.introspect() + + assert schema == { + "from_url": ( + mock.ANY, + { + "autodetect_dialect": False, + "url": url, + "quotechar": '"', + "header": True, + "encoding": "utf-8", + "delimiter": ";", + "autodetect_header": False, + "autodetect_encoding": False, + }, + ), + "from_s3_rdu": ( + mock.ANY, + { + "encoding": "utf-8", + "autodetect_dialect": False, + "autodetect_encoding": False, + "autodetect_header": False, + "delimiter": ";", + "header": True, + "quotechar": '"', + "s3_object": "some_prefix/rdu-weather-history.csv", + }, + ), + "from_s3_encoding": ( + mock.ANY, + { + "s3_object": "some_prefix/encoding-win-1252.csv", + "quotechar": '"', + "header": True, + "encoding": "Windows-1252", + "autodetect_dialect": False, + "delimiter": ";", + "autodetect_header": False, + "autodetect_encoding": False, + }, + ), + "from_url_broken": MountError( + table_name="from_url_broken", + error="requests.exceptions.MissingSchema", + error_text="Invalid URL 'invalid_url': No schema supplied. Perhaps you meant http://invalid_url?", + ), + "from_s3_broken": MountError( + table_name="from_s3_broken", + error="minio.error.S3Error", + error_text=mock.ANY, + ), + } + + # Mount the datasets with this introspected schema. + schema = unwrap(schema)[0] + try: + source.mount("temp_data", tables=schema) + rows = local_engine_empty.run_sql("SELECT * FROM temp_data.from_s3_encoding") + assert len(rows) == 3 + assert len(rows[0]) == 3 + finally: + local_engine_empty.delete_schema("temp_data") + + # Override the delimiter and blank out the schema for a single table + schema["from_s3_encoding"] = ( + [], + { + "s3_object": "some_prefix/encoding-win-1252.csv", + "quotechar": '"', + "header": True, + "encoding": "Windows-1252", + "autodetect_dialect": False, + # We force a delimiter "," here which will make the CSV a single-column one + # (to test we can actually override these) + "delimiter": ",", + "autodetect_header": False, + "autodetect_encoding": False, + }, + ) + + # Reintrospect the source with the new table parameters + source = CSVDataSource(local_engine_empty, credentials, params, schema) + new_schema = source.introspect() + assert len(new_schema) == 3 + # Check other tables are unchanged + assert new_schema["from_url"] == schema["from_url"] + assert new_schema["from_s3_rdu"] == schema["from_s3_rdu"] + + # Table with a changed separator only has one column (since we have , for delimiter + # instead of ;) + assert new_schema["from_s3_encoding"][0] == [ + TableColumn( + ordinal=1, name=";DATE;TEXT", pg_type="character varying", is_pk=False, comment=None + ) + ] + + try: + source.mount("temp_data", tables=new_schema) + rows = local_engine_empty.run_sql("SELECT * FROM temp_data.from_s3_encoding") + assert len(rows) == 3 + # Check we get one column now + assert rows[0] == ("1;01/07/2021;PaƱamao",) + finally: + local_engine_empty.delete_schema("temp_data") + + def test_csv_data_source_http(local_engine_empty): source = CSVDataSource( local_engine_empty, @@ -169,7 +468,7 @@ def test_csv_data_source_http(local_engine_empty): schema = source.introspect() assert len(schema.keys()) == 1 - assert len(schema["data"]) == 28 + assert len(schema["data"][0]) == 28 preview = source.preview(schema) assert len(preview.keys()) == 1 @@ -189,15 +488,9 @@ def test_csv_dialect_encoding_inference(): assert options.encoding == "Windows-1252" assert options.header is True - - # TODO: we keep these in the dialect struct rather than extract back out into the - # CSVOptions. Might need to do the latter if we want to return the proposed FDW table - # params to the user. - - # Note this line terminator is always "\r\n" since CSV assumes we use the - # universal newlines mode. - assert options.dialect.lineterminator == "\r\n" - assert options.dialect.delimiter == ";" + # NB we don't extract everything from the sniffed dialect, just the delimiter and the + # quotechar. The sniffer also returns doublequote and skipinitialspace. + assert options.delimiter == ";" data = list(reader) diff --git a/test/splitgraph/ingestion/test_singer.py b/test/splitgraph/ingestion/test_singer.py index 7c59521b..7af7bb8a 100644 --- a/test/splitgraph/ingestion/test_singer.py +++ b/test/splitgraph/ingestion/test_singer.py @@ -22,7 +22,11 @@ _STARGAZERS_SCHEMA = [ TableColumn( - ordinal=0, name="_sdc_repository", pg_type="character varying", is_pk=False, comment=None, + ordinal=0, + name="_sdc_repository", + pg_type="character varying", + is_pk=False, + comment=None, ), TableColumn( ordinal=1, @@ -38,7 +42,11 @@ _RELEASES_SCHEMA = [ TableColumn( - ordinal=0, name="_sdc_repository", pg_type="character varying", is_pk=False, comment=None, + ordinal=0, + name="_sdc_repository", + pg_type="character varying", + is_pk=False, + comment=None, ), TableColumn(ordinal=1, name="author", pg_type="jsonb", is_pk=False, comment=None), TableColumn(ordinal=2, name="body", pg_type="character varying", is_pk=False, comment=None), @@ -65,7 +73,11 @@ ordinal=10, name="tag_name", pg_type="character varying", is_pk=False, comment=None ), TableColumn( - ordinal=11, name="target_commitish", pg_type="character varying", is_pk=False, comment=None, + ordinal=11, + name="target_commitish", + pg_type="character varying", + is_pk=False, + comment=None, ), TableColumn(ordinal=12, name="url", pg_type="character varying", is_pk=False, comment=None), ] @@ -210,7 +222,9 @@ def test_singer_ingestion_schema_change(local_engine_empty): assert json.loads(result.stdout) == { "bookmarks": { - "splitgraph/splitgraph": {"stargazers": {"since": "2020-10-14T11:06:42.565793Z"},} + "splitgraph/splitgraph": { + "stargazers": {"since": "2020-10-14T11:06:42.565793Z"}, + } } } repo = Repository.from_schema(TEST_REPO) @@ -337,7 +351,7 @@ def test_singer_data_source_introspect(local_engine_empty): ) schema = source.introspect() - assert schema == {"releases": _RELEASES_SCHEMA, "stargazers": _STARGAZERS_SCHEMA} + assert schema == {"releases": (_RELEASES_SCHEMA, {}), "stargazers": (_STARGAZERS_SCHEMA, {})} def test_singer_data_source_sync(local_engine_empty): @@ -390,8 +404,15 @@ def test_singer_data_source_sync(local_engine_empty): def _source(local_engine_empty): return MySQLSingerDataSource( engine=local_engine_empty, - params={"replication_method": "INCREMENTAL", "host": "localhost", "port": 3306,}, - credentials={"user": "originuser", "password": "originpass",}, + params={ + "replication_method": "INCREMENTAL", + "host": "localhost", + "port": 3306, + }, + credentials={ + "user": "originuser", + "password": "originpass", + }, ) @@ -400,20 +421,27 @@ def _source(local_engine_empty): def test_singer_tap_mysql_introspection(local_engine_empty): source = _source(local_engine_empty) assert source.introspect() == { - "mushrooms": [ - TableColumn( - ordinal=0, - name="discovery", - pg_type="timestamp without time zone", - is_pk=False, - comment=None, - ), - TableColumn(ordinal=1, name="friendly", pg_type="boolean", is_pk=False, comment=None), - TableColumn(ordinal=2, name="mushroom_id", pg_type="integer", is_pk=True, comment=None), - TableColumn( - ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], + "mushrooms": ( + [ + TableColumn( + ordinal=0, + name="discovery", + pg_type="timestamp without time zone", + is_pk=False, + comment=None, + ), + TableColumn( + ordinal=1, name="friendly", pg_type="boolean", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="mushroom_id", pg_type="integer", is_pk=True, comment=None + ), + TableColumn( + ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None + ), + ], + {}, + ) } singer_config = source.get_singer_config() @@ -498,7 +526,10 @@ def test_singer_tap_mysql_introspection(local_engine_empty): }, { "breadcrumb": ["properties", "name"], - "metadata": {"selected-by-default": True, "sql-datatype": "varchar(20)",}, + "metadata": { + "selected-by-default": True, + "sql-datatype": "varchar(20)", + }, }, { "breadcrumb": ["properties", "varbinary_data"], diff --git a/test/splitgraph/ingestion/test_snowflake.py b/test/splitgraph/ingestion/test_snowflake.py index c062e652..7b863a99 100644 --- a/test/splitgraph/ingestion/test_snowflake.py +++ b/test/splitgraph/ingestion/test_snowflake.py @@ -4,6 +4,7 @@ import pytest +from splitgraph.core.types import dict_to_table_schema_params from splitgraph.ingestion.snowflake import SnowflakeDataSource _sample_privkey = """-----BEGIN PRIVATE KEY----- @@ -70,7 +71,10 @@ def test_snowflake_data_source_dburl_conversion_no_warehouse(): "password": "password", "account": "abcdef.eu-west-1.aws", }, - params={"database": "SOME_DB", "schema": "TPCH_SF100",}, + params={ + "database": "SOME_DB", + "schema": "TPCH_SF100", + }, ) assert source.get_server_options() == { @@ -89,7 +93,10 @@ def test_snowflake_data_source_private_key(private_key): "private_key": private_key, "account": "abcdef.eu-west-1.aws", }, - params={"database": "SOME_DB", "schema": "TPCH_SF100",}, + params={ + "database": "SOME_DB", + "schema": "TPCH_SF100", + }, ) opts = source.get_server_options() @@ -114,13 +121,15 @@ def test_snowflake_data_source_table_options(): }, params={ "database": "SOME_DB", - "tables": { + }, + tables=dict_to_table_schema_params( + { "test_table": { "schema": {"col_1": "int", "col_2": "varchar"}, "options": {"subquery": "SELECT col_1, col_2 FROM other_table"}, } - }, - }, + } + ), ) assert source.get_table_options("test_table") == { diff --git a/test/splitgraph/test_misc.py b/test/splitgraph/test_misc.py index ecf888d1..26ca1160 100644 --- a/test/splitgraph/test_misc.py +++ b/test/splitgraph/test_misc.py @@ -9,7 +9,11 @@ from splitgraph.core.engine import lookup_repository from splitgraph.core.metadata_manager import Object from splitgraph.core.repository import Repository -from splitgraph.core.types import TableColumn, tableschema_to_dict, dict_to_tableschema +from splitgraph.core.types import ( + TableColumn, + dict_to_table_schema_params, + table_schema_params_to_dict, +) from splitgraph.engine.postgres.engine import API_MAX_QUERY_LENGTH from splitgraph.exceptions import RepositoryNotFoundError from splitgraph.hooks.s3 import get_object_upload_urls @@ -272,7 +276,12 @@ def test_large_api_calls(unprivileged_pg_repo): unprivileged_pg_repo, [ (image_hash, "small_table", [(1, "key", "integer", True)], [all_ids[0]]), - (image_hash, "table", [(1, "key", "integer", True)], all_ids,), + ( + image_hash, + "table", + [(1, "key", "integer", True)], + all_ids, + ), ], ) @@ -310,51 +319,85 @@ def test_parse_dt(): parse_dt("not a dt") -def test_tableschema_to_dict(): - assert tableschema_to_dict( +def test_table_schema_params_to_dict(): + assert table_schema_params_to_dict( { - "fruits": [ + "fruits": ( + [ + TableColumn( + ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, + ), + ], + {"key": "value"}, + ), + "vegetables": ( + [ + TableColumn( + ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, + ), + ], + {"key": "value"}, + ), + } + ) == { + "fruits": { + "schema": {"fruit_id": "integer", "name": "character varying"}, + "options": {"key": "value"}, + }, + "vegetables": { + "schema": {"name": "character varying", "vegetable_id": "integer"}, + "options": {"key": "value"}, + }, + } + + +def test_dict_to_table_schema_params(): + assert dict_to_table_schema_params( + { + "fruits": { + "schema": {"fruit_id": "integer", "name": "character varying"}, + "options": {"key": "value"}, + }, + "vegetables": { + "schema": {"name": "character varying", "vegetable_id": "integer"}, + "options": {"key": "value"}, + }, + } + ) == { + "fruits": ( + [ TableColumn( ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None ), TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None, + ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None ), ], - "vegetables": [ + {"key": "value"}, + ), + "vegetables": ( + [ TableColumn( - ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None + ordinal=1, name="name", pg_type="character varying", is_pk=False, comment=None ), TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None, + ordinal=2, name="vegetable_id", pg_type="integer", is_pk=False, comment=None ), ], - } - ) == { - "fruits": {"fruit_id": "integer", "name": "character varying"}, - "vegetables": {"name": "character varying", "vegetable_id": "integer"}, - } - - -def test_dict_to_tableschema(): - assert dict_to_tableschema( - { - "fruits": {"fruit_id": "integer", "name": "character varying"}, - "vegetables": {"name": "character varying", "vegetable_id": "integer"}, - } - ) == { - "fruits": [ - TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), - TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], - "vegetables": [ - TableColumn( - ordinal=1, name="name", pg_type="character varying", is_pk=False, comment=None - ), - TableColumn( - ordinal=2, name="vegetable_id", pg_type="integer", is_pk=False, comment=None - ), - ], + {"key": "value"}, + ), }