From d6fce0812bcfd3ca2d734d3ab6086bd115b76143 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 16:17:42 +0100 Subject: [PATCH 01/16] Rename some parameters for mongo_fdw and mysql_fdw to match the actual fdw_params (breaking change): * mysql: remote_schema -> dbname (MySQL uses `dbname`) * mongo: coll -> collection, db -> database --- examples/cross-db-analytics/mounting/matomo.json | 2 +- examples/dbt_two_databases/README.md | 4 ++-- examples/import-from-mongo/example.yaml | 4 ++-- examples/import-from-mongo/mongo_import.splitfile | 4 ++-- splitgraph/hooks/data_source/fdw.py | 13 ++++++------- test/splitgraph/commandline/test_mount.py | 2 +- test/splitgraph/conftest.py | 2 +- 7 files changed, 15 insertions(+), 16 deletions(-) diff --git a/examples/cross-db-analytics/mounting/matomo.json b/examples/cross-db-analytics/mounting/matomo.json index f3a304e6..52ee55fc 100644 --- a/examples/cross-db-analytics/mounting/matomo.json +++ b/examples/cross-db-analytics/mounting/matomo.json @@ -1,5 +1,5 @@ { - "remote_schema": "matomo", + "dbname": "matomo", "tables": { "matomo_access": { "schema": { diff --git a/examples/dbt_two_databases/README.md b/examples/dbt_two_databases/README.md index a25224bd..51bbdefb 100644 --- a/examples/dbt_two_databases/README.md +++ b/examples/dbt_two_databases/README.md @@ -72,8 +72,8 @@ $ sgr mount mongo_fdw order_data -c originro:originpass@mongo:27017 -o @- <, "coll": } + "options": {"database": , "collection": } } } ``` @@ -469,17 +469,16 @@ class MySQLDataSource(ForeignDataWrapperDataSource): "properties": { "host": {"type": "string"}, "port": {"type": "integer"}, - "remote_schema": {"type": "string"}, - "tables": _table_options_schema, + "dbname": {"type": "string"}, }, - "required": ["host", "port", "remote_schema"], + "required": ["host", "port", "dbname"], } commandline_help: str = """Mount a MySQL database. Mounts a schema on a remote MySQL database as a set of foreign tables locally.""" - commandline_kwargs_help: str = """remote_schema: Remote schema name (required) + commandline_kwargs_help: str = """dbname: Remote MySQL database name (required) tables: Tables to mount (default all). If a list, then will use IMPORT FOREIGN SCHEMA. If a dictionary, must have the format {"table_name": {"schema": {"col_1": "type_1", ...}, @@ -505,13 +504,13 @@ def get_user_options(self): return {"username": self.credentials["username"], "password": self.credentials["password"]} def get_table_options(self, table_name: str): - return {"dbname": self.params["remote_schema"]} + return {"dbname": self.params["dbname"]} def get_fdw_name(self): return "mysql_fdw" def get_remote_schema_name(self) -> str: - return str(self.params["remote_schema"]) + return str(self.params["dbname"]) class ElasticSearchDataSource(ForeignDataWrapperDataSource): diff --git a/test/splitgraph/commandline/test_mount.py b/test/splitgraph/commandline/test_mount.py index 64b971a1..fe62f0af 100644 --- a/test/splitgraph/commandline/test_mount.py +++ b/test/splitgraph/commandline/test_mount.py @@ -23,7 +23,7 @@ _MONGO_PARAMS = { "tables": { "stuff": { - "options": {"db": "origindb", "coll": "stuff",}, + "options": {"database": "origindb", "collection": "stuff",}, "schema": {"name": "text", "duration": "numeric", "happy": "boolean",}, } } diff --git a/test/splitgraph/conftest.py b/test/splitgraph/conftest.py index 48da84ba..965bbe68 100644 --- a/test/splitgraph/conftest.py +++ b/test/splitgraph/conftest.py @@ -120,7 +120,7 @@ def _mount_mysql(repository): port=3306, username="originuser", password="originpass", - remote_schema="mysqlschema", + dbname="mysqlschema", ), tables={ "mushrooms": [ From ad69339eb893abc432df27f8c59fda76ed6bfef3 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 16:36:54 +0100 Subject: [PATCH 02/16] Bunch of breaking changes (from the Py API POV -- commandline sgr mount invocations are unchanged) to the data source class: * Factor the table params out into a separate toplevel parameter, with its own JSONSchema (parameters for each table, e.g. index for ElasticSearch) * As a compat layer with sgr mount / existing invocations, look for `tables` in the current data source params and hoist it up into table params at init time * `TableInfo` type, besides a list of strings (table names), can now be a dict of table name -> (table schema, table params) instead of just table schema. This is accepted throughout various data source methods (mount, introspect, load, sync etc) * Introspection now returns the table schema and the table params (this is a TODO -- need to actually scrape the table params from the FDW options of the mounted tables) --- splitgraph/core/types.py | 25 ++- splitgraph/hooks/data_source/base.py | 48 +++-- splitgraph/hooks/data_source/fdw.py | 218 ++++++++++----------- splitgraph/hooks/mount_handlers.py | 2 +- splitgraph/ingestion/csv/__init__.py | 4 +- splitgraph/ingestion/singer/data_source.py | 10 +- splitgraph/ingestion/snowflake/__init__.py | 4 +- test/splitgraph/commands/test_mounting.py | 52 +++-- test/splitgraph/conftest.py | 23 ++- test/splitgraph/ingestion/test_common.py | 33 ++-- test/splitgraph/ingestion/test_singer.py | 73 +++++-- test/splitgraph/test_misc.py | 71 ++++--- 12 files changed, 328 insertions(+), 235 deletions(-) diff --git a/splitgraph/core/types.py b/splitgraph/core/types.py index 385a2738..ac58224e 100644 --- a/splitgraph/core/types.py +++ b/splitgraph/core/types.py @@ -18,6 +18,14 @@ class TableColumn(NamedTuple): SourcesList = List[Dict[str, str]] ProvenanceLine = Dict[str, Union[str, List[str], List[bool], SourcesList]] +# Ingestion-related params +Credentials = Dict[str, Any] +Params = Dict[str, Any] +TableParams = Dict[str, Any] +TableInfo = Union[List[str], Dict[str, Tuple[TableSchema, TableParams]]] +SyncState = Dict[str, Any] +PreviewResult = Dict[str, Union[str, List[Dict[str, Any]]]] + class Comparable(metaclass=ABCMeta): @abstractmethod @@ -25,13 +33,18 @@ def __lt__(self, other: Any) -> bool: ... -def dict_to_tableschema(tables: Dict[str, Dict[str, Any]]) -> Dict[str, TableSchema]: +def dict_to_table_schema_params( + tables: Dict[str, Dict[str, Any]] +) -> Dict[str, Tuple[TableSchema, TableParams]]: return { - t: [ - TableColumn(i + 1, cname, ctype, False, None) - for (i, (cname, ctype)) in enumerate(ts.items()) - ] - for t, ts in tables.items() + t: ( + [ + TableColumn(i + 1, cname, ctype, False, None) + for (i, (cname, ctype)) in enumerate(tsp["schema"].items()) + ], + tsp.get("options", {}), + ) + for t, tsp in tables.items() } diff --git a/splitgraph/hooks/data_source/base.py b/splitgraph/hooks/data_source/base.py index 037f102a..88a2b9e1 100644 --- a/splitgraph/hooks/data_source/base.py +++ b/splitgraph/hooks/data_source/base.py @@ -1,27 +1,27 @@ -import json from abc import ABC, abstractmethod from random import getrandbits -from typing import Dict, Any, Union, List, Optional, TYPE_CHECKING, cast, Tuple +from typing import Dict, Any, Optional, TYPE_CHECKING, cast, Tuple from psycopg2._json import Json from psycopg2.sql import SQL, Identifier from splitgraph.core.engine import repository_exists from splitgraph.core.image import Image -from splitgraph.core.types import TableSchema, TableColumn +from splitgraph.core.types import ( + TableSchema, + TableColumn, + Credentials, + Params, + TableParams, + TableInfo, + SyncState, +) from splitgraph.engine import ResultShape if TYPE_CHECKING: from splitgraph.engine.postgres.engine import PostgresEngine from splitgraph.core.repository import Repository -Credentials = Dict[str, Any] -Params = Dict[str, Any] -TableInfo = Union[List[str], Dict[str, TableSchema]] -SyncState = Dict[str, Any] -PreviewResult = Dict[str, Union[str, List[Dict[str, Any]]]] - - INGESTION_STATE_TABLE = "_sg_ingestion_state" INGESTION_STATE_SCHEMA = [ TableColumn(1, "timestamp", "timestamp", True, None), @@ -32,6 +32,7 @@ class DataSource(ABC): params_schema: Dict[str, Any] credentials_schema: Dict[str, Any] + table_params_schema: Dict[str, Any] supports_mount = False supports_sync = False @@ -47,19 +48,34 @@ def get_name(cls) -> str: def get_description(cls) -> str: raise NotImplementedError - def __init__(self, engine: "PostgresEngine", credentials: Credentials, params: Params): + def __init__( + self, + engine: "PostgresEngine", + credentials: Credentials, + params: Params, + tables: Optional[TableInfo] = None, + ): import jsonschema self.engine = engine + if "tables" in params: + tables = params.pop("tables") + jsonschema.validate(instance=credentials, schema=self.credentials_schema) jsonschema.validate(instance=params, schema=self.params_schema) self.credentials = credentials self.params = params + if isinstance(tables, dict): + for _, table_params in tables.values(): + jsonschema.validate(instance=table_params, schema=self.table_params_schema) + + self.tables = tables + @abstractmethod - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: # TODO here: dict str -> [tableschema, dict of suggested options] # params -- add table options as a separate field? # separate table schema @@ -78,7 +94,10 @@ class MountableDataSource(DataSource, ABC): @abstractmethod def mount( - self, schema: str, tables: Optional[TableInfo] = None, overwrite: bool = True, + self, + schema: str, + tables: Optional[TableInfo] = None, + overwrite: bool = True, ): """Instantiate the data source as foreign tables in a schema""" raise NotImplementedError @@ -98,7 +117,8 @@ def load(self, repository: "Repository", tables: Optional[TableInfo] = None) -> image_hash = "{:064x}".format(getrandbits(256)) tmp_schema = "{:064x}".format(getrandbits(256)) repository.images.add( - parent_id=None, image=image_hash, + parent_id=None, + image=image_hash, ) repository.object_engine.create_schema(tmp_schema) diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index f1156b9b..52097f85 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -1,17 +1,20 @@ import logging from abc import ABC, abstractmethod from copy import deepcopy -from typing import Optional, Mapping, Dict, List, Any, cast, Union, TYPE_CHECKING +from typing import Optional, Mapping, Dict, List, Any, cast, Union, TYPE_CHECKING, Tuple import psycopg2 from psycopg2.sql import SQL, Identifier -from splitgraph.core.types import dict_to_tableschema, TableSchema, TableColumn -from splitgraph.hooks.data_source.base import ( - Credentials, - Params, +from splitgraph.core.types import ( + TableSchema, + TableColumn, + dict_to_table_schema_params, + TableParams, TableInfo, PreviewResult, +) +from splitgraph.hooks.data_source.base import ( MountableDataSource, LoadableDataSource, ) @@ -20,24 +23,17 @@ from splitgraph.engine.postgres.engine import PostgresEngine -_table_options_schema = { - "type": "object", - "additionalProperties": { - "options": {"type": "object", "additionalProperties": {"type": "string"}}, - }, -} - - class ForeignDataWrapperDataSource(MountableDataSource, LoadableDataSource, ABC): - credentials_schema = { + credentials_schema: Dict[str, Any] = { "type": "object", - "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, - "required": ["username", "password"], } - params_schema = { + params_schema: Dict[str, Any] = { + "type": "object", + } + + table_params_schema: Dict[str, Any] = { "type": "object", - "properties": {"tables": _table_options_schema}, } commandline_help: str = "" @@ -46,15 +42,6 @@ class ForeignDataWrapperDataSource(MountableDataSource, LoadableDataSource, ABC) supports_mount = True supports_load = True - def __init__( - self, - engine: "PostgresEngine", - credentials: Optional[Credentials] = None, - params: Optional[Params] = None, - ): - self.tables: Optional[TableInfo] = None - super().__init__(engine, credentials or {}, params or {}) - @classmethod def from_commandline(cls, engine, commandline_kwargs) -> "ForeignDataWrapperDataSource": """Instantiate an FDW data source from commandline arguments.""" @@ -64,28 +51,19 @@ def from_commandline(cls, engine, commandline_kwargs) -> "ForeignDataWrapperData # By convention, the "tables" object can be: # * A list of tables to import - # * A dictionary table_name -> {"schema": schema, **mount_options} - table_kwargs = params.pop("tables", None) - tables: Optional[TableInfo] - - if isinstance(table_kwargs, dict): - tables = dict_to_tableschema( - {t: to["schema"] for t, to in table_kwargs.items() if "schema" in to} - ) - params["tables"] = { - t: {"options": to["options"]} for t, to in table_kwargs.items() if "options" in to - } - elif isinstance(table_kwargs, list): - tables = table_kwargs - else: - tables = None + # * A dictionary table_name -> {"schema": schema, "options": table options} + tables = params.pop("tables", None) + if isinstance(tables, dict): + tables = dict_to_table_schema_params(tables) + # Extract credentials from the cmdline params credentials: Dict[str, Any] = {} for k in cast(Dict[str, Any], cls.credentials_schema["properties"]).keys(): if k in params: credentials[k] = params[k] - result = cls(engine, credentials, params) - result.tables = tables + + result = cls(engine, credentials, params, tables) + return result def get_user_options(self) -> Mapping[str, str]: @@ -94,10 +72,15 @@ def get_user_options(self) -> Mapping[str, str]: def get_server_options(self) -> Mapping[str, str]: return {} - def get_table_options(self, table_name: str) -> Mapping[str, str]: + def get_table_options(self, table_name: str) -> Dict[str, str]: + if not isinstance(self.tables, dict): + return {} + return { k: str(v) - for k, v in self.params.get("tables", {}).get(table_name, {}).get("options", {}).items() + for k, v in self.tables.get( + table_name, cast(Tuple[TableSchema, TableParams], ({}, {})) + )[1].items() } def get_table_schema(self, table_name: str, table_schema: TableSchema) -> TableSchema: @@ -113,7 +96,10 @@ def get_fdw_name(self): pass def mount( - self, schema: str, tables: Optional[TableInfo] = None, overwrite: bool = True, + self, + schema: str, + tables: Optional[TableInfo] = None, + overwrite: bool = True, ): tables = tables or self.tables or [] @@ -133,7 +119,7 @@ def mount( self._create_foreign_tables(schema, server_id, tables) - def _create_foreign_tables(self, schema, server_id, tables): + def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo): if isinstance(tables, list): try: remote_schema = self.get_remote_schema_name() @@ -143,7 +129,7 @@ def _create_foreign_tables(self, schema, server_id, tables): ) _import_foreign_schema(self.engine, schema, remote_schema, server_id, tables) else: - for table_name, table_schema in tables.items(): + for table_name, (table_schema, _) in tables.items(): logging.info("Mounting table %s", table_name) query, args = create_foreign_table( schema, @@ -154,7 +140,7 @@ def _create_foreign_tables(self, schema, server_id, tables): ) self.engine.run_sql(query, args) - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: # Ability to override introspection by e.g. contacting the remote database without having # to mount the actual table. By default, just call out into mount() @@ -165,7 +151,8 @@ def introspect(self) -> Dict[str, TableSchema]: try: self.mount(tmp_schema) result = { - t: self.engine.get_full_table_schema(tmp_schema, t) + # TODO extract table_params from the mount + t: (self.engine.get_full_table_schema(tmp_schema, t), cast(TableParams, {})) for t in self.engine.get_all_tables(tmp_schema) } return result @@ -189,7 +176,7 @@ def _preview_table(self, schema: str, table: str, limit: int = 10) -> List[Dict[ ) return result_json - def preview(self, schema: Dict[str, TableSchema]) -> PreviewResult: + def preview(self, tables: Optional[TableInfo]) -> PreviewResult: # Preview data in tables mounted by this FDW / data source # Local import here since this data source gets imported by the commandline entry point @@ -197,7 +184,7 @@ def preview(self, schema: Dict[str, TableSchema]) -> PreviewResult: tmp_schema = get_temporary_table_id() try: - self.mount(tmp_schema, tables=schema) + self.mount(tmp_schema, tables=tables) # Commit so that errors don't cancel the mount self.engine.commit() result: Dict[str, Union[str, List[Dict[str, Any]]]] = {} @@ -324,7 +311,9 @@ def create_foreign_table( for col in schema_spec: if col.comment: query += SQL("COMMENT ON COLUMN {}.{}.{} IS %s;").format( - Identifier(schema), Identifier(table_name), Identifier(col.name), + Identifier(schema), + Identifier(table_name), + Identifier(col.name), ) args.append(col.comment) return query, args @@ -342,6 +331,12 @@ def get_description(cls) -> str: "based on postgres_fdw" ) + credentials_schema = { + "type": "object", + "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, + "required": ["username", "password"], + } + params_schema = { "type": "object", "properties": { @@ -349,11 +344,12 @@ def get_description(cls) -> str: "port": {"type": "integer", "description": "Port"}, "dbname": {"type": "string", "description": "Database name"}, "remote_schema": {"type": "string", "description": "Remote schema name"}, - "tables": _table_options_schema, }, "required": ["host", "port", "dbname", "remote_schema"], } + table_params_schema = {"type": "object"} + commandline_help: str = """Mount a Postgres database. Mounts a schema on a remote Postgres database as a set of foreign tables locally.""" @@ -389,27 +385,28 @@ def get_remote_schema_name(self) -> str: class MongoDataSource(ForeignDataWrapperDataSource): + credentials_schema = { + "type": "object", + "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, + "required": ["username", "password"], + } + params_schema = { "type": "object", "properties": { "host": {"type": "string"}, "port": {"type": "integer"}, - "tables": { - "type": "object", - "additionalProperties": { - "options": { - "type": "object", - "properties": { - "db": {"type": "string"}, - "coll": {"type": "string"}, - "required": ["db", "coll"], - }, - }, - "required": ["options"], - }, - }, }, - "required": ["host", "port", "tables"], + "required": ["host", "port"], + } + + table_params_schema = { + "type": "object", + "properties": { + "database": {"type": "string"}, + "collection": {"type": "string"}, + }, + "required": ["database", "collection"], } commandline_help = """Mount a Mongo database. @@ -444,14 +441,6 @@ def get_server_options(self): def get_user_options(self): return {"username": self.credentials["username"], "password": self.credentials["password"]} - def get_table_options(self, table_name: str): - try: - table_params = self.params["tables"][table_name]["options"] - except KeyError: - raise ValueError("No options specified for table %s!" % table_name) - - return {"database": table_params["db"], "collection": table_params["coll"]} - def get_fdw_name(self): return "mongo_fdw" @@ -464,6 +453,12 @@ def get_table_schema(self, table_name, table_schema): class MySQLDataSource(ForeignDataWrapperDataSource): + credentials_schema = { + "type": "object", + "properties": {"username": {"type": "string"}, "password": {"type": "string"}}, + "required": ["username", "password"], + } + params_schema = { "type": "object", "properties": { @@ -491,7 +486,7 @@ def get_name(cls) -> str: @classmethod def get_description(cls) -> str: - return "Data source for MySQL databases that supports live querying, " "based on mysql_fdw" + return "Data source for MySQL databases that supports live querying, based on mysql_fdw" def get_server_options(self): return { @@ -527,42 +522,39 @@ class ElasticSearchDataSource(ForeignDataWrapperDataSource): "properties": { "host": {"type": "string"}, "port": {"type": "integer"}, - "tables": { - "type": "object", - "additionalProperties": { - "options": { - "type": "object", - "properties": { - "index": { - "type": "string", - "description": 'ES index name or pattern to use, for example, "events-*"', - }, - "type": { - "type": "string", - "description": "Pre-ES7 doc_type, not required in ES7 or later", - }, - "query_column": { - "type": "string", - "description": "Name of the column to use to pass queries in", - }, - "score_column": { - "type": "string", - "description": "Name of the column with the document score", - }, - "scroll_size": { - "type": "integer", - "description": "Fetch size, default 1000", - }, - "scroll_duration": { - "type": "string", - "description": "How long to hold the scroll context open for, default 10m", - }, - }, - }, - }, + }, + "required": ["host", "port"], + } + + table_params_schema = { + "type": "object", + "properties": { + "index": { + "type": "string", + "description": 'ES index name or pattern to use, for example, "events-*"', + }, + "type": { + "type": "string", + "description": "Pre-ES7 doc_type, not required in ES7 or later", + }, + "query_column": { + "type": "string", + "description": "Name of the column to use to pass queries in", + }, + "score_column": { + "type": "string", + "description": "Name of the column with the document score", + }, + "scroll_size": { + "type": "integer", + "description": "Fetch size, default 1000", + }, + "scroll_duration": { + "type": "string", + "description": "How long to hold the scroll context open for, default 10m", }, }, - "required": ["host", "port", "tables"], + "required": ["index"], } commandline_help = """Mount an ElasticSearch instance. diff --git a/splitgraph/hooks/mount_handlers.py b/splitgraph/hooks/mount_handlers.py index 51ab60a6..023cb07a 100644 --- a/splitgraph/hooks/mount_handlers.py +++ b/splitgraph/hooks/mount_handlers.py @@ -6,7 +6,7 @@ from splitgraph.exceptions import DataSourceError if TYPE_CHECKING: - from splitgraph.hooks.data_source.base import TableInfo + from splitgraph.core.types import TableInfo def mount_postgres(mountpoint, **kwargs) -> None: diff --git a/splitgraph/ingestion/csv/__init__.py b/splitgraph/ingestion/csv/__init__.py index a3480226..7b820088 100644 --- a/splitgraph/ingestion/csv/__init__.py +++ b/splitgraph/ingestion/csv/__init__.py @@ -173,8 +173,8 @@ def from_commandline(cls, engine, commandline_kwargs) -> "CSVDataSource": credentials[k] = params[k] return cls(engine, credentials, params) - def get_table_options(self, table_name: str) -> Mapping[str, str]: - result = cast(Dict[str, str], super().get_table_options(table_name)) + def get_table_options(self, table_name: str) -> Dict[str, str]: + result = super().get_table_options(table_name) result["s3_object"] = result.get( "s3_object", self.params.get("s3_object_prefix", "") + table_name ) diff --git a/splitgraph/ingestion/singer/data_source.py b/splitgraph/ingestion/singer/data_source.py index 760bc21c..63df025e 100644 --- a/splitgraph/ingestion/singer/data_source.py +++ b/splitgraph/ingestion/singer/data_source.py @@ -7,21 +7,19 @@ from contextlib import contextmanager from io import StringIO from threading import Thread -from typing import Dict, Any, Optional, cast +from typing import Dict, Any, Optional, cast, Tuple from psycopg2.sql import Identifier, SQL from splitgraph.core.repository import Repository -from splitgraph.core.types import TableSchema +from splitgraph.core.types import TableSchema, TableParams, TableInfo, SyncState from splitgraph.exceptions import DataSourceError from splitgraph.hooks.data_source.base import ( get_ingestion_state, INGESTION_STATE_TABLE, INGESTION_STATE_SCHEMA, - TableInfo, prepare_new_image, SyncableDataSource, - SyncState, ) from splitgraph.ingestion.singer.db_sync import ( get_table_name, @@ -197,7 +195,7 @@ def build_singer_catalog( catalog, tables=tables, use_legacy_stream_selection=self.use_legacy_stream_selection ) - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: config = self.get_singer_config() singer_schema = self._run_singer_discovery(config) @@ -205,7 +203,7 @@ def introspect(self) -> Dict[str, TableSchema]: for stream in singer_schema["streams"]: stream_name = get_table_name(stream) stream_schema = get_sg_schema(stream) - result[stream_name] = stream_schema + result[stream_name] = (stream_schema, cast(TableParams, {})) return result diff --git a/splitgraph/ingestion/snowflake/__init__.py b/splitgraph/ingestion/snowflake/__init__.py index ac3c2b27..7afe230e 100644 --- a/splitgraph/ingestion/snowflake/__init__.py +++ b/splitgraph/ingestion/snowflake/__init__.py @@ -135,8 +135,8 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Schema, table or a subquery from a Snowflake database" - def get_table_options(self, table_name: str) -> Mapping[str, str]: - result = cast(Dict[str, str], super().get_table_options(table_name)) + def get_table_options(self, table_name: str) -> Dict[str, str]: + result = super().get_table_options(table_name) result["tablename"] = result.get("tablename", table_name) return result diff --git a/test/splitgraph/commands/test_mounting.py b/test/splitgraph/commands/test_mounting.py index cd5de8a7..0606ee87 100644 --- a/test/splitgraph/commands/test_mounting.py +++ b/test/splitgraph/commands/test_mounting.py @@ -63,26 +63,34 @@ def test_mount_introspection_preview(local_engine_empty): params={"host": "pgorigin", "port": 5432, "dbname": "origindb", "remote_schema": "public"}, ) - schema = handler.introspect() - - assert schema == { - "fruits": [ - TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), - TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], - "vegetables": [ - TableColumn( - ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None - ), - TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], + tables = handler.introspect() + + assert tables == { + "fruits": ( + [ + TableColumn( + ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None + ), + ], + {}, + ), + "vegetables": ( + [ + TableColumn( + ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None + ), + ], + {}, + ), } - preview = handler.preview(schema=schema) + preview = handler.preview(tables=tables) assert preview == { "fruits": [{"fruit_id": 1, "name": "apple"}, {"fruit_id": 2, "name": "orange"}], "vegetables": [ @@ -131,9 +139,11 @@ def test_mount_elasticsearch(local_engine_empty): "col_1": "text", "col_2": "boolean", }, - "index": "index-pattern*", - "rowid_column": "id", - "query_column": "query", + "options": { + "index": "index-pattern*", + "rowid_column": "id", + "query_column": "query", + }, } }, ), diff --git a/test/splitgraph/conftest.py b/test/splitgraph/conftest.py index 965bbe68..6fbf8e1d 100644 --- a/test/splitgraph/conftest.py +++ b/test/splitgraph/conftest.py @@ -98,8 +98,8 @@ def _mount_mongo(repository): tables=dict( stuff={ "options": { - "db": "origindb", - "coll": "stuff", + "database": "origindb", + "collection": "stuff", }, "schema": {"name": "text", "duration": "numeric", "happy": "boolean"}, }, @@ -123,14 +123,17 @@ def _mount_mysql(repository): dbname="mysqlschema", ), tables={ - "mushrooms": [ - TableColumn(1, "mushroom_id", "integer", False), - TableColumn(2, "name", "character varying (20)", False), - TableColumn(3, "discovery", "timestamp", False), - TableColumn(4, "friendly", "boolean", False), - TableColumn(5, "binary_data", "bytea", False), - TableColumn(6, "varbinary_data", "bytea", False), - ] + "mushrooms": ( + [ + TableColumn(1, "mushroom_id", "integer", False), + TableColumn(2, "name", "character varying (20)", False), + TableColumn(3, "discovery", "timestamp", False), + TableColumn(4, "friendly", "boolean", False), + TableColumn(5, "binary_data", "bytea", False), + TableColumn(6, "varbinary_data", "bytea", False), + ], + {}, + ) }, ) diff --git a/test/splitgraph/ingestion/test_common.py b/test/splitgraph/ingestion/test_common.py index d5772f7d..dbf46c2c 100644 --- a/test/splitgraph/ingestion/test_common.py +++ b/test/splitgraph/ingestion/test_common.py @@ -1,15 +1,18 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Tuple from splitgraph.core.repository import Repository -from splitgraph.core.types import TableSchema, TableColumn +from splitgraph.core.types import TableSchema, TableColumn, TableInfo, SyncState, TableParams from splitgraph.engine import ResultShape -from splitgraph.hooks.data_source.base import SyncState, TableInfo, SyncableDataSource - -SCHEMA = { - "test_table": [ - TableColumn(1, "key", "integer", True), - TableColumn(2, "value", "character varying", False), - ] +from splitgraph.hooks.data_source.base import SyncableDataSource + +SCHEMA: Dict[str, Tuple[TableSchema, TableParams]] = { + "test_table": ( + [ + TableColumn(1, "key", "integer", True), + TableColumn(2, "value", "character varying", False), + ], + {}, + ) } TEST_REPO = "test/generic_sync" @@ -26,26 +29,22 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Test ingestion" - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: return SCHEMA def _sync( self, schema: str, state: Optional[SyncState] = None, tables: Optional[TableInfo] = None ) -> SyncState: if not self.engine.table_exists(schema, "test_table"): - self.engine.create_table(schema, "test_table", SCHEMA["test_table"]) + self.engine.create_table(schema, "test_table", SCHEMA["test_table"][0]) if not state: - self.engine.run_sql_in( - schema, "INSERT INTO test_table (key, value) " "VALUES (1, 'one')" - ) + self.engine.run_sql_in(schema, "INSERT INTO test_table (key, value) VALUES (1, 'one')") return {"last_value": 1} else: last_value = state["last_value"] assert last_value == 1 - self.engine.run_sql_in( - schema, "INSERT INTO test_table (key, value) " "VALUES (2, 'two')" - ) + self.engine.run_sql_in(schema, "INSERT INTO test_table (key, value) VALUES (2, 'two')") return {"last_value": 2} diff --git a/test/splitgraph/ingestion/test_singer.py b/test/splitgraph/ingestion/test_singer.py index 7c59521b..fd5c82d1 100644 --- a/test/splitgraph/ingestion/test_singer.py +++ b/test/splitgraph/ingestion/test_singer.py @@ -22,7 +22,11 @@ _STARGAZERS_SCHEMA = [ TableColumn( - ordinal=0, name="_sdc_repository", pg_type="character varying", is_pk=False, comment=None, + ordinal=0, + name="_sdc_repository", + pg_type="character varying", + is_pk=False, + comment=None, ), TableColumn( ordinal=1, @@ -38,7 +42,11 @@ _RELEASES_SCHEMA = [ TableColumn( - ordinal=0, name="_sdc_repository", pg_type="character varying", is_pk=False, comment=None, + ordinal=0, + name="_sdc_repository", + pg_type="character varying", + is_pk=False, + comment=None, ), TableColumn(ordinal=1, name="author", pg_type="jsonb", is_pk=False, comment=None), TableColumn(ordinal=2, name="body", pg_type="character varying", is_pk=False, comment=None), @@ -65,7 +73,11 @@ ordinal=10, name="tag_name", pg_type="character varying", is_pk=False, comment=None ), TableColumn( - ordinal=11, name="target_commitish", pg_type="character varying", is_pk=False, comment=None, + ordinal=11, + name="target_commitish", + pg_type="character varying", + is_pk=False, + comment=None, ), TableColumn(ordinal=12, name="url", pg_type="character varying", is_pk=False, comment=None), ] @@ -210,7 +222,9 @@ def test_singer_ingestion_schema_change(local_engine_empty): assert json.loads(result.stdout) == { "bookmarks": { - "splitgraph/splitgraph": {"stargazers": {"since": "2020-10-14T11:06:42.565793Z"},} + "splitgraph/splitgraph": { + "stargazers": {"since": "2020-10-14T11:06:42.565793Z"}, + } } } repo = Repository.from_schema(TEST_REPO) @@ -390,8 +404,15 @@ def test_singer_data_source_sync(local_engine_empty): def _source(local_engine_empty): return MySQLSingerDataSource( engine=local_engine_empty, - params={"replication_method": "INCREMENTAL", "host": "localhost", "port": 3306,}, - credentials={"user": "originuser", "password": "originpass",}, + params={ + "replication_method": "INCREMENTAL", + "host": "localhost", + "port": 3306, + }, + credentials={ + "user": "originuser", + "password": "originpass", + }, ) @@ -400,20 +421,27 @@ def _source(local_engine_empty): def test_singer_tap_mysql_introspection(local_engine_empty): source = _source(local_engine_empty) assert source.introspect() == { - "mushrooms": [ - TableColumn( - ordinal=0, - name="discovery", - pg_type="timestamp without time zone", - is_pk=False, - comment=None, - ), - TableColumn(ordinal=1, name="friendly", pg_type="boolean", is_pk=False, comment=None), - TableColumn(ordinal=2, name="mushroom_id", pg_type="integer", is_pk=True, comment=None), - TableColumn( - ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], + "mushrooms": ( + [ + TableColumn( + ordinal=0, + name="discovery", + pg_type="timestamp without time zone", + is_pk=False, + comment=None, + ), + TableColumn( + ordinal=1, name="friendly", pg_type="boolean", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="mushroom_id", pg_type="integer", is_pk=True, comment=None + ), + TableColumn( + ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None + ), + ], + {}, + ) } singer_config = source.get_singer_config() @@ -498,7 +526,10 @@ def test_singer_tap_mysql_introspection(local_engine_empty): }, { "breadcrumb": ["properties", "name"], - "metadata": {"selected-by-default": True, "sql-datatype": "varchar(20)",}, + "metadata": { + "selected-by-default": True, + "sql-datatype": "varchar(20)", + }, }, { "breadcrumb": ["properties", "varbinary_data"], diff --git a/test/splitgraph/test_misc.py b/test/splitgraph/test_misc.py index ecf888d1..7ff7c83b 100644 --- a/test/splitgraph/test_misc.py +++ b/test/splitgraph/test_misc.py @@ -9,7 +9,7 @@ from splitgraph.core.engine import lookup_repository from splitgraph.core.metadata_manager import Object from splitgraph.core.repository import Repository -from splitgraph.core.types import TableColumn, tableschema_to_dict, dict_to_tableschema +from splitgraph.core.types import TableColumn, tableschema_to_dict, dict_to_table_schema_params from splitgraph.engine.postgres.engine import API_MAX_QUERY_LENGTH from splitgraph.exceptions import RepositoryNotFoundError from splitgraph.hooks.s3 import get_object_upload_urls @@ -272,7 +272,12 @@ def test_large_api_calls(unprivileged_pg_repo): unprivileged_pg_repo, [ (image_hash, "small_table", [(1, "key", "integer", True)], [all_ids[0]]), - (image_hash, "table", [(1, "key", "integer", True)], all_ids,), + ( + image_hash, + "table", + [(1, "key", "integer", True)], + all_ids, + ), ], ) @@ -318,7 +323,11 @@ def test_tableschema_to_dict(): ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None ), TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None, + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, ), ], "vegetables": [ @@ -326,7 +335,11 @@ def test_tableschema_to_dict(): ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None ), TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None, + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, ), ], } @@ -336,25 +349,39 @@ def test_tableschema_to_dict(): } -def test_dict_to_tableschema(): - assert dict_to_tableschema( +def test_dict_to_table_schema_params(): + assert dict_to_table_schema_params( { - "fruits": {"fruit_id": "integer", "name": "character varying"}, - "vegetables": {"name": "character varying", "vegetable_id": "integer"}, + "fruits": { + "schema": {"fruit_id": "integer", "name": "character varying"}, + "options": {"key": "value"}, + }, + "vegetables": { + "schema": {"name": "character varying", "vegetable_id": "integer"}, + "options": {"key": "value"}, + }, } ) == { - "fruits": [ - TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), - TableColumn( - ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None - ), - ], - "vegetables": [ - TableColumn( - ordinal=1, name="name", pg_type="character varying", is_pk=False, comment=None - ), - TableColumn( - ordinal=2, name="vegetable_id", pg_type="integer", is_pk=False, comment=None - ), - ], + "fruits": ( + [ + TableColumn( + ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None + ), + ], + {"key": "value"}, + ), + "vegetables": ( + [ + TableColumn( + ordinal=1, name="name", pg_type="character varying", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, name="vegetable_id", pg_type="integer", is_pk=False, comment=None + ), + ], + {"key": "value"}, + ), } From c76c06e3cd8f80838a4b170c46e8d585a151146b Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 16:48:03 +0100 Subject: [PATCH 03/16] Fix the Snowflake adapter (extract table options out as well) --- splitgraph/ingestion/snowflake/__init__.py | 15 +++++++++------ test/splitgraph/ingestion/test_snowflake.py | 19 ++++++++++++++----- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/splitgraph/ingestion/snowflake/__init__.py b/splitgraph/ingestion/snowflake/__init__.py index 7afe230e..a357df70 100644 --- a/splitgraph/ingestion/snowflake/__init__.py +++ b/splitgraph/ingestion/snowflake/__init__.py @@ -50,12 +50,6 @@ class SnowflakeDataSource(ForeignDataWrapperDataSource): params_schema = { "type": "object", "properties": { - "tables": { - "type": "object", - "additionalProperties": { - "options": {"type": "object", "additionalProperties": {"type": "string"}}, - }, - }, "database": {"type": "string", "description": "Snowflake database name"}, "schema": {"type": "string", "description": "Snowflake schema"}, "warehouse": {"type": "string", "description": "Warehouse name"}, @@ -72,6 +66,15 @@ class SnowflakeDataSource(ForeignDataWrapperDataSource): "required": ["database"], } + table_params_schema = { + "type": "object", + "properties": { + "subquery": { + "type": "string", + "description": "Subquery for this table to run on the server side", + } + }, + } supports_mount = True supports_load = True supports_sync = False diff --git a/test/splitgraph/ingestion/test_snowflake.py b/test/splitgraph/ingestion/test_snowflake.py index c062e652..7b863a99 100644 --- a/test/splitgraph/ingestion/test_snowflake.py +++ b/test/splitgraph/ingestion/test_snowflake.py @@ -4,6 +4,7 @@ import pytest +from splitgraph.core.types import dict_to_table_schema_params from splitgraph.ingestion.snowflake import SnowflakeDataSource _sample_privkey = """-----BEGIN PRIVATE KEY----- @@ -70,7 +71,10 @@ def test_snowflake_data_source_dburl_conversion_no_warehouse(): "password": "password", "account": "abcdef.eu-west-1.aws", }, - params={"database": "SOME_DB", "schema": "TPCH_SF100",}, + params={ + "database": "SOME_DB", + "schema": "TPCH_SF100", + }, ) assert source.get_server_options() == { @@ -89,7 +93,10 @@ def test_snowflake_data_source_private_key(private_key): "private_key": private_key, "account": "abcdef.eu-west-1.aws", }, - params={"database": "SOME_DB", "schema": "TPCH_SF100",}, + params={ + "database": "SOME_DB", + "schema": "TPCH_SF100", + }, ) opts = source.get_server_options() @@ -114,13 +121,15 @@ def test_snowflake_data_source_table_options(): }, params={ "database": "SOME_DB", - "tables": { + }, + tables=dict_to_table_schema_params( + { "test_table": { "schema": {"col_1": "int", "col_2": "varchar"}, "options": {"subquery": "SELECT col_1, col_2 FROM other_table"}, } - }, - }, + } + ), ) assert source.get_table_options("test_table") == { From ce7ad51cd0d95ad3cb822d7437b5fbda93fd6030 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 12 Apr 2021 11:01:02 +0100 Subject: [PATCH 04/16] Fix CSV / Singer ingestion tests --- test/splitgraph/ingestion/test_csv.py | 60 ++++++++++++++---------- test/splitgraph/ingestion/test_singer.py | 2 +- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/test/splitgraph/ingestion/test_csv.py b/test/splitgraph/ingestion/test_csv.py index 51a79945..a7401043 100644 --- a/test/splitgraph/ingestion/test_csv.py +++ b/test/splitgraph/ingestion/test_csv.py @@ -57,10 +57,6 @@ def test_csv_introspection_s3(): assert schema[2]["table_name"] == "rdu-weather-history.csv" assert schema[2]["columns"][0] == {"column_name": "date", "type_name": "date"} - # TODO we need a way to pass suggested table options in the inference / preview response, - # since we need to somehow decouple the table name from the S3 object name and/or customize - # delimiter/quotechar - def test_csv_introspection_http(): # Pre-sign the S3 URL for an easy HTTP URL to test this @@ -105,27 +101,39 @@ def test_csv_data_source_s3(local_engine_empty): schema = source.introspect() - assert len(schema.keys()) == 3 - assert schema["fruits.csv"] == [ - TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), - TableColumn( - ordinal=2, - name="timestamp", - pg_type="timestamp without time zone", - is_pk=False, - comment=None, - ), - TableColumn(ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None), - TableColumn(ordinal=4, name="number", pg_type="integer", is_pk=False, comment=None), - TableColumn(ordinal=5, name="bignumber", pg_type="bigint", is_pk=False, comment=None), - TableColumn(ordinal=6, name="vbignumber", pg_type="numeric", is_pk=False, comment=None), - ] - assert schema["encoding-win-1252.csv"] == [ - TableColumn(ordinal=1, name="col_1", pg_type="integer", is_pk=False, comment=None), - TableColumn(ordinal=2, name="DATE", pg_type="character varying", is_pk=False, comment=None), - TableColumn(ordinal=3, name="TEXT", pg_type="character varying", is_pk=False, comment=None), - ] - assert len(schema["rdu-weather-history.csv"]) == 28 + assert len(schema.keys()) == 4 + assert schema["fruits.csv"] == ( + [ + TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), + TableColumn( + ordinal=2, + name="timestamp", + pg_type="timestamp without time zone", + is_pk=False, + comment=None, + ), + TableColumn( + ordinal=3, name="name", pg_type="character varying", is_pk=False, comment=None + ), + TableColumn(ordinal=4, name="number", pg_type="integer", is_pk=False, comment=None), + TableColumn(ordinal=5, name="bignumber", pg_type="bigint", is_pk=False, comment=None), + TableColumn(ordinal=6, name="vbignumber", pg_type="numeric", is_pk=False, comment=None), + ], + {}, + ) + assert schema["encoding-win-1252.csv"] == ( + [ + TableColumn(ordinal=1, name="col_1", pg_type="integer", is_pk=False, comment=None), + TableColumn( + ordinal=2, name="DATE", pg_type="character varying", is_pk=False, comment=None + ), + TableColumn( + ordinal=3, name="TEXT", pg_type="character varying", is_pk=False, comment=None + ), + ], + {}, + ) + assert len(schema["rdu-weather-history.csv"][0]) == 28 preview = source.preview(schema) assert len(preview.keys()) == 3 @@ -169,7 +177,7 @@ def test_csv_data_source_http(local_engine_empty): schema = source.introspect() assert len(schema.keys()) == 1 - assert len(schema["data"]) == 28 + assert len(schema["data"][0]) == 28 preview = source.preview(schema) assert len(preview.keys()) == 1 diff --git a/test/splitgraph/ingestion/test_singer.py b/test/splitgraph/ingestion/test_singer.py index fd5c82d1..7af7bb8a 100644 --- a/test/splitgraph/ingestion/test_singer.py +++ b/test/splitgraph/ingestion/test_singer.py @@ -351,7 +351,7 @@ def test_singer_data_source_introspect(local_engine_empty): ) schema = source.introspect() - assert schema == {"releases": _RELEASES_SCHEMA, "stargazers": _STARGAZERS_SCHEMA} + assert schema == {"releases": (_RELEASES_SCHEMA, {}), "stargazers": (_STARGAZERS_SCHEMA, {})} def test_singer_data_source_sync(local_engine_empty): From e8e0e78f203205ce1b64488f0ab446f3a760643a Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 17:07:00 +0100 Subject: [PATCH 05/16] Fix Socrata data source --- splitgraph/ingestion/socrata/mount.py | 43 +++++++++++++++++++-------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/splitgraph/ingestion/socrata/mount.py b/splitgraph/ingestion/socrata/mount.py index 1d3f04eb..7ed3ce95 100644 --- a/splitgraph/ingestion/socrata/mount.py +++ b/splitgraph/ingestion/socrata/mount.py @@ -6,6 +6,7 @@ from psycopg2.sql import SQL, Identifier +from splitgraph.core.types import TableInfo from splitgraph.exceptions import RepositoryNotFoundError from splitgraph.hooks.data_source.fdw import create_foreign_table, ForeignDataWrapperDataSource @@ -29,17 +30,17 @@ class SocrataDataSource(ForeignDataWrapperDataSource): "type": "integer", "description": "Amount of rows to fetch from Socrata per request (limit parameter). Maximum 50000.", }, - "tables": { - "type": "object", - }, }, "required": ["domain"], } - """ - tables: A dictionary mapping PostgreSQL table names to Socrata table IDs. For example, - {"salaries": "xzkq-xp2w"}. If skipped, ALL tables in the Socrata endpoint will be mounted. - """ + table_params_schema = { + "type": "object", + "properties": { + "socrata_id": {"type": "string", "description": "Socrata dataset ID, e.g. xzkq-xp2w"} + }, + "required": ["socrata_id"], + } @classmethod def get_name(cls) -> str: @@ -55,8 +56,19 @@ def get_fdw_name(self): @classmethod def from_commandline(cls, engine, commandline_kwargs) -> "SocrataDataSource": params = deepcopy(commandline_kwargs) + + # Convert the old-style "tables" param ({"table_name": "some_id"}) + # to the schema this data source expects + # {"table_name": [table schema], {"socrata_id": "some_id"}} + # Note that we don't actually care about the table schema in this data source, + # since it reintrospects on every mount. + + tables = params.pop("tables", []) + if isinstance(tables, dict) and isinstance(next(iter(tables.values())), str): + tables = {k: ([], {"socrata_id": v}) for k, v in tables.items()} + credentials = {"app_token": params.pop("app_token", None)} - return cls(engine, credentials, params) + return cls(engine, credentials, params, tables) def get_server_options(self): options: Dict[str, Optional[str]] = { @@ -75,8 +87,12 @@ def _create_foreign_tables(self, schema, server_id, tables): logging.info("Getting Socrata metadata") client = Socrata(domain=self.params["domain"], app_token=self.credentials.get("app_token")) - tables = self.params.get("tables") - sought_ids = tables.values() if tables else [] + + tables = tables or self.tables + if isinstance(tables, list): + sought_ids = tables + else: + sought_ids = [t[1]["socrata_id"] for t in tables.values()] try: datasets = client.datasets(ids=sought_ids, only=["dataset"]) @@ -98,7 +114,7 @@ def _create_foreign_tables(self, schema, server_id, tables): self.engine.run_sql(SQL(";").join(mount_statements), mount_args) -def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, tables): +def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, tables: TableInfo): # Local imports since this module gets run from commandline entrypoint on startup. from splitgraph.core.output import slugify @@ -109,7 +125,7 @@ def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, found_ids = set(d["resource"]["id"] for d in datasets) logging.info("Loaded metadata for %s", pluralise("Socrata table", len(found_ids))) - if tables: + if isinstance(tables, (dict, list)): missing_ids = [d for d in found_ids if d not in sought_ids] if missing_ids: raise ValueError( @@ -117,7 +133,8 @@ def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, % truncate_list(missing_ids) ) - tables_inv = {s: p for p, s in tables.items()} + if isinstance(tables, dict): + tables_inv = {to["socrata_id"]: p for p, (ts, to) in tables.items()} else: tables_inv = {} From 1a9fd3a9e3469f7a03bf39568f8ded3fa90fbc8d Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 17:13:35 +0100 Subject: [PATCH 06/16] Refactor the table inversion code for Socrata a bit. --- splitgraph/ingestion/socrata/mount.py | 31 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/splitgraph/ingestion/socrata/mount.py b/splitgraph/ingestion/socrata/mount.py index 7ed3ce95..646121b8 100644 --- a/splitgraph/ingestion/socrata/mount.py +++ b/splitgraph/ingestion/socrata/mount.py @@ -118,25 +118,13 @@ def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, # Local imports since this module gets run from commandline entrypoint on startup. from splitgraph.core.output import slugify - from splitgraph.core.output import truncate_list from splitgraph.core.output import pluralise from splitgraph.ingestion.socrata.querying import socrata_to_sg_schema found_ids = set(d["resource"]["id"] for d in datasets) logging.info("Loaded metadata for %s", pluralise("Socrata table", len(found_ids))) - if isinstance(tables, (dict, list)): - missing_ids = [d for d in found_ids if d not in sought_ids] - if missing_ids: - raise ValueError( - "Some Socrata tables couldn't be found! Missing tables: %s" - % truncate_list(missing_ids) - ) - - if isinstance(tables, dict): - tables_inv = {to["socrata_id"]: p for p, (ts, to) in tables.items()} - else: - tables_inv = {} + tables_inv = _get_table_map(found_ids, sought_ids, tables) mount_statements = [] mount_args = [] @@ -165,3 +153,20 @@ def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, mount_args.extend(args) return mount_statements, mount_args + + +def _get_table_map(found_ids, sought_ids, tables: TableInfo) -> Dict[str, str]: + """Get a map of Socrata ID -> local table name""" + from splitgraph.core.output import truncate_list + + if isinstance(tables, (dict, list)): + missing_ids = [d for d in found_ids if d not in sought_ids] + if missing_ids: + raise ValueError( + "Some Socrata tables couldn't be found! Missing tables: %s" + % truncate_list(missing_ids) + ) + + if isinstance(tables, dict): + return {to["socrata_id"]: p for p, (ts, to) in tables.items()} + return {} From b16bb914fe0249eb97e615e9792643b969f4bb38 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 17:44:34 +0100 Subject: [PATCH 07/16] Change `tableschema_to_dict` to `table_schema_params_to_dict` (utility conversion function) --- splitgraph/core/types.py | 12 ++++- splitgraph/ingestion/socrata/mount.py | 2 +- test/splitgraph/test_misc.py | 74 ++++++++++++++++----------- 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/splitgraph/core/types.py b/splitgraph/core/types.py index ac58224e..97a6262e 100644 --- a/splitgraph/core/types.py +++ b/splitgraph/core/types.py @@ -48,5 +48,13 @@ def dict_to_table_schema_params( } -def tableschema_to_dict(tables: Dict[str, TableSchema]) -> Dict[str, Dict[str, str]]: - return {t: {c.name: c.pg_type for c in ts} for t, ts in tables.items()} +def table_schema_params_to_dict( + tables: Dict[str, Tuple[TableSchema, TableParams]] +) -> Dict[str, Dict[str, Dict[str, str]]]: + return { + t: { + "schema": {c.name: c.pg_type for c in ts}, + "options": {tpk: str(tpv) for tpk, tpv in tp.items()}, + } + for t, (ts, tp) in tables.items() + } diff --git a/splitgraph/ingestion/socrata/mount.py b/splitgraph/ingestion/socrata/mount.py index 646121b8..95fc61c3 100644 --- a/splitgraph/ingestion/socrata/mount.py +++ b/splitgraph/ingestion/socrata/mount.py @@ -159,7 +159,7 @@ def _get_table_map(found_ids, sought_ids, tables: TableInfo) -> Dict[str, str]: """Get a map of Socrata ID -> local table name""" from splitgraph.core.output import truncate_list - if isinstance(tables, (dict, list)): + if isinstance(tables, (dict, list)) and tables: missing_ids = [d for d in found_ids if d not in sought_ids] if missing_ids: raise ValueError( diff --git a/test/splitgraph/test_misc.py b/test/splitgraph/test_misc.py index 7ff7c83b..26ca1160 100644 --- a/test/splitgraph/test_misc.py +++ b/test/splitgraph/test_misc.py @@ -9,7 +9,11 @@ from splitgraph.core.engine import lookup_repository from splitgraph.core.metadata_manager import Object from splitgraph.core.repository import Repository -from splitgraph.core.types import TableColumn, tableschema_to_dict, dict_to_table_schema_params +from splitgraph.core.types import ( + TableColumn, + dict_to_table_schema_params, + table_schema_params_to_dict, +) from splitgraph.engine.postgres.engine import API_MAX_QUERY_LENGTH from splitgraph.exceptions import RepositoryNotFoundError from splitgraph.hooks.s3 import get_object_upload_urls @@ -315,37 +319,49 @@ def test_parse_dt(): parse_dt("not a dt") -def test_tableschema_to_dict(): - assert tableschema_to_dict( +def test_table_schema_params_to_dict(): + assert table_schema_params_to_dict( { - "fruits": [ - TableColumn( - ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None - ), - TableColumn( - ordinal=2, - name="name", - pg_type="character varying", - is_pk=False, - comment=None, - ), - ], - "vegetables": [ - TableColumn( - ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None - ), - TableColumn( - ordinal=2, - name="name", - pg_type="character varying", - is_pk=False, - comment=None, - ), - ], + "fruits": ( + [ + TableColumn( + ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, + ), + ], + {"key": "value"}, + ), + "vegetables": ( + [ + TableColumn( + ordinal=1, name="vegetable_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, + ), + ], + {"key": "value"}, + ), } ) == { - "fruits": {"fruit_id": "integer", "name": "character varying"}, - "vegetables": {"name": "character varying", "vegetable_id": "integer"}, + "fruits": { + "schema": {"fruit_id": "integer", "name": "character varying"}, + "options": {"key": "value"}, + }, + "vegetables": { + "schema": {"name": "character varying", "vegetable_id": "integer"}, + "options": {"key": "value"}, + }, } From 6459b8230feac4f982e7649d609e92117b78b317 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 22:50:20 +0100 Subject: [PATCH 08/16] Allow the Postgres data source to receive arbitrary table params (e.g. to override the table name). --- splitgraph/hooks/data_source/fdw.py | 4 ++- test/splitgraph/commands/test_mounting.py | 32 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index 52097f85..7f54129b 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -375,7 +375,9 @@ def get_user_options(self): return {"user": self.credentials["username"], "password": self.credentials["password"]} def get_table_options(self, table_name: str): - return {"schema_name": self.params["remote_schema"]} + options = super().get_table_options(table_name) + options["schema_name"] = self.params["remote_schema"] + return options def get_fdw_name(self): return "postgres_fdw" diff --git a/test/splitgraph/commands/test_mounting.py b/test/splitgraph/commands/test_mounting.py index 0606ee87..a22da7d6 100644 --- a/test/splitgraph/commands/test_mounting.py +++ b/test/splitgraph/commands/test_mounting.py @@ -100,6 +100,38 @@ def test_mount_introspection_preview(local_engine_empty): } +@pytest.mark.mounting +def test_mount_rename_table(local_engine_empty): + tables = { + "fruits_renamed": ( + [ + TableColumn( + ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None + ), + TableColumn( + ordinal=2, + name="name", + pg_type="character varying", + is_pk=False, + comment=None, + ), + ], + {"table_name": "fruits"}, + ) + } + handler = PostgreSQLDataSource( + engine=local_engine_empty, + credentials={"username": "originro", "password": "originpass"}, + params={"host": "pgorigin", "port": 5432, "dbname": "origindb", "remote_schema": "public"}, + tables=tables, + ) + + preview = handler.preview(tables) + assert preview == { + "fruits_renamed": [{"fruit_id": 1, "name": "apple"}, {"fruit_id": 2, "name": "orange"}], + } + + @pytest.mark.mounting def test_cross_joins(local_engine_empty): _mount_postgres(PG_MNT) From 14940317f703be61c72c66b6493377f0d75f4f08 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 5 Apr 2021 23:15:38 +0100 Subject: [PATCH 09/16] Use semver ordering when figuring out the latest migration to apply (instead of lexicographical in which 0.0.9 > 0.0.10) --- splitgraph/core/migration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/splitgraph/core/migration.py b/splitgraph/core/migration.py index 13c19b6d..37bb1915 100644 --- a/splitgraph/core/migration.py +++ b/splitgraph/core/migration.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Optional, Tuple, cast, TYPE_CHECKING, List, TypeVar, Dict, DefaultDict +from packaging.version import Version from psycopg2.sql import SQL, Identifier from splitgraph.core.sql import select, insert @@ -116,7 +117,7 @@ def source_files_to_apply( ) -> Tuple[List[str], str]: """ Get the ordered list of .sql files to apply to the database""" version_tuples = get_version_tuples(schema_files) - target_version = target_version or max([v[1] for v in version_tuples]) + target_version = target_version or max([v[1] for v in version_tuples], key=Version) if static: # For schemata like splitgraph_api which we want to apply in any # case, bypassing the upgrade mechanism, we just run the latest From 6cd33ae1a27950400600ee3cdac63cf552115ae5 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Tue, 6 Apr 2021 10:46:50 +0100 Subject: [PATCH 10/16] Get introspection to suggest foreign table options (by scraping them from the engine after a mount operation). --- splitgraph/hooks/data_source/fdw.py | 43 +++++++++++++++++++++-- test/splitgraph/commands/test_mounting.py | 4 +-- test/splitgraph/ingestion/test_csv.py | 6 ++-- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index 7f54129b..0b8d8586 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -146,13 +146,24 @@ def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: # Local import here since this data source gets imported by the commandline entry point from splitgraph.core.common import get_temporary_table_id + import jsonschema tmp_schema = get_temporary_table_id() try: self.mount(tmp_schema) + + table_options = dict(self._get_foreign_table_options(tmp_schema)) + + # Sanity check for adapters: validate the foreign table options that we get back + # to make sure they're still appropriate. + for v in table_options.values(): + jsonschema.validate(v, self.table_params_schema) + result = { - # TODO extract table_params from the mount - t: (self.engine.get_full_table_schema(tmp_schema, t), cast(TableParams, {})) + t: ( + self.engine.get_full_table_schema(tmp_schema, t), + cast(TableParams, table_options.get(t, {})), + ) for t in self.engine.get_all_tables(tmp_schema) } return result @@ -176,6 +187,30 @@ def _preview_table(self, schema: str, table: str, limit: int = 10) -> List[Dict[ ) return result_json + def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, str]]]: + """ + Get a list of options the foreign tables in this schema were instantiated with + :return: List of tables and their options + """ + # We use this to suggest table options during introspection. With FDWs, we do this by + # mounting the data source first (using IMPORT FOREIGN SCHEMA) and then scraping the + # foreign table options it inferred. + + # Downstream FDWs can override this: if they remap some table options into different + # FDW options (e.g. "remote_schema" on the data source side turns into "schema" on the + # FDW side), they have to map them back in this routine (otherwise the introspection will + # suggest "schema", which is wrong. + + return cast( + List[Tuple[str, Dict[str, str]]], + self.engine.run_sql( + "SELECT foreign_table_name, json_object_agg(option_name, option_value) " + "FROM information_schema.foreign_table_options " + "WHERE foreign_table_schema = %s GROUP BY foreign_table_name", + (schema,), + ), + ) + def preview(self, tables: Optional[TableInfo]) -> PreviewResult: # Preview data in tables mounted by this FDW / data source @@ -501,7 +536,9 @@ def get_user_options(self): return {"username": self.credentials["username"], "password": self.credentials["password"]} def get_table_options(self, table_name: str): - return {"dbname": self.params["dbname"]} + options = super().get_table_options(table_name) + options["dbname"] = options.get("dbname", self.params["dbname"]) + return options def get_fdw_name(self): return "mysql_fdw" diff --git a/test/splitgraph/commands/test_mounting.py b/test/splitgraph/commands/test_mounting.py index a22da7d6..634fb781 100644 --- a/test/splitgraph/commands/test_mounting.py +++ b/test/splitgraph/commands/test_mounting.py @@ -75,7 +75,7 @@ def test_mount_introspection_preview(local_engine_empty): ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None ), ], - {}, + {"schema_name": "public", "table_name": "fruits"}, ), "vegetables": ( [ @@ -86,7 +86,7 @@ def test_mount_introspection_preview(local_engine_empty): ordinal=2, name="name", pg_type="character varying", is_pk=False, comment=None ), ], - {}, + {"schema_name": "public", "table_name": "vegetables"}, ), } diff --git a/test/splitgraph/ingestion/test_csv.py b/test/splitgraph/ingestion/test_csv.py index a7401043..e5c0fb2c 100644 --- a/test/splitgraph/ingestion/test_csv.py +++ b/test/splitgraph/ingestion/test_csv.py @@ -101,7 +101,7 @@ def test_csv_data_source_s3(local_engine_empty): schema = source.introspect() - assert len(schema.keys()) == 4 + assert len(schema.keys()) == 3 assert schema["fruits.csv"] == ( [ TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), @@ -119,7 +119,7 @@ def test_csv_data_source_s3(local_engine_empty): TableColumn(ordinal=5, name="bignumber", pg_type="bigint", is_pk=False, comment=None), TableColumn(ordinal=6, name="vbignumber", pg_type="numeric", is_pk=False, comment=None), ], - {}, + {"s3_object": "some_prefix/fruits.csv"}, ) assert schema["encoding-win-1252.csv"] == ( [ @@ -131,7 +131,7 @@ def test_csv_data_source_s3(local_engine_empty): ordinal=3, name="TEXT", pg_type="character varying", is_pk=False, comment=None ), ], - {}, + {"s3_object": "some_prefix/encoding-win-1252.csv"}, ) assert len(schema["rdu-weather-history.csv"][0]) == 28 From 69eb285cafced3c62edfd438724f1900bbdd307e Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 12 Apr 2021 16:39:45 +0100 Subject: [PATCH 11/16] Make import_foreign_schema public and allow passing a dict of table options. --- splitgraph/hooks/data_source/fdw.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index 0b8d8586..4b76b051 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -127,7 +127,7 @@ def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo) raise NotImplementedError( "The FDW does not support IMPORT FOREIGN SCHEMA! Pass a tables dictionary." ) - _import_foreign_schema(self.engine, schema, remote_schema, server_id, tables) + import_foreign_schema(self.engine, schema, remote_schema, server_id, tables) else: for table_name, (table_schema, _) in tables.items(): logging.info("Mounting table %s", table_name) @@ -308,8 +308,13 @@ def _format_options(option_names): ) -def _import_foreign_schema( - engine: "PostgresEngine", mountpoint: str, remote_schema: str, server_id: str, tables: List[str] +def import_foreign_schema( + engine: "PostgresEngine", + mountpoint: str, + remote_schema: str, + server_id: str, + tables: List[str], + options: Optional[Dict[str, str]] = None, ) -> None: from psycopg2.sql import Identifier, SQL @@ -318,7 +323,13 @@ def _import_foreign_schema( if tables: query += SQL("LIMIT TO (") + SQL(",").join(Identifier(t) for t in tables) + SQL(")") query += SQL("FROM SERVER {} INTO {}").format(Identifier(server_id), Identifier(mountpoint)) - engine.run_sql(query) + args: List[str] = [] + + if options: + query += _format_options(options.keys()) + args.extend(options.values()) + + engine.run_sql(query, args) def create_foreign_table( From 4c2a8d2020d546a7f2a6c596cbd604ebe91e3079 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 12 Apr 2021 16:40:01 +0100 Subject: [PATCH 12/16] Bump Multicorn to pick up the fix for table params overriding server params. --- engine/src/Multicorn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/src/Multicorn b/engine/src/Multicorn index 5edb6814..0f95b96c 160000 --- a/engine/src/Multicorn +++ b/engine/src/Multicorn @@ -1 +1 @@ -Subproject commit 5edb681438530479516c4eb60a2b2e2d5c707efb +Subproject commit 0f95b96c7a92d4cac8ae83f9e24fe82111219395 From b480ae01195248eecf206aca237deb342be68dee Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 12 Apr 2021 16:43:25 +0100 Subject: [PATCH 13/16] CSV inference: instead of passing the dialect into the `CSVOptions` struct, unpack it and set the correct CSV parsing options (delimiter, quotechar) --- splitgraph/ingestion/csv/common.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/splitgraph/ingestion/csv/common.py b/splitgraph/ingestion/csv/common.py index 3039cb2f..7d936318 100644 --- a/splitgraph/ingestion/csv/common.py +++ b/splitgraph/ingestion/csv/common.py @@ -18,7 +18,6 @@ class CSVOptions(NamedTuple): schema_inference_rows: int = 10000 delimiter: str = "," quotechar: str = '"' - dialect: Optional[Union[str, Type[csv.Dialect]]] = None header: bool = True encoding: str = "utf-8" ignore_decode_errors: bool = False @@ -34,14 +33,11 @@ def from_fdw_options(cls, fdw_options): header=get_bool(fdw_options, "header"), delimiter=fdw_options.get("delimiter", ","), quotechar=fdw_options.get("quotechar", '"'), - dialect=fdw_options.get("dialect"), encoding=fdw_options.get("encoding", "utf-8"), ignore_decode_errors=get_bool(fdw_options, "ignore_decode_errors", default=False), ) def to_csv_kwargs(self): - if self.dialect: - return {"dialect": self.dialect} return {"delimiter": self.delimiter, "quotechar": self.quotechar} @@ -77,7 +73,10 @@ def autodetect_csv(stream: io.RawIOBase, csv_options: CSVOptions) -> CSVOptions: if csv_options.autodetect_dialect: dialect = csv.Sniffer().sniff(sample) - csv_options = csv_options._replace(dialect=dialect) + # These are meant to be set, but mypy claims they might not be. + csv_options = csv_options._replace( + delimiter=dialect.delimiter or ",", quotechar=dialect.quotechar or '"' + ) if csv_options.autodetect_header: has_header = csv.Sniffer().has_header(sample) From 5ae39534af8b8603601f49988e71ce28ec93f08a Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Mon, 12 Apr 2021 16:49:29 +0100 Subject: [PATCH 14/16] CSV data source: allow passing a partially initialized list of table options without a table schema (just the options like the S3 object key, not the table schema). This will make it introspect just those keys and fill out the missing table options (e.g. can pass the encoding and it'll infer the rest). Add initial support for passing these table options to `IMPORT FOREIGN SCHEMA` (as a JSON option). --- splitgraph/hooks/data_source/fdw.py | 7 +- splitgraph/ingestion/csv/__init__.py | 71 ++++++- splitgraph/ingestion/csv/common.py | 33 +++- splitgraph/ingestion/csv/fdw.py | 194 +++++++++++------- test/splitgraph/ingestion/test_csv.py | 273 ++++++++++++++++++++++++-- 5 files changed, 487 insertions(+), 91 deletions(-) diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index 4b76b051..fc61732f 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -187,7 +187,7 @@ def _preview_table(self, schema: str, table: str, limit: int = 10) -> List[Dict[ ) return result_json - def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, str]]]: + def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, Any]]]: """ Get a list of options the foreign tables in this schema were instantiated with :return: List of tables and their options @@ -201,6 +201,11 @@ def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, s # FDW side), they have to map them back in this routine (otherwise the introspection will # suggest "schema", which is wrong. + # This is also used for type remapping, since table params can only be strings on PG. + # TODO: this will lead to a bunch of serialization/deserialization code duplication + # in data sources. One potential solution is, at least in Multicorn-backed wrappers + # that we control, passing a JSON as a single table param instead. + return cast( List[Tuple[str, Dict[str, str]]], self.engine.run_sql( diff --git a/splitgraph/ingestion/csv/__init__.py b/splitgraph/ingestion/csv/__init__.py index 7b820088..18b50bc9 100644 --- a/splitgraph/ingestion/csv/__init__.py +++ b/splitgraph/ingestion/csv/__init__.py @@ -1,9 +1,11 @@ +import json from copy import deepcopy -from typing import Optional, TYPE_CHECKING, Dict, Mapping, cast +from typing import Optional, TYPE_CHECKING, Dict, List, Tuple, Any from psycopg2.sql import SQL, Identifier -from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource +from splitgraph.core.types import TableInfo +from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource, import_foreign_schema from splitgraph.ingestion.common import IngestionAdapter, build_commandline_help if TYPE_CHECKING: @@ -118,6 +120,15 @@ class CSVDataSource(ForeignDataWrapperDataSource): }, "quotechar": {"type": "string", "description": "Character used to quote fields"}, }, + "oneOf": [{"required": ["url"]}, {"required": ["s3_endpoint", "s3_bucket"]}], + } + + table_params_schema = { + "type": "object", + "properties": { + "url": {"type": "string", "description": "HTTP URL to the CSV file"}, + "s3_object": {"type": "string", "description": "S3 object of the CSV file"}, + }, } supports_mount = True @@ -175,11 +186,61 @@ def from_commandline(cls, engine, commandline_kwargs) -> "CSVDataSource": def get_table_options(self, table_name: str) -> Dict[str, str]: result = super().get_table_options(table_name) - result["s3_object"] = result.get( - "s3_object", self.params.get("s3_object_prefix", "") + table_name - ) + + # Set a default s3_object if we're using S3 and not HTTP + if "url" not in result: + result["s3_object"] = result.get( + "s3_object", self.params.get("s3_object_prefix", "") + table_name + ) return result + def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo): + # Override _create_foreign_tables (actual mounting code) to support TableInfo structs + # where the schema is empty. This is so that we can call this data source with a limited + # list of CSV files and their delimiters / other params and have it introspect just + # those tables (instead of e.g. scanning the whole bucket). + if isinstance(tables, dict): + to_introspect = { + table_name: table_options + for table_name, (table_schema, table_options) in tables.items() + if not table_schema + } + if to_introspect: + # This FDW's implementation of IMPORT FOREIGN SCHEMA supports passing a JSON of table + # options in options + import_foreign_schema( + self.engine, + schema, + self.get_remote_schema_name(), + server_id, + tables=list(to_introspect.keys()), + options={"table_options": json.dumps(to_introspect)}, + ) + + # Create the remaining tables (that have a schema) as usual. + tables = { + table_name: (table_schema, table_options) + for table_name, (table_schema, table_options) in tables.items() + if table_schema + } + if not tables: + return + + super()._create_foreign_tables(schema, server_id, tables) + + def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, Any]]]: + options = super()._get_foreign_table_options(schema) + + # Deserialize things like booleans from the foreign table options that the FDW inferred + # for us. + def _destring_table_options(table_options): + for k in ["autodetect_header", "autodetect_encoding", "autodetect_dialect", "header"]: + if k in table_options: + table_options[k] = table_options[k].lower() == "true" + return table_options + + return [(t, _destring_table_options(d)) for t, d in options] + def get_server_options(self): options: Dict[str, Optional[str]] = { "wrapper": "splitgraph.ingestion.csv.fdw.CSVForeignDataWrapper" diff --git a/splitgraph/ingestion/csv/common.py b/splitgraph/ingestion/csv/common.py index 7d936318..7af33645 100644 --- a/splitgraph/ingestion/csv/common.py +++ b/splitgraph/ingestion/csv/common.py @@ -1,6 +1,6 @@ import csv import io -from typing import Optional, Dict, Tuple, NamedTuple, Union, Type, TYPE_CHECKING +from typing import Dict, Tuple, NamedTuple, TYPE_CHECKING, Any if TYPE_CHECKING: import _csv @@ -40,6 +40,27 @@ def from_fdw_options(cls, fdw_options): def to_csv_kwargs(self): return {"delimiter": self.delimiter, "quotechar": self.quotechar} + def to_table_options(self): + """ + Turn this into a dict of table options that can be plugged back into CSVDataSource. + """ + + # The purpose is to return to the user the CSV dialect options that we inferred + # so that they can freeze them in the table options (instead of rescanning the CSV + # on every mount) + iterate on them. + + # We flip the autodetect flags to False here so that if we merge the new params with + # the old params again, it won't rerun CSV dialect detection. + return { + "autodetect_header": "false", + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "header": bool_to_str(self.header), + "delimiter": self.delimiter, + "quotechar": self.quotechar, + "encoding": self.encoding, + } + def autodetect_csv(stream: io.RawIOBase, csv_options: CSVOptions) -> CSVOptions: """Autodetect the CSV dialect, encoding, header etc.""" @@ -85,10 +106,16 @@ def autodetect_csv(stream: io.RawIOBase, csv_options: CSVOptions) -> CSVOptions: return csv_options -def get_bool(params: Dict[str, str], key: str, default: bool = True) -> bool: +def get_bool(params: Dict[str, Any], key: str, default: bool = True) -> bool: if key not in params: return default - return params[key].lower() == "true" + if isinstance(params[key], bool): + return bool(params[key]) + return bool(params[key].lower() == "true") + + +def bool_to_str(boolean: bool) -> str: + return "true" if boolean else "false" def make_csv_reader( diff --git a/splitgraph/ingestion/csv/fdw.py b/splitgraph/ingestion/csv/fdw.py index 201da8a1..cb3883b4 100644 --- a/splitgraph/ingestion/csv/fdw.py +++ b/splitgraph/ingestion/csv/fdw.py @@ -1,7 +1,8 @@ import gzip +import json import logging import os -from copy import deepcopy +from copy import copy from itertools import islice from typing import Tuple @@ -37,6 +38,10 @@ def _get_table_definition(response, fdw_options, table_name, table_options): + # Allow overriding introspection options with per-table params (e.g. encoding, delimiter...) + fdw_options = copy(fdw_options) + fdw_options.update(table_options) + csv_options, reader = make_csv_reader(response, CSVOptions.from_fdw_options(fdw_options)) sample = list(islice(reader, csv_options.schema_inference_rows)) @@ -48,14 +53,19 @@ def _get_table_definition(response, fdw_options, table_name, table_options): # For nonexistent column names: replace with autogenerated ones (can't have empty column names) sg_schema = generate_column_names(sg_schema) + # Merge the autodetected table options with the ones passed to us originally (e.g. + # S3 object etc) + new_table_options = copy(table_options) + new_table_options.update(csv_options.to_table_options()) + # Build Multicorn TableDefinition. ColumnDefinition takes in type OIDs, # typmods and other internal PG stuff but other FDWs seem to get by with just # the textual type name. return TableDefinition( - table_name=table_name[len(fdw_options.get("s3_object_prefix", "")) :], + table_name=table_name, schema=None, columns=[ColumnDefinition(column_name=c.name, type_name=c.pg_type) for c in sg_schema], - options=table_options, + options=new_table_options, ) @@ -133,72 +143,122 @@ def import_schema(cls, schema, srv_options, options, restriction_type, restricts # Implement IMPORT FOREIGN SCHEMA to instead scan an S3 bucket for CSV files # and infer their CSV schema. - # Merge server options and options passed to IMPORT FOREIGN SCHEMA - fdw_options = deepcopy(srv_options) - for k, v in options.items(): - fdw_options[k] = v - - if fdw_options.get("url"): - # Infer from HTTP -- singular table with name "data" - with requests.get( - fdw_options["url"], stream=True, verify=os.environ.get("SSL_CERT_FILE", True) - ) as response: - response.raise_for_status() - stream = response.raw - if response.headers.get("Content-Encoding") == "gzip": - stream = gzip.GzipFile(fileobj=stream) - return [_get_table_definition(stream, fdw_options, "data", None)] - - # Get S3 options - client, bucket, prefix = cls._get_s3_params(fdw_options) - - # Note that we ignore the "schema" here (e.g. IMPORT FOREIGN SCHEMA some_schema) - # and take all interesting parameters through FDW options. - - # Allow just introspecting one object - if "s3_object" in fdw_options: - objects = [fdw_options["s3_object"]] + # 1) if we don't have options["table_options"], do a full scan as normal, + # treat LIMIT TO as a list of S3 objects + # 2) if we do, go through these tables and treat each one as a partial override + # of server options + if "table_options" in options: + table_options = json.loads(options["table_options"]) else: - objects = [ - o.object_name - for o in client.list_objects( - bucket_name=bucket, prefix=prefix or None, recursive=True - ) - ] - - result = [] - - for o in objects: - if restriction_type == "limit" and o not in restricts: - continue - - if restriction_type == "except" and o in restricts: - continue - - response = None - try: - response = client.get_object(bucket, o) - result.append( - _get_table_definition( - response, - fdw_options, - o, - {"s3_object": o}, + table_options = None + + if not table_options: + # Do a full scan of the file at URL / S3 bucket w. prefix + if srv_options.get("url"): + # Infer from HTTP -- singular table with name "data" + return [cls._introspect_url(srv_options, srv_options["url"])] + else: + # Get S3 options + client, bucket, prefix = cls._get_s3_params(srv_options) + + # Note that we ignore the "schema" here (e.g. IMPORT FOREIGN SCHEMA some_schema) + # and take all interesting parameters through FDW options. + + # Allow just introspecting one object + if "s3_object" in srv_options: + objects = [srv_options["s3_object"]] + elif restriction_type == "limit": + objects = restricts + else: + objects = [ + o.object_name + for o in client.list_objects( + bucket_name=bucket, prefix=prefix or None, recursive=True + ) + ] + + result = [] + + for o in objects: + if restriction_type == "except" and o in restricts: + continue + result.append(cls._introspect_s3(client, bucket, o, srv_options)) + + return [r for r in result if r] + else: + result = [] + + # Note we ignore LIMIT/EXCEPT here. There's no point in using them if the user + # is passing a dict of table options anyway. + for table_name, this_table_options in table_options.items(): + if "s3_object" in this_table_options: + # TODO: we can support overriding S3 params per-table here, but currently + # we don't do it. + client, bucket, _ = cls._get_s3_params(srv_options) + result.append( + cls._introspect_s3( + client, + bucket, + this_table_options["s3_object"], + srv_options, + table_name, + this_table_options, + ) ) - ) - except Exception as e: - logging.error( - "Error scanning object %s, ignoring: %s: %s", o, get_exception_name(e), e - ) - log_to_postgres( - "Error scanning object %s, ignoring: %s: %s" % (o, get_exception_name(e), e) - ) - finally: - if response: - response.close() - response.release_conn() + else: + result.append( + cls._introspect_url( + srv_options, this_table_options["url"], table_name, this_table_options + ) + ) + return [r for r in result if r] + + @classmethod + def _introspect_s3( + cls, client, bucket, object_id, srv_options, table_name=None, table_options=None + ): + response = None + # Default table name: truncate S3 object key up to the prefix + table_name = table_name or object_id[len(srv_options.get("s3_object_prefix", "")) :] + table_options = table_options or {} + table_options.update({"s3_object": object_id}) + try: + response = client.get_object(bucket, object_id) + return _get_table_definition( + response, + srv_options, + table_name, + table_options, + ) + except Exception as e: + logging.error( + "Error scanning object %s, ignoring: %s: %s", + object_id, + get_exception_name(e), + e, + exc_info=e, + ) + log_to_postgres( + "Error scanning object %s, ignoring: %s: %s" % (object_id, get_exception_name(e), e) + ) + finally: + if response: + response.close() + response.release_conn() - return result + @classmethod + def _introspect_url(cls, srv_options, url, table_name=None, table_options=None): + table_name = table_name or "data" + table_options = table_options or {} + + with requests.get( + url, stream=True, verify=os.environ.get("SSL_CERT_FILE", True) + ) as response: + response.raise_for_status() + stream = response.raw + if response.headers.get("Content-Encoding") == "gzip": + stream = gzip.GzipFile(fileobj=stream) + return _get_table_definition(stream, srv_options, table_name, table_options) @classmethod def _get_s3_params(cls, fdw_options) -> Tuple[Minio, str, str]: @@ -242,6 +302,4 @@ def __init__(self, fdw_options, fdw_columns): self.mode = "s3" self.s3_client, self.s3_bucket, self.s3_object_prefix = self._get_s3_params(fdw_options) - # TODO need a way to pass the table params (e.g. the actual S3 object name which - # might be different from table) into the preview and return it from introspection. self.s3_object = fdw_options["s3_object"] diff --git a/test/splitgraph/ingestion/test_csv.py b/test/splitgraph/ingestion/test_csv.py index e5c0fb2c..6a28a31f 100644 --- a/test/splitgraph/ingestion/test_csv.py +++ b/test/splitgraph/ingestion/test_csv.py @@ -1,5 +1,7 @@ +import json import os from io import BytesIO +from unittest import mock import pytest @@ -13,6 +15,28 @@ from splitgraph.ingestion.inference import infer_sg_schema from test.splitgraph.conftest import INGESTION_RESOURCES_CSV +_s3_win_1252_opts = { + "s3_object": "some_prefix/encoding-win-1252.csv", + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ";", + "encoding": "Windows-1252", + "header": "true", + "quotechar": '"', +} + +_s3_fruits_opts = { + "s3_object": "some_prefix/fruits.csv", + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ",", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', +} + def test_csv_introspection_s3(): fdw_options = { @@ -37,7 +61,7 @@ def test_csv_introspection_s3(): {"column_name": "DATE", "type_name": "character varying"}, {"column_name": "TEXT", "type_name": "character varying"}, ], - "options": {"s3_object": "some_prefix/encoding-win-1252.csv"}, + "options": _s3_win_1252_opts, "schema": None, "table_name": "encoding-win-1252.csv", } @@ -52,7 +76,7 @@ def test_csv_introspection_s3(): {"column_name": "bignumber", "type_name": "bigint"}, {"column_name": "vbignumber", "type_name": "numeric"}, ], - "options": {"s3_object": "some_prefix/fruits.csv"}, + "options": _s3_fruits_opts, } assert schema[2]["table_name"] == "rdu-weather-history.csv" assert schema[2]["columns"][0] == {"column_name": "date", "type_name": "date"} @@ -80,7 +104,87 @@ def test_csv_introspection_http(): {"column_name": "bignumber", "type_name": "bigint"}, {"column_name": "vbignumber", "type_name": "numeric"}, ], - "options": None, + "options": { + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ",", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', + }, + } + + +def test_csv_introspection_multiple(): + # Test running the introspection passing the table options as CREATE FOREIGN SCHEMA params. + # In effect, we specify the table names, S3 key/URL and expect the FDW to figure out + # the rest. + + fdw_options = { + "s3_endpoint": "objectstorage:9000", + "s3_secure": "false", + "s3_access_key": "minioclient", + "s3_secret_key": "supersecure", + "s3_bucket": "test_csv", + "s3_object_prefix": "some_prefix/", + } + + url = MINIO.presigned_get_object("test_csv", "some_prefix/rdu-weather-history.csv") + schema = CSVForeignDataWrapper.import_schema( + schema=None, + srv_options=fdw_options, + options={ + "table_options": json.dumps( + { + "from_url": {"url": url}, + "from_s3_rdu": {"s3_object": "some_prefix/rdu-weather-history.csv"}, + "from_s3_encoding": {"s3_object": "some_prefix/encoding-win-1252.csv"}, + } + ) + }, + restriction_type=None, + restricts=[], + ) + + assert len(schema) == 3 + schema = sorted(schema, key=lambda s: s["table_name"]) + + assert schema[0] == { + "table_name": "from_s3_encoding", + "schema": None, + "columns": mock.ANY, + "options": _s3_win_1252_opts, + } + assert schema[1] == { + "table_name": "from_s3_rdu", + "schema": None, + "columns": mock.ANY, + "options": { + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ";", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', + "s3_object": "some_prefix/rdu-weather-history.csv", + }, + } + assert schema[2] == { + "table_name": "from_url", + "schema": None, + "columns": mock.ANY, + "options": { + "autodetect_dialect": "false", + "autodetect_encoding": "false", + "autodetect_header": "false", + "delimiter": ";", + "encoding": "utf-8", + "header": "true", + "quotechar": '"', + "url": url, + }, } @@ -119,7 +223,16 @@ def test_csv_data_source_s3(local_engine_empty): TableColumn(ordinal=5, name="bignumber", pg_type="bigint", is_pk=False, comment=None), TableColumn(ordinal=6, name="vbignumber", pg_type="numeric", is_pk=False, comment=None), ], - {"s3_object": "some_prefix/fruits.csv"}, + { + "s3_object": "some_prefix/fruits.csv", + "autodetect_dialect": False, + "autodetect_encoding": False, + "autodetect_header": False, + "delimiter": ",", + "encoding": "utf-8", + "header": True, + "quotechar": '"', + }, ) assert schema["encoding-win-1252.csv"] == ( [ @@ -131,7 +244,16 @@ def test_csv_data_source_s3(local_engine_empty): ordinal=3, name="TEXT", pg_type="character varying", is_pk=False, comment=None ), ], - {"s3_object": "some_prefix/encoding-win-1252.csv"}, + { + "s3_object": "some_prefix/encoding-win-1252.csv", + "autodetect_dialect": False, + "autodetect_encoding": False, + "autodetect_header": False, + "delimiter": ";", + "encoding": "Windows-1252", + "header": True, + "quotechar": '"', + }, ) assert len(schema["rdu-weather-history.csv"][0]) == 28 @@ -166,6 +288,135 @@ def test_csv_data_source_s3(local_engine_empty): local_engine_empty.delete_schema("temp_data") +def test_csv_data_source_multiple(local_engine_empty): + # End-to-end version for test_csv_introspection_multiple to check things like table params + # getting serialized and deserialized properly. + + url = MINIO.presigned_get_object("test_csv", "some_prefix/rdu-weather-history.csv") + + credentials = { + "s3_access_key": "minioclient", + "s3_secret_key": "supersecure", + } + + params = { + "s3_endpoint": "objectstorage:9000", + "s3_secure": False, + "s3_bucket": "test_csv", + # Put this delimiter in as a canary to make sure table params override server params. + "delimiter": ",", + } + + tables = { + # Pass an empty table schema to denote we want to introspect it + "from_url": ([], {"url": url}), + "from_s3_rdu": ([], {"s3_object": "some_prefix/rdu-weather-history.csv"}), + "from_s3_encoding": ([], {"s3_object": "some_prefix/encoding-win-1252.csv"}), + } + + source = CSVDataSource( + local_engine_empty, + credentials, + params, + tables, + ) + + schema = source.introspect() + + assert schema == { + "from_url": ( + mock.ANY, + { + "autodetect_dialect": False, + "url": url, + "quotechar": '"', + "header": True, + "encoding": "utf-8", + "delimiter": ";", + "autodetect_header": False, + "autodetect_encoding": False, + }, + ), + "from_s3_rdu": ( + mock.ANY, + { + "encoding": "utf-8", + "autodetect_dialect": False, + "autodetect_encoding": False, + "autodetect_header": False, + "delimiter": ";", + "header": True, + "quotechar": '"', + "s3_object": "some_prefix/rdu-weather-history.csv", + }, + ), + "from_s3_encoding": ( + mock.ANY, + { + "s3_object": "some_prefix/encoding-win-1252.csv", + "quotechar": '"', + "header": True, + "encoding": "Windows-1252", + "autodetect_dialect": False, + "delimiter": ";", + "autodetect_header": False, + "autodetect_encoding": False, + }, + ), + } + + # Mount the datasets with this introspected schema. + try: + source.mount("temp_data", tables=schema) + rows = local_engine_empty.run_sql("SELECT * FROM temp_data.from_s3_encoding") + assert len(rows) == 3 + assert len(rows[0]) == 3 + finally: + local_engine_empty.delete_schema("temp_data") + + # Override the delimiter and blank out the schema for a single table + schema["from_s3_encoding"] = ( + [], + { + "s3_object": "some_prefix/encoding-win-1252.csv", + "quotechar": '"', + "header": True, + "encoding": "Windows-1252", + "autodetect_dialect": False, + # We force a delimiter "," here which will make the CSV a single-column one + # (to test we can actually override these) + "delimiter": ",", + "autodetect_header": False, + "autodetect_encoding": False, + }, + ) + + # Reintrospect the source with the new table parameters + source = CSVDataSource(local_engine_empty, credentials, params, schema) + new_schema = source.introspect() + assert len(new_schema) == 3 + # Check other tables are unchanged + assert new_schema["from_url"] == schema["from_url"] + assert new_schema["from_s3_rdu"] == schema["from_s3_rdu"] + + # Table with a changed separator only has one column (since we have , for delimiter + # instead of ;) + assert new_schema["from_s3_encoding"][0] == [ + TableColumn( + ordinal=1, name=";DATE;TEXT", pg_type="character varying", is_pk=False, comment=None + ) + ] + + try: + source.mount("temp_data", tables=new_schema) + rows = local_engine_empty.run_sql("SELECT * FROM temp_data.from_s3_encoding") + assert len(rows) == 3 + # Check we get one column now + assert rows[0] == ("1;01/07/2021;PaƱamao",) + finally: + local_engine_empty.delete_schema("temp_data") + + def test_csv_data_source_http(local_engine_empty): source = CSVDataSource( local_engine_empty, @@ -197,15 +448,9 @@ def test_csv_dialect_encoding_inference(): assert options.encoding == "Windows-1252" assert options.header is True - - # TODO: we keep these in the dialect struct rather than extract back out into the - # CSVOptions. Might need to do the latter if we want to return the proposed FDW table - # params to the user. - - # Note this line terminator is always "\r\n" since CSV assumes we use the - # universal newlines mode. - assert options.dialect.lineterminator == "\r\n" - assert options.dialect.delimiter == ";" + # NB we don't extract everything from the sniffed dialect, just the delimiter and the + # quotechar. The sniffer also returns doublequote and skipinitialspace. + assert options.delimiter == ";" data = list(reader) From eea118182c76cb4b47263f7c94f80209cb567855 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Tue, 13 Apr 2021 12:21:21 +0100 Subject: [PATCH 15/16] Record PG notices for last run_sql inside of the Engine class instead of discarding them --- splitgraph/engine/postgres/engine.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/splitgraph/engine/postgres/engine.py b/splitgraph/engine/postgres/engine.py index 0cbc3327..93e34617 100644 --- a/splitgraph/engine/postgres/engine.py +++ b/splitgraph/engine/postgres/engine.py @@ -228,6 +228,9 @@ def __init__( self.registry = registry self.in_fdw = in_fdw + """List of notices issued by the server during the previous execution of run_sql.""" + self.notices: List[str] = [] + if conn_params: self.conn_params = conn_params @@ -505,13 +508,17 @@ def run_sql( while True: with connection.cursor(**cursor_kwargs) as cur: try: + self.notices = [] cur.execute(statement, _convert_vals(arguments) if arguments else None) - if connection.notices and self.registry: - # Forward NOTICE messages from the registry back to the user - # (e.g. to nag them to upgrade etc). - for notice in connection.notices: - logging.info("%s says: %s", self.name, notice) + if connection.notices: + self.notices = connection.notices[:] del connection.notices[:] + + if self.registry: + # Forward NOTICE messages from the registry back to the user + # (e.g. to nag them to upgrade etc). + for notice in self.notices: + logging.info("%s says: %s", self.name, notice) except Exception as e: # Rollback the transaction (to a savepoint if we're inside the savepoint() context manager) self.rollback() From a17030577dde7961f7ea767b77a7733bf1432a53 Mon Sep 17 00:00:00 2001 From: Artjoms Iskovs Date: Tue, 13 Apr 2021 12:25:22 +0100 Subject: [PATCH 16/16] Improve data source error handling: * CSV data source: pass errors introspecting / scanning files back to the Python process over the PG notice mechanism * Mount/Preview data source methods: return a better defined error struct (instead of an ad hoc string) * Clean up some types --- examples/custom_fdw/src/hn_fdw/mount.py | 16 +++- splitgraph/commandline/__init__.py | 8 +- splitgraph/core/types.py | 28 +++++- splitgraph/exceptions.py | 11 ++- splitgraph/hooks/data_source/base.py | 20 +--- splitgraph/hooks/data_source/fdw.py | 91 ++++++++++++++----- splitgraph/ingestion/csv/__init__.py | 35 ++++--- splitgraph/ingestion/csv/fdw.py | 91 ++++++++++++------- splitgraph/ingestion/singer/data_source.py | 8 +- splitgraph/ingestion/snowflake/__init__.py | 9 +- splitgraph/ingestion/socrata/mount.py | 11 ++- .../custom_plugin_dir/some_plugin/plugin.py | 8 +- test/splitgraph/commandline/test_mount.py | 16 +++- test/splitgraph/ingestion/test_common.py | 15 ++- test/splitgraph/ingestion/test_csv.py | 46 +++++++++- 15 files changed, 286 insertions(+), 127 deletions(-) diff --git a/examples/custom_fdw/src/hn_fdw/mount.py b/examples/custom_fdw/src/hn_fdw/mount.py index bad22038..cb25cb99 100644 --- a/examples/custom_fdw/src/hn_fdw/mount.py +++ b/examples/custom_fdw/src/hn_fdw/mount.py @@ -1,6 +1,10 @@ -from typing import Dict, Mapping +from typing import Dict, Optional -from splitgraph.core.types import TableColumn, TableSchema +from splitgraph.core.types import ( + TableColumn, + TableInfo, + IntrospectionResult, +) # Define the schema of the foreign table we wish to create # We're only going to be fetching stories, so limit the columns to the ones that @@ -54,7 +58,9 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Query Hacker News stories through the Firebase API" - def get_table_options(self, table_name: str) -> Mapping[str, str]: + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: # Pass the endpoint name into the FDW return {"table": table_name} @@ -69,7 +75,7 @@ def get_server_options(self): "wrapper": "hn_fdw.fdw.HNForeignDataWrapper", } - def introspect(self) -> Dict[str, TableSchema]: + def introspect(self) -> IntrospectionResult: # Return a list of this FDW's tables and their schema. endpoints = self.params.get("endpoints") or _all_endpoints - return {e: _story_schema_spec for e in endpoints} + return {e: (_story_schema_spec, {}) for e in endpoints} diff --git a/splitgraph/commandline/__init__.py b/splitgraph/commandline/__init__.py index e60ae4fe..7e19ce75 100644 --- a/splitgraph/commandline/__init__.py +++ b/splitgraph/commandline/__init__.py @@ -40,6 +40,7 @@ from splitgraph.commandline.mount import mount_c from splitgraph.commandline.push_pull import pull_c, clone_c, push_c, upstream_c from splitgraph.commandline.splitfile import build_c, provenance_c, rebuild_c, dependents_c +from splitgraph.exceptions import get_exception_name from splitgraph.ingestion.singer.commandline import singer_group logger = logging.getLogger() @@ -61,13 +62,6 @@ def emit(self, record): logger.handlers[0].formatter = ColorFormatter() -def get_exception_name(o): - module = o.__class__.__module__ - if module is None or module == str.__class__.__module__: - return o.__class__.__name__ - return module + "." + o.__class__.__name__ - - def _patch_wrap_text(): # Patch click's formatter to strip ``` # etc which are used to format help text on the website. diff --git a/splitgraph/core/types.py b/splitgraph/core/types.py index 97a6262e..f97ba569 100644 --- a/splitgraph/core/types.py +++ b/splitgraph/core/types.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Dict, Tuple, Any, NamedTuple, Optional, List, Sequence, Union +from typing import Dict, Tuple, Any, NamedTuple, Optional, List, Sequence, Union, TypeVar Changeset = Dict[Tuple[str, ...], Tuple[bool, Dict[str, Any], Dict[str, Any]]] @@ -24,7 +24,31 @@ class TableColumn(NamedTuple): TableParams = Dict[str, Any] TableInfo = Union[List[str], Dict[str, Tuple[TableSchema, TableParams]]] SyncState = Dict[str, Any] -PreviewResult = Dict[str, Union[str, List[Dict[str, Any]]]] + + +class MountError(NamedTuple): + table_name: str + error: str + error_text: str + + +PreviewResult = Dict[str, Union[MountError, List[Dict[str, Any]]]] +IntrospectionResult = Dict[str, Union[Tuple[TableSchema, TableParams], MountError]] + +T = TypeVar("T") + + +def unwrap( + result: Dict[str, Union[MountError, T]], +) -> Tuple[Dict[str, T], Dict[str, MountError]]: + good = {} + bad = {} + for k, v in result.items(): + if isinstance(v, MountError): + bad[k] = v + else: + good[k] = v + return good, bad class Comparable(metaclass=ABCMeta): diff --git a/splitgraph/exceptions.py b/splitgraph/exceptions.py index aae209c7..55db97fa 100644 --- a/splitgraph/exceptions.py +++ b/splitgraph/exceptions.py @@ -99,7 +99,9 @@ class IncompleteObjectDownloadError(SplitGraphError): cleanup and reraise `reason` at the earliest opportunity.""" def __init__( - self, reason: Optional[BaseException], successful_objects: List[str], + self, + reason: Optional[BaseException], + successful_objects: List[str], ): self.reason = reason self.successful_objects = successful_objects @@ -119,3 +121,10 @@ class GQLUnauthenticatedError(GQLAPIError): class GQLRepoDoesntExistError(GQLAPIError): """Repository doesn't exist""" + + +def get_exception_name(o): + module = o.__class__.__module__ + if module is None or module == str.__class__.__module__: + return o.__class__.__name__ + return module + "." + o.__class__.__name__ diff --git a/splitgraph/hooks/data_source/base.py b/splitgraph/hooks/data_source/base.py index 88a2b9e1..430c3b60 100644 --- a/splitgraph/hooks/data_source/base.py +++ b/splitgraph/hooks/data_source/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from random import getrandbits -from typing import Dict, Any, Optional, TYPE_CHECKING, cast, Tuple +from typing import Dict, Any, Optional, TYPE_CHECKING, cast, Tuple, List from psycopg2._json import Json from psycopg2.sql import SQL, Identifier @@ -8,13 +8,13 @@ from splitgraph.core.engine import repository_exists from splitgraph.core.image import Image from splitgraph.core.types import ( - TableSchema, TableColumn, Credentials, Params, - TableParams, TableInfo, SyncState, + MountError, + IntrospectionResult, ) from splitgraph.engine import ResultShape @@ -75,17 +75,7 @@ def __init__( self.tables = tables @abstractmethod - def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: - # TODO here: dict str -> [tableschema, dict of suggested options] - # params -- add table options as a separate field? - # separate table schema - - # When going through the repo addition loop: - # * add separate options field mapping table names to options - # * return separate schema for table options - # * table options are optional for introspection - # * how to introspect: do import foreign schema; check fdw params - # * + def introspect(self) -> IntrospectionResult: raise NotImplementedError @@ -98,7 +88,7 @@ def mount( schema: str, tables: Optional[TableInfo] = None, overwrite: bool = True, - ): + ) -> Optional[List[MountError]]: """Instantiate the data source as foreign tables in a schema""" raise NotImplementedError diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index fc61732f..0ce23c93 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -1,7 +1,9 @@ +import json import logging +import re from abc import ABC, abstractmethod from copy import deepcopy -from typing import Optional, Mapping, Dict, List, Any, cast, Union, TYPE_CHECKING, Tuple +from typing import Optional, Mapping, Dict, List, Any, cast, TYPE_CHECKING, Tuple import psycopg2 from psycopg2.sql import SQL, Identifier @@ -13,7 +15,10 @@ TableParams, TableInfo, PreviewResult, + MountError, + IntrospectionResult, ) +from splitgraph.exceptions import get_exception_name from splitgraph.hooks.data_source.base import ( MountableDataSource, LoadableDataSource, @@ -72,15 +77,19 @@ def get_user_options(self) -> Mapping[str, str]: def get_server_options(self) -> Mapping[str, str]: return {} - def get_table_options(self, table_name: str) -> Dict[str, str]: - if not isinstance(self.tables, dict): + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: + tables = tables or self.tables + + if not isinstance(tables, dict): return {} return { k: str(v) - for k, v in self.tables.get( - table_name, cast(Tuple[TableSchema, TableParams], ({}, {})) - )[1].items() + for k, v in tables.get(table_name, cast(Tuple[TableSchema, TableParams], ({}, {})))[ + 1 + ].items() } def get_table_schema(self, table_name: str, table_schema: TableSchema) -> TableSchema: @@ -100,7 +109,7 @@ def mount( schema: str, tables: Optional[TableInfo] = None, overwrite: bool = True, - ): + ) -> Optional[List[MountError]]: tables = tables or self.tables or [] fdw = self.get_fdw_name() @@ -117,9 +126,15 @@ def mount( # Allow mounting tables into existing schemas self.engine.run_sql(SQL("CREATE SCHEMA IF NOT EXISTS {}").format(Identifier(schema))) - self._create_foreign_tables(schema, server_id, tables) + errors = self._create_foreign_tables(schema, server_id, tables) + if errors: + for e in errors: + logging.warning("Error mounting %s: %s: %s", e.table_name, e.error, e.error_text) + return errors - def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo): + def _create_foreign_tables( + self, schema: str, server_id: str, tables: TableInfo + ) -> List[MountError]: if isinstance(tables, list): try: remote_schema = self.get_remote_schema_name() @@ -127,7 +142,7 @@ def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo) raise NotImplementedError( "The FDW does not support IMPORT FOREIGN SCHEMA! Pass a tables dictionary." ) - import_foreign_schema(self.engine, schema, remote_schema, server_id, tables) + return import_foreign_schema(self.engine, schema, remote_schema, server_id, tables) else: for table_name, (table_schema, _) in tables.items(): logging.info("Mounting table %s", table_name) @@ -136,11 +151,12 @@ def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo) server_id, table_name, schema_spec=self.get_table_schema(table_name, table_schema), - extra_options=self.get_table_options(table_name), + extra_options=self.get_table_options(table_name, tables), ) self.engine.run_sql(query, args) + return [] - def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: + def introspect(self) -> IntrospectionResult: # Ability to override introspection by e.g. contacting the remote database without having # to mount the actual table. By default, just call out into mount() @@ -150,7 +166,7 @@ def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: tmp_schema = get_temporary_table_id() try: - self.mount(tmp_schema) + mount_errors = self.mount(tmp_schema) or [] table_options = dict(self._get_foreign_table_options(tmp_schema)) @@ -159,13 +175,17 @@ def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: for v in table_options.values(): jsonschema.validate(v, self.table_params_schema) - result = { + result: IntrospectionResult = { t: ( self.engine.get_full_table_schema(tmp_schema, t), cast(TableParams, table_options.get(t, {})), ) for t in self.engine.get_all_tables(tmp_schema) } + + # Add errors to the result as well + for error in mount_errors: + result[error.table_name] = error return result finally: self.engine.rollback() @@ -224,16 +244,23 @@ def preview(self, tables: Optional[TableInfo]) -> PreviewResult: tmp_schema = get_temporary_table_id() try: - self.mount(tmp_schema, tables=tables) + # Seed the result with the tables that failed to mount + result: PreviewResult = { + e.table_name: e for e in self.mount(tmp_schema, tables=tables) or [] + } + # Commit so that errors don't cancel the mount self.engine.commit() - result: Dict[str, Union[str, List[Dict[str, Any]]]] = {} for t in self.engine.get_all_tables(tmp_schema): + if t in result: + continue try: result[t] = self._preview_table(tmp_schema, t) except psycopg2.DatabaseError as e: logging.warning("Could not preview data for table %s", t, exc_info=e) - result[t] = str(e) + result[t] = MountError( + table_name=t, error=get_exception_name(e), error_text=str(e).strip() + ) return result finally: self.engine.rollback() @@ -320,7 +347,7 @@ def import_foreign_schema( server_id: str, tables: List[str], options: Optional[Dict[str, str]] = None, -) -> None: +) -> List[MountError]: from psycopg2.sql import Identifier, SQL # Construct a query: import schema limit to (%s, %s, ...) from server mountpoint_server into mountpoint @@ -336,6 +363,26 @@ def import_foreign_schema( engine.run_sql(query, args) + # Some of our FDWs use the PG notices as a side channel to pass errors and information back + # to the user, as IMPORT FOREIGN SCHEMA is all-or-nothing. If some of the tables failed to + # mount, we return those + the reasons. + import_errors: List[MountError] = [] + if engine.notices: + for notice in engine.notices: + match = re.match(r".*SPLITGRAPH: (.*)\n", notice) + if not match: + continue + notice_j = json.loads(match.group(1)) + import_errors.append( + MountError( + table_name=notice_j["table_name"], + error=notice_j["error"], + error_text=notice_j["error_text"], + ) + ) + + return import_errors + def create_foreign_table( schema: str, @@ -425,8 +472,8 @@ def get_server_options(self): def get_user_options(self): return {"user": self.credentials["username"], "password": self.credentials["password"]} - def get_table_options(self, table_name: str): - options = super().get_table_options(table_name) + def get_table_options(self, table_name: str, tables: Optional[TableInfo] = None): + options = super().get_table_options(table_name, tables) options["schema_name"] = self.params["remote_schema"] return options @@ -551,8 +598,8 @@ def get_server_options(self): def get_user_options(self): return {"username": self.credentials["username"], "password": self.credentials["password"]} - def get_table_options(self, table_name: str): - options = super().get_table_options(table_name) + def get_table_options(self, table_name: str, tables: Optional[TableInfo] = None): + options = super().get_table_options(table_name, tables) options["dbname"] = options.get("dbname", self.params["dbname"]) return options diff --git a/splitgraph/ingestion/csv/__init__.py b/splitgraph/ingestion/csv/__init__.py index 18b50bc9..261bf213 100644 --- a/splitgraph/ingestion/csv/__init__.py +++ b/splitgraph/ingestion/csv/__init__.py @@ -4,7 +4,7 @@ from psycopg2.sql import SQL, Identifier -from splitgraph.core.types import TableInfo +from splitgraph.core.types import TableInfo, MountError from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource, import_foreign_schema from splitgraph.ingestion.common import IngestionAdapter, build_commandline_help @@ -184,8 +184,10 @@ def from_commandline(cls, engine, commandline_kwargs) -> "CSVDataSource": credentials[k] = params[k] return cls(engine, credentials, params) - def get_table_options(self, table_name: str) -> Dict[str, str]: - result = super().get_table_options(table_name) + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: + result = super().get_table_options(table_name, tables) # Set a default s3_object if we're using S3 and not HTTP if "url" not in result: @@ -194,11 +196,15 @@ def get_table_options(self, table_name: str) -> Dict[str, str]: ) return result - def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo): + def _create_foreign_tables( + self, schema: str, server_id: str, tables: TableInfo + ) -> List[MountError]: # Override _create_foreign_tables (actual mounting code) to support TableInfo structs # where the schema is empty. This is so that we can call this data source with a limited # list of CSV files and their delimiters / other params and have it introspect just # those tables (instead of e.g. scanning the whole bucket). + errors: List[MountError] = [] + if isinstance(tables, dict): to_introspect = { table_name: table_options @@ -208,13 +214,15 @@ def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo) if to_introspect: # This FDW's implementation of IMPORT FOREIGN SCHEMA supports passing a JSON of table # options in options - import_foreign_schema( - self.engine, - schema, - self.get_remote_schema_name(), - server_id, - tables=list(to_introspect.keys()), - options={"table_options": json.dumps(to_introspect)}, + errors.extend( + import_foreign_schema( + self.engine, + schema, + self.get_remote_schema_name(), + server_id, + tables=list(to_introspect.keys()), + options={"table_options": json.dumps(to_introspect)}, + ) ) # Create the remaining tables (that have a schema) as usual. @@ -224,9 +232,10 @@ def _create_foreign_tables(self, schema: str, server_id: str, tables: TableInfo) if table_schema } if not tables: - return + return errors - super()._create_foreign_tables(schema, server_id, tables) + errors.extend(super()._create_foreign_tables(schema, server_id, tables)) + return errors def _get_foreign_table_options(self, schema: str) -> List[Tuple[str, Dict[str, Any]]]: options = super()._get_foreign_table_options(schema) diff --git a/splitgraph/ingestion/csv/fdw.py b/splitgraph/ingestion/csv/fdw.py index cb3883b4..f7eba951 100644 --- a/splitgraph/ingestion/csv/fdw.py +++ b/splitgraph/ingestion/csv/fdw.py @@ -2,15 +2,16 @@ import json import logging import os +from contextlib import contextmanager from copy import copy from itertools import islice -from typing import Tuple +from typing import Tuple, Optional import requests from minio import Minio import splitgraph.config -from splitgraph.commandline import get_exception_name +from splitgraph.exceptions import get_exception_name from splitgraph.ingestion.common import generate_column_names from splitgraph.ingestion.csv.common import CSVOptions, get_bool, make_csv_reader from splitgraph.ingestion.inference import infer_sg_schema @@ -69,6 +70,33 @@ def _get_table_definition(response, fdw_options, table_name, table_options): ) +@contextmanager +def report_errors(table_name: str): + """Context manager that ignores exceptions and serializes them to JSON using PG's notice + mechanism instead. The data source is meant to load these to report on partial failures + (e.g. failed to load one table, but not others).""" + try: + yield + except Exception as e: + logging.error( + "Error scanning %s, ignoring: %s: %s", + table_name, + get_exception_name(e), + e, + exc_info=e, + ) + log_to_postgres( + "SPLITGRAPH: " + + json.dumps( + { + "table_name": table_name, + "error": get_exception_name(e), + "error_text": str(e), + } + ) + ) + + class CSVForeignDataWrapper(ForeignDataWrapper): """Foreign data wrapper for CSV files stored in S3 buckets or HTTP""" @@ -216,49 +244,42 @@ def import_schema(cls, schema, srv_options, options, restriction_type, restricts @classmethod def _introspect_s3( cls, client, bucket, object_id, srv_options, table_name=None, table_options=None - ): + ) -> Optional[TableDefinition]: response = None # Default table name: truncate S3 object key up to the prefix table_name = table_name or object_id[len(srv_options.get("s3_object_prefix", "")) :] table_options = table_options or {} table_options.update({"s3_object": object_id}) - try: - response = client.get_object(bucket, object_id) - return _get_table_definition( - response, - srv_options, - table_name, - table_options, - ) - except Exception as e: - logging.error( - "Error scanning object %s, ignoring: %s: %s", - object_id, - get_exception_name(e), - e, - exc_info=e, - ) - log_to_postgres( - "Error scanning object %s, ignoring: %s: %s" % (object_id, get_exception_name(e), e) - ) - finally: - if response: - response.close() - response.release_conn() + with report_errors(table_name): + try: + response = client.get_object(bucket, object_id) + return _get_table_definition( + response, + srv_options, + table_name, + table_options, + ) + finally: + if response: + response.close() + response.release_conn() @classmethod - def _introspect_url(cls, srv_options, url, table_name=None, table_options=None): + def _introspect_url( + cls, srv_options, url, table_name=None, table_options=None + ) -> Optional[TableDefinition]: table_name = table_name or "data" table_options = table_options or {} - with requests.get( - url, stream=True, verify=os.environ.get("SSL_CERT_FILE", True) - ) as response: - response.raise_for_status() - stream = response.raw - if response.headers.get("Content-Encoding") == "gzip": - stream = gzip.GzipFile(fileobj=stream) - return _get_table_definition(stream, srv_options, table_name, table_options) + with report_errors(table_name): + with requests.get( + url, stream=True, verify=os.environ.get("SSL_CERT_FILE", True) + ) as response: + response.raise_for_status() + stream = response.raw + if response.headers.get("Content-Encoding") == "gzip": + stream = gzip.GzipFile(fileobj=stream) + return _get_table_definition(stream, srv_options, table_name, table_options) @classmethod def _get_s3_params(cls, fdw_options) -> Tuple[Minio, str, str]: diff --git a/splitgraph/ingestion/singer/data_source.py b/splitgraph/ingestion/singer/data_source.py index 63df025e..8f5832d0 100644 --- a/splitgraph/ingestion/singer/data_source.py +++ b/splitgraph/ingestion/singer/data_source.py @@ -7,12 +7,12 @@ from contextlib import contextmanager from io import StringIO from threading import Thread -from typing import Dict, Any, Optional, cast, Tuple +from typing import Dict, Any, Optional, cast from psycopg2.sql import Identifier, SQL from splitgraph.core.repository import Repository -from splitgraph.core.types import TableSchema, TableParams, TableInfo, SyncState +from splitgraph.core.types import TableParams, TableInfo, SyncState, IntrospectionResult from splitgraph.exceptions import DataSourceError from splitgraph.hooks.data_source.base import ( get_ingestion_state, @@ -195,11 +195,11 @@ def build_singer_catalog( catalog, tables=tables, use_legacy_stream_selection=self.use_legacy_stream_selection ) - def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: + def introspect(self) -> IntrospectionResult: config = self.get_singer_config() singer_schema = self._run_singer_discovery(config) - result = {} + result: IntrospectionResult = {} for stream in singer_schema["streams"]: stream_name = get_table_name(stream) stream_schema = get_sg_schema(stream) diff --git a/splitgraph/ingestion/snowflake/__init__.py b/splitgraph/ingestion/snowflake/__init__.py index a357df70..01cae739 100644 --- a/splitgraph/ingestion/snowflake/__init__.py +++ b/splitgraph/ingestion/snowflake/__init__.py @@ -1,8 +1,9 @@ import base64 import json import urllib.parse -from typing import Dict, Optional, cast, Mapping, Any +from typing import Dict, Optional, Any +from splitgraph.core.types import TableInfo from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource from splitgraph.ingestion.common import build_commandline_help @@ -138,8 +139,10 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Schema, table or a subquery from a Snowflake database" - def get_table_options(self, table_name: str) -> Dict[str, str]: - result = super().get_table_options(table_name) + def get_table_options( + self, table_name: str, tables: Optional[TableInfo] = None + ) -> Dict[str, str]: + result = super().get_table_options(table_name, tables) result["tablename"] = result.get("tablename", table_name) return result diff --git a/splitgraph/ingestion/socrata/mount.py b/splitgraph/ingestion/socrata/mount.py index 95fc61c3..8252ecf7 100644 --- a/splitgraph/ingestion/socrata/mount.py +++ b/splitgraph/ingestion/socrata/mount.py @@ -2,11 +2,11 @@ import json import logging from copy import deepcopy -from typing import Optional, Dict +from typing import Optional, Dict, List from psycopg2.sql import SQL, Identifier -from splitgraph.core.types import TableInfo +from splitgraph.core.types import TableInfo, MountError from splitgraph.exceptions import RepositoryNotFoundError from splitgraph.hooks.data_source.fdw import create_foreign_table, ForeignDataWrapperDataSource @@ -81,14 +81,16 @@ def get_server_options(self): options["app_token"] = str(self.credentials["app_token"]) return options - def _create_foreign_tables(self, schema, server_id, tables): + def _create_foreign_tables( + self, schema: str, server_id: str, tables: TableInfo + ) -> List[MountError]: from sodapy import Socrata from psycopg2.sql import SQL logging.info("Getting Socrata metadata") client = Socrata(domain=self.params["domain"], app_token=self.credentials.get("app_token")) - tables = tables or self.tables + tables = self.tables or tables if isinstance(tables, list): sought_ids = tables else: @@ -112,6 +114,7 @@ def _create_foreign_tables(self, schema, server_id, tables): ) self.engine.run_sql(SQL(";").join(mount_statements), mount_args) + return [] def generate_socrata_mount_queries(sought_ids, datasets, mountpoint, server_id, tables: TableInfo): diff --git a/test/resources/custom_plugin_dir/some_plugin/plugin.py b/test/resources/custom_plugin_dir/some_plugin/plugin.py index 861cc34c..2d7e3a8b 100644 --- a/test/resources/custom_plugin_dir/some_plugin/plugin.py +++ b/test/resources/custom_plugin_dir/some_plugin/plugin.py @@ -1,12 +1,10 @@ -from typing import Dict - -from splitgraph.core.types import TableSchema +from splitgraph.core.types import IntrospectionResult from splitgraph.hooks.data_source import DataSource class TestDataSource(DataSource): - def introspect(self) -> Dict[str, TableSchema]: - return {"some_table": []} + def introspect(self) -> IntrospectionResult: + return {"some_table": ([], {})} @classmethod def get_name(cls) -> str: diff --git a/test/splitgraph/commandline/test_mount.py b/test/splitgraph/commandline/test_mount.py index fe62f0af..2a85d281 100644 --- a/test/splitgraph/commandline/test_mount.py +++ b/test/splitgraph/commandline/test_mount.py @@ -23,8 +23,15 @@ _MONGO_PARAMS = { "tables": { "stuff": { - "options": {"database": "origindb", "collection": "stuff",}, - "schema": {"name": "text", "duration": "numeric", "happy": "boolean",}, + "options": { + "database": "origindb", + "collection": "stuff", + }, + "schema": { + "name": "text", + "duration": "numeric", + "happy": "boolean", + }, } } } @@ -55,7 +62,8 @@ def test_misc_mountpoint_management(pg_repo_local, mg_repo_local): # sgr mount with a file with tempfile.NamedTemporaryFile("w") as f: json.dump( - _MONGO_PARAMS, f, + _MONGO_PARAMS, + f, ) f.flush() @@ -150,5 +158,5 @@ def test_mount_plugin_dir(): plugin = plugin_class( engine=None, credentials={"access_token": "abc"}, params={"some_field": "some_value"} ) - assert plugin.introspect() == {"some_table": []} + assert plugin.introspect() == {"some_table": ([], {})} assert plugin.get_name() == "Test Data Source" diff --git a/test/splitgraph/ingestion/test_common.py b/test/splitgraph/ingestion/test_common.py index dbf46c2c..e5a84826 100644 --- a/test/splitgraph/ingestion/test_common.py +++ b/test/splitgraph/ingestion/test_common.py @@ -1,7 +1,14 @@ -from typing import Optional, Dict, Any, Tuple +from typing import Optional, Dict, Any, Tuple, cast from splitgraph.core.repository import Repository -from splitgraph.core.types import TableSchema, TableColumn, TableInfo, SyncState, TableParams +from splitgraph.core.types import ( + TableSchema, + TableColumn, + TableInfo, + SyncState, + TableParams, + IntrospectionResult, +) from splitgraph.engine import ResultShape from splitgraph.hooks.data_source.base import SyncableDataSource @@ -29,8 +36,8 @@ def get_name(cls) -> str: def get_description(cls) -> str: return "Test ingestion" - def introspect(self) -> Dict[str, Tuple[TableSchema, TableParams]]: - return SCHEMA + def introspect(self) -> IntrospectionResult: + return cast(IntrospectionResult, SCHEMA) def _sync( self, schema: str, state: Optional[SyncState] = None, tables: Optional[TableInfo] = None diff --git a/test/splitgraph/ingestion/test_csv.py b/test/splitgraph/ingestion/test_csv.py index 6a28a31f..05ef92ba 100644 --- a/test/splitgraph/ingestion/test_csv.py +++ b/test/splitgraph/ingestion/test_csv.py @@ -5,7 +5,7 @@ import pytest -from splitgraph.core.types import TableColumn +from splitgraph.core.types import TableColumn, MountError, unwrap from splitgraph.engine import ResultShape from splitgraph.hooks.s3_server import MINIO from splitgraph.ingestion.common import generate_column_names @@ -205,7 +205,7 @@ def test_csv_data_source_s3(local_engine_empty): schema = source.introspect() - assert len(schema.keys()) == 3 + assert len(schema.keys()) == 4 assert schema["fruits.csv"] == ( [ TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None), @@ -257,11 +257,38 @@ def test_csv_data_source_s3(local_engine_empty): ) assert len(schema["rdu-weather-history.csv"][0]) == 28 + assert schema["not_a_csv.txt"] == MountError( + table_name="not_a_csv.txt", + error="ValueError", + error_text="Malformed CSV: header has 7 columns, rows have 0 columns", + ) + + schema = unwrap(schema)[0] + + # Add a nonexistent file to the schema with malformed params to check preview error reporting + schema["doesnt_exist"] = ( + [], + {"s3_object": "doesnt_exist"}, + ) + schema["exists_but_broken"] = ( + # Force a schema that doesn't work for this CSV + [TableColumn(1, "col_1", "date", False)], + {"s3_object": "some_prefix/fruits.csv"}, + ) + preview = source.preview(schema) - assert len(preview.keys()) == 3 + assert len(preview.keys()) == 5 assert len(preview["fruits.csv"]) == 4 assert len(preview["encoding-win-1252.csv"]) == 3 assert len(preview["rdu-weather-history.csv"]) == 10 + assert preview["doesnt_exist"] == MountError( + table_name="doesnt_exist", error="minio.error.S3Error", error_text=mock.ANY + ) + assert preview["exists_but_broken"] == MountError( + table_name="exists_but_broken", + error="psycopg2.errors.InvalidDatetimeFormat", + error_text='invalid input syntax for type date: "1"', + ) try: source.mount("temp_data") @@ -312,6 +339,8 @@ def test_csv_data_source_multiple(local_engine_empty): "from_url": ([], {"url": url}), "from_s3_rdu": ([], {"s3_object": "some_prefix/rdu-weather-history.csv"}), "from_s3_encoding": ([], {"s3_object": "some_prefix/encoding-win-1252.csv"}), + "from_url_broken": ([], {"url": "invalid_url"}), + "from_s3_broken": ([], {"s3_object": "invalid_object"}), } source = CSVDataSource( @@ -363,9 +392,20 @@ def test_csv_data_source_multiple(local_engine_empty): "autodetect_encoding": False, }, ), + "from_url_broken": MountError( + table_name="from_url_broken", + error="requests.exceptions.MissingSchema", + error_text="Invalid URL 'invalid_url': No schema supplied. Perhaps you meant http://invalid_url?", + ), + "from_s3_broken": MountError( + table_name="from_s3_broken", + error="minio.error.S3Error", + error_text=mock.ANY, + ), } # Mount the datasets with this introspected schema. + schema = unwrap(schema)[0] try: source.mount("temp_data", tables=schema) rows = local_engine_empty.run_sql("SELECT * FROM temp_data.from_s3_encoding")