Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 88 additions & 13 deletions splitgraph/ingestion/csv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from psycopg2.sql import SQL, Identifier
from splitgraph.core.types import Credentials, MountError, TableInfo
from splitgraph.core.types import Credentials, MountError, Params, TableInfo
from splitgraph.hooks.data_source.fdw import (
ForeignDataWrapperDataSource,
import_foreign_schema,
Expand Down Expand Up @@ -78,16 +78,50 @@ class CSVDataSource(ForeignDataWrapperDataSource):
params_schema: Dict[str, Any] = {
"type": "object",
"properties": {
"url": {"type": "string", "description": "HTTP URL to the CSV file"},
"s3_endpoint": {
"type": "string",
"description": "S3 endpoint (including port if required)",
"connection": {
"type": "object",
"oneOf": [
{
"type": "object",
"required": ["connection_type", "url"],
"properties": {
"connection_type": {"type": "string", "const": "http"},
"url": {"type": "string", "description": "HTTP URL to the CSV file"},
},
},
{
"type": "object",
"required": ["connection_type", "s3_endpoint", "s3_bucket"],
"properties": {
"connection_type": {"type": "string", "const": "s3"},
"s3_endpoint": {
"type": "string",
"description": "S3 endpoint (including port if required)",
},
"s3_region": {
"type": "string",
"description": "Region of the S3 bucket",
},
"s3_secure": {
"type": "boolean",
"description": "Whether to use HTTPS for S3 access",
},
"s3_bucket": {
"type": "string",
"description": "Bucket the object is in",
},
"s3_object": {
"type": "string",
"description": "Limit the import to a single object",
},
"s3_object_prefix": {
"type": "string",
"description": "Prefix for object in S3 bucket",
},
},
},
],
},
"s3_region": {"type": "string", "description": "Region of the S3 bucket"},
"s3_secure": {"type": "boolean", "description": "Whether to use HTTPS for S3 access"},
"s3_bucket": {"type": "string", "description": "Bucket the object is in"},
"s3_object": {"type": "string", "description": "Limit the import to a single object"},
"s3_object_prefix": {"type": "string", "description": "Prefix for object in S3 bucket"},
"autodetect_header": {
"type": "boolean",
"description": "Detect whether the CSV file has a header automatically",
Expand Down Expand Up @@ -123,7 +157,6 @@ class CSVDataSource(ForeignDataWrapperDataSource):
},
"quotechar": {"type": "string", "description": "Character used to quote fields"},
},
"oneOf": [{"required": ["url"]}, {"required": ["s3_endpoint", "s3_bucket"]}],
}

table_params_schema: Dict[str, Any] = {
Expand Down Expand Up @@ -167,6 +200,18 @@ class CSVDataSource(ForeignDataWrapperDataSource):
build_commandline_help(credentials_schema) + "\n" + build_commandline_help(params_schema)
)

def __init__(
self,
engine: "PsycopgEngine",
credentials: Credentials,
params: Params,
tables: Optional[TableInfo] = None,
):
# TODO this is a hack to automatically accept both old and new versions of CSV params.
# We might need a more robust data source config migration system.
params = CSVDataSource.migrate_params(params)
super().__init__(engine, credentials, params, tables)

def get_fdw_name(self):
return "multicorn"

Expand All @@ -187,6 +232,30 @@ def from_commandline(cls, engine, commandline_kwargs) -> "CSVDataSource":
credentials[k] = params[k]
return cls(engine, credentials, params)

@classmethod
def migrate_params(cls, params: Params) -> Params:
params = deepcopy(params)
if "url" in params:
params["connection"] = {"connection_type": "http", "url": params["url"]}
del params["url"]
else:
connection = {"connection_type": "s3"}
for key in [
"s3_endpoint",
"s3_region",
"s3_secure",
"s3_bucket",
"s3_object",
"s3_object_prefix",
]:
try:
connection[key] = params.pop(key)
except KeyError:
pass

params["connection"] = connection
return params

def get_table_options(
self, table_name: str, tables: Optional[TableInfo] = None
) -> Dict[str, str]:
Expand Down Expand Up @@ -258,8 +327,13 @@ def get_server_options(self):
"wrapper": "splitgraph.ingestion.csv.fdw.CSVForeignDataWrapper"
}
for k in self.params_schema["properties"].keys():
# Flatten the options and extract connection parameters
if k in self.params:
options[k] = str(self.params[k])
if k != "connection":
options[k] = str(self.params[k])
else:
options.update(self.params[k])

for k in self.credentials_schema["properties"].keys():
if k in self.credentials:
options[k] = str(self.credentials[k])
Expand Down Expand Up @@ -303,9 +377,10 @@ def get_raw_url(

# Merge the table options to take care of overrides and use them to get URLs
# for each table.
server_options = self.get_server_options()
for table in tables:
table_options = super().get_table_options(table, tables)
full_options = {**self.credentials, **self.params, **table_options}
full_options = {**server_options, **table_options}
result[table] = self._get_url(full_options)

return result
74 changes: 62 additions & 12 deletions splitgraph/ingestion/snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import base64
import json
import urllib.parse
from typing import Any, Dict, Optional
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional

from splitgraph.core.types import TableInfo
from splitgraph.core.types import Credentials, Params, TableInfo
from splitgraph.hooks.data_source.fdw import ForeignDataWrapperDataSource
from splitgraph.ingestion.common import build_commandline_help

if TYPE_CHECKING:
from splitgraph.engine.postgres.engine import PsycopgEngine


def _encode_private_key(privkey: str):
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -34,18 +38,36 @@ class SnowflakeDataSource(ForeignDataWrapperDataSource):
"type": "object",
"properties": {
"username": {"type": "string", "description": "Username"},
"password": {"type": "string", "description": "Password"},
"secret": {
"type": "object",
"oneOf": [
{
"type": "object",
"required": ["secret_type", "password"],
"properties": {
"secret_type": {"type": "string", "const": "password"},
"password": {"type": "string", "description": "Password"},
},
},
{
"type": "object",
"required": ["secret_type", "private_key"],
"properties": {
"secret_type": {"type": "string", "const": "private_key"},
"private_key": {
"type": "string",
"description": "Private key in PEM format",
},
},
},
],
},
"account": {
"type": "string",
"description": "Account Locator, e.g. xy12345.us-east-2.aws. For more information, see https://docs.snowflake.com/en/user-guide/connecting.html",
},
"private_key": {
"type": "string",
"description": "Private key in PEM format",
},
},
"required": ["username", "account"],
"oneOf": [{"required": ["password"]}, {"required": ["private_key"]}],
}

params_schema = {
Expand Down Expand Up @@ -128,6 +150,18 @@ class SnowflakeDataSource(ForeignDataWrapperDataSource):
+ "The schema parameter is required when subquery isn't used."
)

def __init__(
self,
engine: "PsycopgEngine",
credentials: Credentials,
params: Params,
tables: Optional[TableInfo] = None,
):
# TODO this is a hack to automatically accept both old and new versions of CSV params.
# We might need a more robust data source config migration system.
credentials = SnowflakeDataSource.migrate_credentials(credentials)
super().__init__(engine, credentials, params, tables)

def get_fdw_name(self):
return "multicorn"

Expand Down Expand Up @@ -165,19 +199,35 @@ def get_server_options(self):
if "batch_size" in self.params:
options["batch_size"] = str(self.params["batch_size"])

if "private_key" in self.credentials:
if self.credentials["secret"]["secret_type"] == "private_key":
options["connect_args"] = json.dumps(
{"private_key": _encode_private_key(self.credentials["private_key"])}
{"private_key": _encode_private_key(self.credentials["secret"]["private_key"])}
)

return options

@classmethod
def migrate_credentials(cls, credentials: Credentials) -> Credentials:
credentials = deepcopy(credentials)
if "private_key" in credentials:
credentials["secret"] = {
"secret_type": "private_key",
"private_key": credentials.pop("private_key"),
}
elif "password" in credentials:
credentials["secret"] = {
"secret_type": "password",
"password": credentials.pop("password"),
}

return credentials

def _build_db_url(self) -> str:
"""Construct the SQLAlchemy Snowflake db_url"""

uname = self.credentials["username"]
if "password" in self.credentials:
uname += f":{self.credentials['password']}"
if self.credentials["secret"]["secret_type"] == "password":
uname += f":{self.credentials['secret']['password']}"

db_url = f"snowflake://{uname}@{self.credentials['account']}"
if "database" in self.params:
Expand Down
11 changes: 6 additions & 5 deletions splitgraph/ingestion/socrata/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
class SocrataDataSource(ForeignDataWrapperDataSource):
credentials_schema = {
"type": "object",
"properties": {
"app_token": {"type": ["string", "null"], "description": "Socrata app token, optional"}
},
"properties": {"app_token": {"type": "string", "description": "Socrata app token"}},
}

params_schema = {
Expand All @@ -30,7 +28,10 @@ class SocrataDataSource(ForeignDataWrapperDataSource):
},
"batch_size": {
"type": "integer",
"description": "Amount of rows to fetch from Socrata per request (limit parameter). Maximum 50000.",
"description": "Amount of rows to fetch from Socrata per request (limit parameter)",
"minimum": 1,
"default": 1000,
"maximum": 50000,
},
},
"required": ["domain"],
Expand Down Expand Up @@ -69,7 +70,7 @@ def from_commandline(cls, engine, commandline_kwargs) -> "SocrataDataSource":
if isinstance(tables, dict) and isinstance(next(iter(tables.values())), str):
tables = {k: ([], {"socrata_id": v}) for k, v in tables.items()}

credentials = Credentials({"app_token": params.pop("app_token", None)})
credentials = Credentials({})
return cls(engine, credentials, params, tables)

def get_server_options(self):
Expand Down
29 changes: 28 additions & 1 deletion test/splitgraph/ingestion/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest import mock

import pytest
from splitgraph.core.types import MountError, TableColumn, unwrap
from splitgraph.core.types import MountError, Params, TableColumn, unwrap
from splitgraph.engine import ResultShape
from splitgraph.hooks.s3_server import MINIO
from splitgraph.ingestion.common import generate_column_names
Expand Down Expand Up @@ -37,6 +37,33 @@
}


def test_csv_param_migration():
assert CSVDataSource.migrate_params(
Params(
{
"s3_endpoint": "objectstorage:9000",
"s3_secure": False,
"s3_bucket": "test_csv",
"delimiter": ",",
}
)
) == Params(
{
"delimiter": ",",
"connection": {
"connection_type": "s3",
"s3_endpoint": "objectstorage:9000",
"s3_secure": False,
"s3_bucket": "test_csv",
},
}
)

assert CSVDataSource.migrate_params(Params({"url": "some-url"})) == Params(
{"connection": {"connection_type": "http", "url": "some-url"}}
)


def test_csv_introspection_s3():
fdw_options = {
"s3_endpoint": "objectstorage:9000",
Expand Down