Skip to content

Commit

Permalink
Fix FDW previews when column names have percentage signs
Browse files Browse the repository at this point in the history
These column names issues when using argument interpolation, as psycopg2 treats
them as arguments and raises an IndexError (not enough arguments). Fix by
escaping them with double-percentage signs.

Note that this isn't fixed everywhere, just on the FDW data source preview code
path.

Also add a test using a CSV file with a percentage sign, as this is how this bug
was first discovered.
  • Loading branch information
mildbyte committed Jan 24, 2022
1 parent 62984e7 commit b32da5a
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 10 deletions.
18 changes: 13 additions & 5 deletions splitgraph/hooks/data_source/fdw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, cast

import psycopg2
from psycopg2.sql import SQL, Identifier
from psycopg2.sql import SQL, Composed, Identifier

from splitgraph.core.types import (
Credentials,
Expand Down Expand Up @@ -398,34 +398,42 @@ def import_foreign_schema(
return import_errors


def _identifier(c: str) -> Identifier:
"""
Create a psycopg2 composable Identifier escaping percentage signs in column names. These
cause issues when using argument interpolation, as psycopg2 treats them as arguments.
"""
return Identifier(c.replace("%", "%%").replace("{", "{{").replace("}", "}}"))


def create_foreign_table(
schema: str,
server: str,
table_name: str,
schema_spec: TableSchema,
extra_options: Optional[Dict[str, str]] = None,
):
) -> Tuple[Composed, List[str]]:
table_options = extra_options or {}

query = SQL("CREATE FOREIGN TABLE {}.{} (").format(Identifier(schema), Identifier(table_name))
query += SQL(",".join("{} %s " % col.pg_type for col in schema_spec)).format(
*(Identifier(col.name) for col in schema_spec)
*(_identifier(col.name) for col in schema_spec)
)
query += SQL(") SERVER {}").format(Identifier(server))

args: List[str] = []
if table_options:
table_opts, table_optvals = zip(*table_options.items())
query += SQL(" OPTIONS(")
query += SQL(",").join(Identifier(o) + SQL(" %s") for o in table_opts) + SQL(");")
query += SQL(",").join(_identifier(o) + SQL(" %s") for o in table_opts) + SQL(");")
args.extend(table_optvals)

for col in schema_spec:
if col.comment:
query += SQL("COMMENT ON COLUMN {}.{}.{} IS %s;").format(
Identifier(schema),
Identifier(table_name),
Identifier(col.name),
_identifier(col.name),
)
args.append(col.comment)
return query, args
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"Id","Submit time","Profile","Status","Source currency","Amount paid by","Fee","Amount converted","Excess refund","Target currency","Converted and sent to","Exchange rate","Exchange Rate Date","Payout time","Name","Account details","Reference","VAT (10%)"
"123456789","2022/02/19 18:52:18","business","transferred","USD","15000.20","75.00","15000.00","0.0","GBP","12000.0","0.7336","2022/02/21 19:00:20","2022/02/21 19:00:36","Some Company","","",
105 changes: 100 additions & 5 deletions test/splitgraph/ingestion/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_csv_introspection_s3():
restricts=[],
)

assert len(schema) == 3
assert len(schema) == 4
schema = sorted(schema, key=lambda s: s["table_name"])

assert schema[0] == {
Expand Down Expand Up @@ -116,8 +116,10 @@ def test_csv_introspection_s3():
"options": mock.ANY,
}
assert load_options(schema[1]["options"]) == _s3_fruits_opts
assert schema[2]["table_name"] == "rdu-weather-history.csv"
assert schema[2]["columns"][0] == {"column_name": "date", "type_name": "date"}
assert schema[2]["table_name"] == "percentage_sign.csv"
assert schema[2]["columns"][0] == {"column_name": "Id", "type_name": "integer"}
assert schema[3]["table_name"] == "rdu-weather-history.csv"
assert schema[3]["columns"][0] == {"column_name": "date", "type_name": "date"}


def test_csv_introspection_http():
Expand Down Expand Up @@ -251,7 +253,7 @@ def test_csv_data_source_s3(local_engine_empty):

schema = source.introspect()

assert len(schema.keys()) == 4
assert len(schema.keys()) == 5
assert schema["fruits.csv"] == (
[
TableColumn(ordinal=1, name="fruit_id", pg_type="integer", is_pk=False, comment=None),
Expand Down Expand Up @@ -302,6 +304,98 @@ def test_csv_data_source_s3(local_engine_empty):
},
)
assert len(schema["rdu-weather-history.csv"][0]) == 28
assert schema["percentage_sign.csv"] == (
[
TableColumn(ordinal=1, name="Id", pg_type="integer", is_pk=False, comment=None),
TableColumn(
ordinal=2,
name="Submit time",
pg_type="character varying",
is_pk=False,
comment=None,
),
TableColumn(
ordinal=3, name="Profile", pg_type="character varying", is_pk=False, comment=None
),
TableColumn(
ordinal=4, name="Status", pg_type="character varying", is_pk=False, comment=None
),
TableColumn(
ordinal=5,
name="Source currency",
pg_type="character varying",
is_pk=False,
comment=None,
),
TableColumn(
ordinal=6, name="Amount paid by", pg_type="numeric", is_pk=False, comment=None
),
TableColumn(ordinal=7, name="Fee", pg_type="numeric", is_pk=False, comment=None),
TableColumn(
ordinal=8, name="Amount converted", pg_type="numeric", is_pk=False, comment=None
),
TableColumn(
ordinal=9, name="Excess refund", pg_type="numeric", is_pk=False, comment=None
),
TableColumn(
ordinal=10,
name="Target currency",
pg_type="character varying",
is_pk=False,
comment=None,
),
TableColumn(
ordinal=11,
name="Converted and sent to",
pg_type="numeric",
is_pk=False,
comment=None,
),
TableColumn(
ordinal=12, name="Exchange rate", pg_type="numeric", is_pk=False, comment=None
),
TableColumn(
ordinal=13,
name="Exchange Rate Date",
pg_type="character varying",
is_pk=False,
comment=None,
),
TableColumn(
ordinal=14,
name="Payout time",
pg_type="character varying",
is_pk=False,
comment=None,
),
TableColumn(
ordinal=15, name="Name", pg_type="character varying", is_pk=False, comment=None
),
TableColumn(
ordinal=16,
name="Account details",
pg_type="character varying",
is_pk=False,
comment=None,
),
TableColumn(
ordinal=17, name="Reference", pg_type="character varying", is_pk=False, comment=None
),
TableColumn(
ordinal=18, name="VAT (10%)", pg_type="character varying", is_pk=False, comment=None
),
],
{
"autodetect_dialect": False,
"autodetect_encoding": False,
"autodetect_header": False,
"delimiter": ",",
"encoding": "utf-8",
"header": True,
"quotechar": '"',
"s3_object": "some_prefix/percentage_sign.csv",
},
)

assert schema["not_a_csv.txt"] == MountError(
table_name="not_a_csv.txt",
Expand All @@ -323,10 +417,11 @@ def test_csv_data_source_s3(local_engine_empty):
)

preview = source.preview(schema)
assert len(preview.keys()) == 5
assert len(preview.keys()) == 6
assert len(preview["fruits.csv"]) == 4
assert len(preview["encoding-win-1252.csv"]) == 3
assert len(preview["rdu-weather-history.csv"]) == 10
assert len(preview["percentage_sign.csv"]) == 1
assert preview["doesnt_exist"] == MountError.construct(
table_name="doesnt_exist", error="minio.error.S3Error", error_text=mock.ANY
)
Expand Down

0 comments on commit b32da5a

Please sign in to comment.