Skip to content

Commit

Permalink
feat(diff): get_diff function to retrieve human-readable diffs betwee…
Browse files Browse the repository at this point in the history
…n differrent sql statements
  • Loading branch information
ErwanSimonetti authored and Erwan SIMONETTI committed Mar 4, 2025
1 parent a2d884e commit 9fee51a
Showing 2 changed files with 170 additions and 0 deletions.
35 changes: 35 additions & 0 deletions sql_compare/__init__.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,26 @@ def compare(first_sql: str, second_sql: str) -> bool:
return first_sql_statements == second_sql_statements


def get_diff(
first_sql: str,
second_sql: str,
) -> list[list[list[str]]]:
"""Show the difference between two SQL schemas, ignoring differences due to column order and other non-significant SQL changes."""
first_statements = [Statement(t) for t in sqlparse.parse(first_sql)]
second_statements = [Statement(t) for t in sqlparse.parse(second_sql)]
first_diffs: list[list[str]] = []
second_diffs: list[list[str]] = []

for first, second in itertools.zip_longest(first_statements, second_statements):
if first != second:
first_value = first.value if first else []
second_value = second.value if second else []
first_diffs.append(first_value)
second_diffs.append(second_value)

return sorted([first_diffs, second_diffs])


@dataclasses.dataclass
class Token:
"""Wrapper around `sqlparse.sql.Token`."""
@@ -152,6 +172,21 @@ def statement_type(self) -> str:
# Only one keyword (e.g.: SELECT, INSERT, DELETE, etc.)
return keywords[0]

@property
def value(self) -> list[str]:
"""Return the reconstructed SQL statement from tokens as a list of strings, excluding tokens with a hash that is just a space."""

def process_token(token: Token | TokenList) -> list[str]:
if isinstance(token, TokenList):
return [
t for sub_token in token.tokens for t in process_token(sub_token)
]
if token.hash.strip():
return [token.hash]
return []

return [t for token in self.tokens for t in process_token(token)]


class UnorderedTokenList(TokenList):
"""Unordered token list."""
135 changes: 135 additions & 0 deletions tests/test_sql_compare.py
Original file line number Diff line number Diff line change
@@ -184,3 +184,138 @@ def test_compare_neq(first_sql: str, second_sql: str) -> None:
def test_statement_type(sql: str, expected_type: str) -> None:
statement = sql_compare.Statement(sqlparse.parse(sql)[0])
assert statement.statement_type == expected_type


@pytest.mark.parametrize(
("first_sql", "second_sql", "expected_diff"),
[
(
"CREATE TABLE foo (id INT PRIMARY KEY)",
"CREATE TABLE foo (id INT UNIQUE)",
[
[["CREATE", "TABLE", "foo", "(", "id", "INT", "PRIMARY KEY"]],
[["CREATE", "TABLE", "foo", "(", "id", "INT", "UNIQUE"]],
],
),
(
"CREATE TYPE public.colors AS ENUM ('RED', 'GREEN', 'BLUE')",
"CREATE TYPE public.colors AS ENUM ('BLUE', 'GREEN', 'RED')",
[[], []],
),
(
"CREATE TYPE public.colors AS ENUM ('RED', 'GREEN', 'BLUE')",
"CREATE TYPE public.colors AS ENUM ('YELLOW', 'BLUE', 'RED')",
[
[
[
"CREATE",
"TYPE",
"public",
".",
"colors",
"AS",
"ENUM",
"(",
"'BLUE'",
",",
"'GREEN'",
",",
"'RED'",
],
],
[
[
"CREATE",
"TYPE",
"public",
".",
"colors",
"AS",
"ENUM",
"(",
"'BLUE'",
",",
"'RED'",
",",
"'YELLOW'",
],
],
],
),
(
"""
CREATE TYPE public.status AS ENUM ('PENDING', 'APPROVED', 'REJECTED');
CREATE TABLE users (id INT, name VARCHAR(100), status public.status);
CREATE INDEX user_status_idx ON users (status);
""",
"""
CREATE TYPE public.status AS ENUM ('PENDING', 'APPROVED', 'ARCHIVED');
CREATE TABLE users (id INT, name VARCHAR(100), status public.status);
CREATE INDEX user_status_idx ON users (status);
CREATE TABLE logs (id INT, message TEXT);
""",
[
[
[
"CREATE",
"TYPE",
"public",
".",
"status",
"AS",
"ENUM",
"(",
"'APPROVED'",
",",
"'ARCHIVED'",
",",
"'PENDING'",
";",
],
[
"CREATE",
"TABLE",
"logs",
"(",
"id",
"INT",
",",
"message",
"TEXT",
";",
],
],
[
[
"CREATE",
"TYPE",
"public",
".",
"status",
"AS",
"ENUM",
"(",
"'APPROVED'",
",",
"'PENDING'",
",",
"'REJECTED'",
";",
],
[],
],
],
),
],
)
def test_get_diff(
first_sql: str,
second_sql: str,
expected_diff: list[list[list[str]]],
) -> None:
result = sql_compare.get_diff(first_sql, second_sql)
assert result == expected_diff
assert sql_compare.get_diff(first_sql, second_sql) == sql_compare.get_diff(
second_sql,
first_sql,
)

0 comments on commit 9fee51a

Please sign in to comment.