diff --git a/src/sqlite3_to_mysql/transporter.py b/src/sqlite3_to_mysql/transporter.py index 290c4a5..3c80029 100644 --- a/src/sqlite3_to_mysql/transporter.py +++ b/src/sqlite3_to_mysql/transporter.py @@ -297,6 +297,25 @@ def _sqlite_quote_ident(name: str) -> str: """Return a SQLite identifier with internal quotes escaped.""" return str(name).replace('"', '""') + def _get_table_info(self, table_name: str) -> t.List[t.Dict[str, t.Any]]: + """Fetch SQLite PRAGMA table information for a table.""" + quoted_table_name: str = self._sqlite_quote_ident(table_name) + pragma: str = "table_xinfo" if self._sqlite_table_xinfo_support else "table_info" + self._sqlite_cur.execute(f'PRAGMA {pragma}("{quoted_table_name}")') + return [dict(row) for row in self._sqlite_cur.fetchall()] + + def _get_table_primary_key_columns(self, table_name: str) -> t.List[str]: + """Return visible primary key columns ordered by their PK sequence.""" + primary_key_rows: t.List[t.Dict[str, t.Any]] = sorted( + ( + column + for column in self._get_table_info(table_name) + if column.get("pk", 0) > 0 and column.get("hidden", 0) not in (1, 2, 3) + ), + key=lambda column: column.get("pk", 0), + ) + return [safe_identifier_length(column["name"]) for column in primary_key_rows] + def _sqlite_table_has_rowid(self, table: str) -> bool: try: quoted_table: str = self._sqlite_quote_ident(table) @@ -1159,51 +1178,79 @@ def _add_foreign_keys(self, table_name: str) -> None: quoted_table_name: str = self._sqlite_quote_ident(table_name) self._sqlite_cur.execute(f'PRAGMA foreign_key_list("{quoted_table_name}")') + foreign_keys: t.Dict[int, t.List[t.Dict[str, t.Any]]] = {} for row in self._sqlite_cur.fetchall(): foreign_key: t.Dict[str, t.Any] = dict(row) + foreign_keys.setdefault(int(foreign_key["id"]), []).append(foreign_key) + + for fk_id, fk_rows in foreign_keys.items(): + fk_rows.sort(key=lambda fk_row: fk_row["seq"]) + ref_table: str = fk_rows[0]["table"] + from_columns: t.List[str] = [safe_identifier_length(fk_row["from"]) for fk_row in fk_rows] + + referenced_columns: t.List[str] + missing_references: t.List[t.Dict[str, t.Any]] = [fk_row for fk_row in fk_rows if not fk_row["to"]] + if missing_references: + if len(missing_references) != len(fk_rows): + self._logger.warning( + 'Skipping foreign key "%s" in table %s: partially defined reference columns.', + safe_identifier_length(fk_rows[0]["from"]), + safe_identifier_length(table_name), + ) + continue + + primary_keys: t.List[str] = self._get_table_primary_key_columns(ref_table) + if not primary_keys or len(primary_keys) != len(from_columns): + self._logger.warning( + 'Skipping foreign key "%s" in table %s: unable to resolve referenced primary key columns from table %s.', + safe_identifier_length(fk_rows[0]["from"]), + safe_identifier_length(table_name), + safe_identifier_length(ref_table), + ) + continue + referenced_columns = primary_keys + else: + referenced_columns = [safe_identifier_length(fk_row["to"]) for fk_row in fk_rows] + sql = """ ALTER TABLE `{table}` ADD CONSTRAINT `{table}_FK_{id}_{seq}` - FOREIGN KEY (`{column}`) - REFERENCES `{ref_table}`(`{ref_column}`) + FOREIGN KEY ({columns}) + REFERENCES `{ref_table}`({ref_columns}) ON DELETE {on_delete} ON UPDATE {on_update} """.format( - id=foreign_key["id"], - seq=foreign_key["seq"], + id=fk_id, + seq=fk_rows[0]["seq"], table=safe_identifier_length(table_name), - column=safe_identifier_length(foreign_key["from"]), - ref_table=safe_identifier_length(foreign_key["table"]), - ref_column=safe_identifier_length(foreign_key["to"]), + columns=", ".join(f"`{column}`" for column in from_columns), + ref_table=safe_identifier_length(ref_table), + ref_columns=", ".join(f"`{column}`" for column in referenced_columns), on_delete=( - foreign_key["on_delete"].upper() - if foreign_key["on_delete"].upper() != "SET DEFAULT" - else "NO ACTION" + fk_rows[0]["on_delete"].upper() if fk_rows[0]["on_delete"].upper() != "SET DEFAULT" else "NO ACTION" ), on_update=( - foreign_key["on_update"].upper() - if foreign_key["on_update"].upper() != "SET DEFAULT" - else "NO ACTION" + fk_rows[0]["on_update"].upper() if fk_rows[0]["on_update"].upper() != "SET DEFAULT" else "NO ACTION" ), ) try: self._logger.info( - "Adding foreign key to %s.%s referencing %s.%s", + "Adding foreign key to %s.(%s) referencing %s.(%s)", safe_identifier_length(table_name), - safe_identifier_length(foreign_key["from"]), - safe_identifier_length(foreign_key["table"]), - safe_identifier_length(foreign_key["to"]), + ", ".join(from_columns), + safe_identifier_length(ref_table), + ", ".join(referenced_columns), ) self._mysql_cur.execute(sql) self._mysql.commit() except mysql.connector.Error as err: self._logger.error( - "MySQL failed adding foreign key to %s.%s referencing %s.%s: %s", + "MySQL failed adding foreign key to %s.(%s) referencing %s.(%s): %s", safe_identifier_length(table_name), - safe_identifier_length(foreign_key["from"]), - safe_identifier_length(foreign_key["table"]), - safe_identifier_length(foreign_key["to"]), + ", ".join(from_columns), + safe_identifier_length(ref_table), + ", ".join(referenced_columns), err, ) raise diff --git a/tests/unit/sqlite3_to_mysql_test.py b/tests/unit/sqlite3_to_mysql_test.py index 358dd20..0d1af2e 100644 --- a/tests/unit/sqlite3_to_mysql_test.py +++ b/tests/unit/sqlite3_to_mysql_test.py @@ -1519,6 +1519,99 @@ def execute(self, statement): sqlite_cnx.close() sqlite_engine.dispose() + def test_add_foreign_keys_shorthand_references_primary_key( + self, + sqlite_database: str, + mysql_database: Engine, + mysql_credentials: MySQLCredentials, + mocker: MockFixture, + ) -> None: + proc = SQLite3toMySQL( # type: ignore[call-arg] + sqlite_file=sqlite_database, + mysql_user=mysql_credentials.user, + mysql_password=mysql_credentials.password, + mysql_host=mysql_credentials.host, + mysql_port=mysql_credentials.port, + mysql_database=mysql_credentials.database, + ) + sqlite_cursor = mocker.MagicMock() + sqlite_cursor.fetchall.side_effect = [ + [ + { + "id": 0, + "seq": 0, + "table": "parent", + "from": "parent_id", + "to": "", + "on_delete": "NO ACTION", + "on_update": "NO ACTION", + } + ], + [ + {"name": "id", "pk": 1, "hidden": 0}, + ], + ] + proc._sqlite_cur = sqlite_cursor + proc._sqlite_table_xinfo_support = False + proc._mysql_cur = mocker.MagicMock() + proc._mysql = mocker.MagicMock() + proc._logger = mocker.MagicMock() + + proc._add_foreign_keys("child") + + assert proc._mysql_cur.execute.call_count == 1 + executed_sql: str = proc._mysql_cur.execute.call_args[0][0] + assert "FOREIGN KEY (`parent_id`)" in executed_sql + assert "REFERENCES `parent`(`id`)" in executed_sql + proc._mysql.commit.assert_called_once() + + def test_add_foreign_keys_shorthand_pk_mismatch_is_skipped( + self, + sqlite_database: str, + mysql_database: Engine, + mysql_credentials: MySQLCredentials, + mocker: MockFixture, + ) -> None: + proc = SQLite3toMySQL( # type: ignore[call-arg] + sqlite_file=sqlite_database, + mysql_user=mysql_credentials.user, + mysql_password=mysql_credentials.password, + mysql_host=mysql_credentials.host, + mysql_port=mysql_credentials.port, + mysql_database=mysql_credentials.database, + ) + sqlite_cursor = mocker.MagicMock() + sqlite_cursor.fetchall.side_effect = [ + [ + { + "id": 1, + "seq": 0, + "table": "parent", + "from": "parent_id", + "to": "", + "on_delete": "NO ACTION", + "on_update": "NO ACTION", + } + ], + [ + {"name": "id", "pk": 1, "hidden": 0}, + {"name": "second", "pk": 2, "hidden": 0}, + ], + ] + proc._sqlite_cur = sqlite_cursor + proc._sqlite_table_xinfo_support = False + proc._mysql_cur = mocker.MagicMock() + proc._mysql = mocker.MagicMock() + proc._logger = mocker.MagicMock() + + proc._add_foreign_keys("child") + + proc._mysql_cur.execute.assert_not_called() + assert any( + "unable to resolve referenced primary key columns" in call.args[0] + for call in proc._logger.warning.call_args_list + ) + @pytest.mark.parametrize("quiet", [False, True]) def test_add_index_duplicate_key_error( self,