diff --git a/.coveragerc b/.coveragerc index 9f054259..724893d1 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,6 @@ [report] # Tested inside of the actual engine -omit = splitgraph/core/fdw_checkout.py,splitgraph/core/server.py +omit = splitgraph/core/fdw_checkout.py,splitgraph/core/server.py,splitgraph/ingestion/csv/fdw.py # Regexes for lines to exclude from consideration exclude_lines = diff --git a/engine/Dockerfile b/engine/Dockerfile index e02d135b..3241a469 100644 --- a/engine/Dockerfile +++ b/engine/Dockerfile @@ -185,6 +185,10 @@ RUN --mount=type=cache,id=pip-cache,target=/root/.cache/pip \ pip install "elasticsearch>=7.7.0" COPY ./engine/src/postgres-elasticsearch-fdw/pg_es_fdw /pg_es_fdw/pg_es_fdw +# Install the Snowflake SQLAlchemy connector +RUN --mount=type=cache,id=pip-cache,target=/root/.cache/pip \ + pip install "snowflake-sqlalchemy>=1.2.4" + ENV PATH "${PATH}:/splitgraph/bin" ENV PYTHONPATH "${PYTHONPATH}:/splitgraph:/pg_es_fdw" diff --git a/engine/src/Multicorn b/engine/src/Multicorn index 5a7b00c8..0a39ac3a 160000 --- a/engine/src/Multicorn +++ b/engine/src/Multicorn @@ -1 +1 @@ -Subproject commit 5a7b00c823adcdeebcc8ef21589940b58c1e0243 +Subproject commit 0a39ac3a84fbcb8a8ca460f0a6f0070b8eb731ae diff --git a/splitgraph/config/keys.py b/splitgraph/config/keys.py index 7709fe69..74d775e0 100644 --- a/splitgraph/config/keys.py +++ b/splitgraph/config/keys.py @@ -64,6 +64,7 @@ "socrata": "splitgraph.ingestion.socrata.mount.SocrataDataSource", "elasticsearch": "splitgraph.hooks.data_source.ElasticSearchDataSource", "csv": "splitgraph.ingestion.csv.CSVDataSource", + "snowflake": "splitgraph.ingestion.snowflake.SnowflakeDataSource", }, } diff --git a/splitgraph/core/output.py b/splitgraph/core/output.py index a6a0f980..2d9064a2 100644 --- a/splitgraph/core/output.py +++ b/splitgraph/core/output.py @@ -79,7 +79,7 @@ def conn_string_to_dict(connection: Optional[str]) -> Dict[str, Any]: password=match.group(3), ) else: - return dict(server=None, port=None, username=None, password=None) + return {} def parse_dt(string: str) -> datetime: diff --git a/splitgraph/hooks/data_source/fdw.py b/splitgraph/hooks/data_source/fdw.py index ff992e44..29139143 100644 --- a/splitgraph/hooks/data_source/fdw.py +++ b/splitgraph/hooks/data_source/fdw.py @@ -80,10 +80,10 @@ def from_commandline(cls, engine, commandline_kwargs) -> "ForeignDataWrapperData else: tables = None - credentials = { - "username": params.pop("username"), - "password": params.pop("password"), - } + credentials: Dict[str, Any] = {} + for k in cast(Dict[str, Any], cls.credentials_schema["properties"]).keys(): + if k in params: + credentials[k] = params[k] result = cls(engine, credentials, params) result.tables = tables return result @@ -345,10 +345,10 @@ def get_description(cls) -> str: params_schema = { "type": "object", "properties": { - "host": {"type": "string"}, - "port": {"type": "integer"}, - "dbname": {"type": "string"}, - "remote_schema": {"type": "string"}, + "host": {"type": "string", "description": "Remote hostname"}, + "port": {"type": "integer", "description": "Port"}, + "dbname": {"type": "string", "description": "Database name"}, + "remote_schema": {"type": "string", "description": "Remote schema name"}, "tables": _table_options_schema, }, "required": ["host", "port", "dbname", "remote_schema"], diff --git a/splitgraph/ingestion/common.py b/splitgraph/ingestion/common.py index e559c9c8..6dd50741 100644 --- a/splitgraph/ingestion/common.py +++ b/splitgraph/ingestion/common.py @@ -1,11 +1,12 @@ from abc import abstractmethod -from typing import Optional, Union +from typing import Optional, Union, Dict, List, Tuple from psycopg2.sql import SQL, Identifier from splitgraph.core.image import Image from splitgraph.core.repository import Repository -from splitgraph.core.types import TableSchema +from splitgraph.core.sql import POSTGRES_MAX_IDENTIFIER +from splitgraph.core.types import TableSchema, TableColumn from splitgraph.engine.postgres.engine import PsycopgEngine from splitgraph.exceptions import CheckoutError @@ -99,7 +100,7 @@ def to_table( if_exists: str = "patch", schema_check: bool = True, no_header: bool = False, - **kwargs + **kwargs, ): tmp_schema = repository.to_schema() @@ -170,7 +171,7 @@ def to_data( image: Optional[Union[Image, str]] = None, repository: Optional[Repository] = None, use_lq: bool = False, - **kwargs + **kwargs, ): if image is None: if repository is None: @@ -195,3 +196,73 @@ def to_data( # (won't download objects unless needed). with image.query_schema() as tmp_schema: return self.query_to_data(image.engine, query, tmp_schema, **kwargs) + + +def dedupe_sg_schema(schema_spec: TableSchema, prefix_len: int = 59) -> TableSchema: + """ + Some foreign schemas have columns that are longer than 63 characters + where the first 63 characters are the same between several columns + (e.g. odn.data.socrata.com). This routine renames columns in a schema + to make sure this can't happen (by giving duplicates a number suffix). + """ + + # We truncate the column name to 59 to leave space for the underscore + # and 3 digits (max PG identifier is 63 chars) + prefix_counts: Dict[str, int] = {} + columns_nums: List[Tuple[str, int]] = [] + + for column in schema_spec: + column_short = column.name[:prefix_len] + count = prefix_counts.get(column_short, 0) + columns_nums.append((column_short, count)) + prefix_counts[column_short] = count + 1 + + result = [] + for (_, position), column in zip(columns_nums, schema_spec): + column_short = column.name[:prefix_len] + count = prefix_counts[column_short] + if count > 1: + result.append( + TableColumn( + column.ordinal, + f"{column_short}_{position:03d}", + column.pg_type, + column.is_pk, + column.comment, + ) + ) + else: + result.append( + TableColumn( + column.ordinal, + column.name[:POSTGRES_MAX_IDENTIFIER], + column.pg_type, + column.is_pk, + column.comment, + ) + ) + return result + + +def _format_jsonschema(prop, schema, required): + if prop == "tables": + return """tables: Tables to mount (default all). If a list, will import only these tables. +If a dictionary, must have the format + {"table_name": {"schema": {"col_1": "type_1", ...}, + "options": {[get passed to CREATE FOREIGN TABLE]}}}.""" + parts = [f"{prop}:"] + if "description" in schema: + parts.append(schema["description"]) + if parts[-1][-1] != ".": + parts[-1] += "." + + if prop in required: + parts.append("Required.") + return " ".join(parts) + + +def build_commandline_help(json_schema): + required = json_schema.get("required", []) + return "\n".join( + _format_jsonschema(p, pd, required) for p, pd in json_schema["properties"].items() + ) diff --git a/splitgraph/ingestion/csv/__init__.py b/splitgraph/ingestion/csv/__init__.py index f3b4abbd..0ea5e62f 100644 --- a/splitgraph/ingestion/csv/__init__.py +++ b/splitgraph/ingestion/csv/__init__.py @@ -4,7 +4,7 @@ from psycopg2.sql import SQL, Identifier from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource -from splitgraph.ingestion.common import IngestionAdapter +from splitgraph.ingestion.common import IngestionAdapter, build_commandline_help if TYPE_CHECKING: from splitgraph.engine.postgres.engine import PsycopgEngine @@ -119,6 +119,7 @@ class CSVDataSource(ForeignDataWrapperDataSource): For example: +\b ``` sgr mount csv target_schema -o@- < str: + return "Snowflake" + + @classmethod + def get_description(cls) -> str: + return "Schema, table or a subquery from a Snowflake database" + + def get_table_options(self, table_name: str) -> Mapping[str, str]: + result = cast(Dict[str, str], super().get_table_options(table_name)) + result["tablename"] = result.get("tablename", table_name) + return result + + def get_server_options(self): + options: Dict[str, Optional[str]] = { + "wrapper": "multicorn.sqlalchemyfdw.SqlAlchemyFdw", + } + + # Construct the SQLAlchemy db_url + + db_url = f"snowflake://{self.credentials['username']}:{self.credentials['password']}@{self.credentials['account']}" + + if "database" in self.params: + db_url += f"/{self.params['database']}" + if "schema" in self.params: + db_url += f"/{self.params['schema']}" + + extra_params = {} + if "warehouse" in self.params: + extra_params["warehouse"] = self.params["warehouse"] + if "role" in self.params: + extra_params["role"] = self.params["role"] + + db_url += urllib.parse.urlencode(extra_params) + + options["db_url"] = db_url + + return options + + def get_remote_schema_name(self) -> str: + if "schema" not in self.params: + raise ValueError("Cannot IMPORT FOREIGN SCHEMA without a schema!") + return str(self.params["schema"]) diff --git a/splitgraph/ingestion/socrata/querying.py b/splitgraph/ingestion/socrata/querying.py index 3fdab99b..42ede2f5 100644 --- a/splitgraph/ingestion/socrata/querying.py +++ b/splitgraph/ingestion/socrata/querying.py @@ -1,7 +1,7 @@ -from typing import Dict, Any, List, Tuple, Optional +from typing import Dict, Any, Tuple, Optional -from splitgraph.core.sql import POSTGRES_MAX_IDENTIFIER from splitgraph.core.types import TableSchema, TableColumn +from splitgraph.ingestion.common import dedupe_sg_schema try: from multicorn import ANY @@ -39,52 +39,6 @@ def _socrata_to_pg_type(socrata_type): return "text" -def dedupe_sg_schema(schema_spec: TableSchema, prefix_len: int = 59) -> TableSchema: - """ - Some Socrata schemas have columns that are longer than 63 characters - where the first 63 characters are the same between several columns - (e.g. odn.data.socrata.com). This routine renames columns in a schema - to make sure this can't happen (by giving duplicates a number suffix). - """ - - # We truncate the column name to 59 to leave space for the underscore - # and 3 digits (max PG identifier is 63 chars) - prefix_counts: Dict[str, int] = {} - columns_nums: List[Tuple[str, int]] = [] - - for column in schema_spec: - column_short = column.name[:prefix_len] - count = prefix_counts.get(column_short, 0) - columns_nums.append((column_short, count)) - prefix_counts[column_short] = count + 1 - - result = [] - for (_, position), column in zip(columns_nums, schema_spec): - column_short = column.name[:prefix_len] - count = prefix_counts[column_short] - if count > 1: - result.append( - TableColumn( - column.ordinal, - f"{column_short}_{position:03d}", - column.pg_type, - column.is_pk, - column.comment, - ) - ) - else: - result.append( - TableColumn( - column.ordinal, - column.name[:POSTGRES_MAX_IDENTIFIER], - column.pg_type, - column.is_pk, - column.comment, - ) - ) - return result - - def socrata_to_sg_schema(metadata: Dict[str, Any]) -> Tuple[TableSchema, Dict[str, str]]: try: col_names = metadata["resource"]["columns_field_name"] diff --git a/test/splitgraph/ingestion/test_snowflake.py b/test/splitgraph/ingestion/test_snowflake.py new file mode 100644 index 00000000..f1deb365 --- /dev/null +++ b/test/splitgraph/ingestion/test_snowflake.py @@ -0,0 +1,48 @@ +from unittest.mock import Mock + +from splitgraph.ingestion.snowflake import SnowflakeDataSource + + +def test_snowflake_data_source_dburl_conversion(): + source = SnowflakeDataSource( + Mock(), + credentials={ + "username": "username", + "password": "password", + "account": "abcdef.eu-west-1.aws", + }, + params={ + "database": "SOME_DB", + "schema": "TPCH_SF100", + "warehouse": "my_warehouse", + "role": "role", + }, + ) + + assert source.get_server_options() == { + "db_url": "snowflake://username:password@abcdef.eu-west-1.aws/SOME_DB/TPCH_SF100warehouse=my_warehouse&role=role", + "wrapper": "multicorn.sqlalchemyfdw.SqlAlchemyFdw", + } + + source = SnowflakeDataSource( + Mock(), + credentials={ + "username": "username", + "password": "password", + "account": "abcdef.eu-west-1.aws", + }, + params={ + "database": "SOME_DB", + "tables": { + "test_table": { + "schema": {"col_1": "int", "col_2": "varchar"}, + "options": {"subquery": "SELECT col_1, col_2 FROM other_table"}, + } + }, + }, + ) + + assert source.get_table_options("test_table") == { + "subquery": "SELECT col_1, col_2 FROM other_table", + "tablename": "test_table", + } diff --git a/test/splitgraph/ingestion/test_socrata.py b/test/splitgraph/ingestion/test_socrata.py index e4c11f96..3888d099 100644 --- a/test/splitgraph/ingestion/test_socrata.py +++ b/test/splitgraph/ingestion/test_socrata.py @@ -18,8 +18,8 @@ cols_to_socrata, sortkeys_to_socrata, _socrata_to_pg_type, - dedupe_sg_schema, ) +from splitgraph.ingestion.common import dedupe_sg_schema class Q: