Skip to content

Commit

Permalink
Merge pull request #616 from splitgraph/bugfix/get-primary-key
Browse files Browse the repository at this point in the history
Fix potential SQL injection in `get_primary_key`.
  • Loading branch information
mildbyte committed Jan 21, 2022
2 parents 190fe80 + 379afc3 commit bfb1a23
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 26 deletions.
16 changes: 7 additions & 9 deletions splitgraph/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,13 @@ def get_full_table_schema(self, schema: str, table_name: str) -> "TableSchema":
assert schema != "pg_temp"

results = self.run_sql(
SQL(
"SELECT c.attnum, c.attname, "
"pg_catalog.format_type(c.atttypid, c.atttypmod), "
"col_description('{}.{}'::regclass, c.attnum) "
"FROM pg_attribute c JOIN pg_class t ON c.attrelid = t.oid "
"JOIN pg_namespace n ON t.relnamespace = n.oid "
"WHERE n.nspname = %s AND t.relname = %s AND NOT c.attisdropped "
"AND c.attnum >= 0 ORDER BY c.attnum "
).format(Identifier(schema), Identifier(table_name)),
"SELECT c.attnum, c.attname, "
"pg_catalog.format_type(c.atttypid, c.atttypmod), pgd.description "
"FROM pg_attribute c JOIN pg_class t ON c.attrelid = t.oid "
"JOIN pg_namespace n ON t.relnamespace = n.oid "
"LEFT JOIN pg_description pgd ON pgd.objoid = t.oid AND pgd.objsubid = c.attnum "
"WHERE n.nspname = %s AND t.relname = %s AND NOT c.attisdropped "
"AND c.attnum >= 0 ORDER BY c.attnum ",
(schema, table_name),
)

Expand Down
28 changes: 18 additions & 10 deletions splitgraph/engine/postgres/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def get_conn_str(conn_params: Dict[str, Optional[str]]) -> str:
return f"postgresql://{username}:{password}@{server}:{port}/{dbname}"


def _quote_ident(val: str) -> str:
return '"%s"' % val.replace('"', '""')


class PsycopgEngine(SQLEngine):
"""Postgres SQL engine backed by a Psycopg connection."""

Expand Down Expand Up @@ -588,12 +592,16 @@ def get_primary_keys(self, schema: str, table: str) -> List[Tuple[str, str]]:
return cast(
List[Tuple[str, str]],
self.run_sql(
SQL(
"""SELECT a.attname, format_type(a.atttypid, a.atttypmod)
FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = '{}.{}'::regclass AND i.indisprimary"""
).format(Identifier(schema), Identifier(table)),
"""SELECT c.column_name, c.data_type
FROM information_schema.table_constraints tc
JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name)
JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema
AND tc.table_name = c.table_name AND ccu.column_name = c.column_name
WHERE constraint_type = 'PRIMARY KEY'
AND tc.constraint_schema = %s
AND tc.table_name = %s
""",
(schema, table),
return_shape=ResultShape.MANY_MANY,
),
)
Expand Down Expand Up @@ -828,11 +836,11 @@ def track_tables(self, tables: List[Tuple[str, str]]) -> None:
"""Install the audit trigger on the required tables"""
self.run_sql(
SQL(";").join(
SQL("SELECT {}.audit_table('{}.{}')").format(
Identifier(_AUDIT_SCHEMA), Identifier(s), Identifier(t)
itertools.repeat(
SQL("SELECT {}.audit_table(%s)").format(Identifier(_AUDIT_SCHEMA)), len(tables)
)
for s, t in tables
)
),
["{}.{}".format(_quote_ident(s), _quote_ident(t)) for s, t in tables],
)

def untrack_tables(self, tables: List[Tuple[str, str]]) -> None:
Expand Down
14 changes: 7 additions & 7 deletions test/splitgraph/commands/test_commit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,15 +1068,15 @@ def test_create_object_out_of_band(local_engine_empty):
)


def test_unicode_columns(local_engine_empty):
def test_unicode_columns_and_quotes_in_table_names(local_engine_empty):
OUTPUT.init()
OUTPUT.run_sql("CREATE TABLE таблица (key INTEGER PRIMARY KEY, столбец VARCHAR)")
OUTPUT.run_sql("COMMENT ON COLUMN таблица.столбец IS 'комментарий';")
OUTPUT.run_sql("INSERT INTO таблица (key, столбец) VALUES (1, 'one'), (2, 'two')")
OUTPUT.run_sql('CREATE TABLE "таблица\'" (key INTEGER PRIMARY KEY, столбец VARCHAR)')
OUTPUT.run_sql("COMMENT ON COLUMN \"таблица'\".столбец IS 'комментарий';")
OUTPUT.run_sql("INSERT INTO \"таблица'\" (key, столбец) VALUES (1, 'one'), (2, 'two')")

image = OUTPUT.commit()

assert image.get_table("таблица").table_schema == [
assert image.get_table("таблица'").table_schema == [
TableColumn(ordinal=1, name="key", pg_type="integer", is_pk=True, comment=None),
TableColumn(
ordinal=2,
Expand All @@ -1087,10 +1087,10 @@ def test_unicode_columns(local_engine_empty):
),
]
image.checkout()
assert OUTPUT.run_sql("SELECT * FROM таблица WHERE столбец = 'two'") == [(2, "two")]
assert OUTPUT.run_sql("SELECT * FROM \"таблица'\" WHERE столбец = 'two'") == [(2, "two")]

image.checkout(layered=True)
assert OUTPUT.run_sql("SELECT * FROM таблица WHERE столбец = 'one'") == [(1, "one")]
assert OUTPUT.run_sql("SELECT * FROM \"таблица'\" WHERE столбец = 'one'") == [(1, "one")]


def test_commit_diff_views(pg_repo_local):
Expand Down

0 comments on commit bfb1a23

Please sign in to comment.