Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/sqlite3_to_mysql/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@
from .click_utils import OptionEatAll, prompt_password
from .debug_info import info
from .mysql_utils import MYSQL_INSERT_METHOD, MYSQL_TEXT_COLUMN_TYPES, mysql_supported_character_sets
from .transporter import MYSQL_TABLE_PREFIX_PATTERN


_copyright_header: str = f"sqlite3mysql version {package_version} Copyright (c) 2018-{datetime.now().year} Klemen Tusar"


def _validate_mysql_table_prefix(_: t.Any, __: t.Any, value: t.Optional[str]) -> str:
"""Validate the optional MySQL table prefix supplied via CLI."""
if not value:
return ""
if not MYSQL_TABLE_PREFIX_PATTERN.match(value):
raise click.BadParameter(
"Table prefix must start with a letter, contain only letters, numbers, or underscores, "
"and be at most 32 characters long."
)
return value


@click.command(
name="sqlite3mysql",
help=_copyright_header,
Expand Down Expand Up @@ -116,6 +129,14 @@
default="TEXT",
help="MySQL default text field type. Defaults to TEXT.",
)
@click.option(
"-b",
"--mysql-table-prefix",
default="",
callback=_validate_mysql_table_prefix,
help="MySQL table prefix must start with a letter and contain only letters, numbers, or underscores "
"with a maximum length of 32 characters.",
)
@click.option(
"--mysql-charset",
metavar="TEXT",
Expand Down Expand Up @@ -169,6 +190,7 @@ def cli(
mysql_integer_type: str,
mysql_string_type: str,
mysql_text_type: str,
mysql_table_prefix: str,
mysql_charset: str,
mysql_collation: str,
use_fulltext: bool,
Expand Down Expand Up @@ -224,6 +246,7 @@ def cli(
mysql_integer_type=mysql_integer_type,
mysql_string_type=mysql_string_type,
mysql_text_type=mysql_text_type,
mysql_table_prefix=mysql_table_prefix,
mysql_charset=mysql_charset.lower() if mysql_charset else "utf8mb4",
mysql_collation=mysql_collation.lower() if mysql_collation else None,
ignore_duplicate_keys=ignore_duplicate_keys,
Expand Down
92 changes: 56 additions & 36 deletions src/sqlite3_to_mysql/transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
key: value for key, value in sqlglot_mysql.MySQL.INVERSE_TIME_MAPPING.items() if key != "%H:%M:%S"
}
SQLGLOT_MYSQL_INVERSE_TIME_TRIE: t.Dict[str, t.Any] = new_trie(SQLGLOT_MYSQL_INVERSE_TIME_MAPPING)
MYSQL_TABLE_PREFIX_PATTERN: t.Pattern[str] = re.compile(r"^[A-Za-z][A-Za-z0-9_]{0,31}$")


class SQLite3toMySQL(SQLite3toMySQLAttributes):
Expand All @@ -91,6 +92,7 @@ class SQLite3toMySQL(SQLite3toMySQLAttributes):
re.IGNORECASE,
)
NUMERIC_LITERAL_PATTERN: t.Pattern[str] = re.compile(r"^[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?$")
TABLE_PREFIX_PATTERN: t.Pattern[str] = MYSQL_TABLE_PREFIX_PATTERN

MYSQL_CONNECTOR_VERSION: version.Version = version.parse(mysql_connector_version_string)

Expand Down Expand Up @@ -174,6 +176,14 @@ def __init__(self, **kwargs: Unpack[SQLite3toMySQLParams]):
if not kwargs.get("mysql_collation") and self._mysql_collation == "utf8mb4_0900_ai_ci":
self._mysql_collation = "utf8mb4_unicode_ci"

mysql_table_prefix: str = str(kwargs.get("mysql_table_prefix", "") or "")
if mysql_table_prefix and not self.TABLE_PREFIX_PATTERN.match(mysql_table_prefix):
raise ValueError(
"MySQL table prefix must start with a letter and contain only letters, numbers, or underscores "
"with a maximum length of 32 characters."
)
self._mysql_table_prefix = mysql_table_prefix

self._ignore_duplicate_keys = kwargs.get("ignore_duplicate_keys", False) or False

self._use_fulltext = kwargs.get("use_fulltext", False) or False
Expand Down Expand Up @@ -339,6 +349,13 @@ def _sqlite_table_has_rowid(self, table: str) -> bool:
except sqlite3.OperationalError:
return False

def _mysql_table_name(self, table_name: str) -> str:
"""Return the MySQL table name with any configured prefix applied."""
prefix: str = getattr(self, "_mysql_table_prefix", "")
if prefix:
return safe_identifier_length(f"{prefix}{table_name}")
return safe_identifier_length(table_name)

def _create_database(self) -> None:
try:
self._mysql_cur.execute(
Expand Down Expand Up @@ -881,8 +898,9 @@ def _create_mysql_view(self, view_name: str, view_sql: str) -> None:

def _create_table(self, table_name: str, transfer_rowid: bool = False, skip_default: bool = False) -> None:
primary_keys: t.List[t.Dict[str, str]] = []
mysql_table_name: str = self._mysql_table_name(table_name)

sql: str = f"CREATE TABLE IF NOT EXISTS `{safe_identifier_length(table_name)}` ( "
sql: str = f"CREATE TABLE IF NOT EXISTS `{mysql_table_name}` ( "

if transfer_rowid:
sql += " `rowid` BIGINT NOT NULL, "
Expand Down Expand Up @@ -912,23 +930,21 @@ def _create_table(self, table_name: str, transfer_rowid: bool = False, skip_defa
column["pk"] > 0 and column_type.startswith(("INT", "BIGINT")) and not compound_primary_key
)

allow_expr_defaults: bool = getattr(self, "_allow_expr_defaults", False)
is_mariadb: bool = getattr(self, "_is_mariadb", False)
base_type: str = self._base_mysql_column_type(column_type)

# Build DEFAULT clause safely (preserve falsy defaults like 0/'')
default_clause: str = ""
if (
not skip_default
and column["dflt_value"] is not None
and self._column_type_supports_default(base_type, allow_expr_defaults)
and self._column_type_supports_default(base_type, self._allow_expr_defaults)
and not auto_increment
):
td: str = self._translate_default_for_mysql(column_type, str(column["dflt_value"]))
if td != "":
stripped_td: str = td.strip()
if base_type in MYSQL_TEXT_COLUMN_TYPES_WITH_JSON and stripped_td.upper() != "NULL":
td = self._format_textual_default(stripped_td, allow_expr_defaults, is_mariadb)
td = self._format_textual_default(stripped_td, self._allow_expr_defaults, self._is_mariadb)
else:
td = stripped_td
default_clause = "DEFAULT " + td
Expand Down Expand Up @@ -960,7 +976,7 @@ def _create_table(self, table_name: str, transfer_rowid: bool = False, skip_defa
)

if transfer_rowid:
sql += f", CONSTRAINT `{safe_identifier_length(table_name)}_rowid` UNIQUE (`rowid`)"
sql += f", CONSTRAINT `{mysql_table_name}_rowid` UNIQUE (`rowid`)"

sql += f" ) ENGINE=InnoDB DEFAULT CHARSET={self._mysql_charset} COLLATE={self._mysql_collation}"

Expand All @@ -971,19 +987,20 @@ def _create_table(self, table_name: str, transfer_rowid: bool = False, skip_defa
if err.errno == errorcode.ER_INVALID_DEFAULT and not skip_default:
self._logger.warning(
"MySQL failed creating table %s with DEFAULT values: %s. Retrying without DEFAULT values ...",
safe_identifier_length(table_name),
mysql_table_name,
err,
)
return self._create_table(table_name, transfer_rowid, skip_default=True)
else:
self._logger.error(
"MySQL failed creating table %s: %s",
safe_identifier_length(table_name),
mysql_table_name,
err,
)
raise

def _truncate_table(self, table_name: str) -> None:
mysql_table_name: str = self._mysql_table_name(table_name)
self._mysql_cur.execute(
"""
SELECT `TABLE_NAME`
Expand All @@ -992,14 +1009,15 @@ def _truncate_table(self, table_name: str) -> None:
AND `TABLE_NAME` = %s
LIMIT 1
""",
(self._mysql_database, safe_identifier_length(table_name)),
(self._mysql_database, mysql_table_name),
)
if len(self._mysql_cur.fetchall()) > 0:
self._logger.info("Truncating table %s", safe_identifier_length(table_name))
self._mysql_cur.execute(f"TRUNCATE TABLE `{safe_identifier_length(table_name)}`")
self._logger.info("Truncating table %s", mysql_table_name)
self._mysql_cur.execute(f"TRUNCATE TABLE `{mysql_table_name}`")

def _add_indices(self, table_name: str) -> None:
quoted_table_name: str = self._sqlite_quote_ident(table_name)
mysql_table_name: str = self._mysql_table_name(table_name)

self._sqlite_cur.execute(f'PRAGMA table_info("{quoted_table_name}")')
table_columns: t.Dict[str, str] = {}
Expand Down Expand Up @@ -1063,7 +1081,7 @@ def _add_indices(self, table_name: str) -> None:
self._logger.warning(
"""Failed adding index to column "%s" in table %s: Column not found!""",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
)
continue

Expand Down Expand Up @@ -1107,12 +1125,13 @@ def _add_index(
index_infos: t.Tuple[t.Dict[str, t.Any], ...],
index_iteration: int = 0,
) -> None:
mysql_table_name: str = self._mysql_table_name(table_name)
sql: str = (
"""
ALTER TABLE `{table}`
ADD {index_type} `{name}`({columns})
""".format(
table=safe_identifier_length(table_name),
table=mysql_table_name,
index_type=index_type,
name=(
safe_identifier_length(index["name"])
Expand All @@ -1128,14 +1147,13 @@ def _add_index(
"""Adding %s to column "%s" in table %s""",
"unique index" if int(index["unique"]) == 1 else "index",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
)
self._mysql_cur.execute(sql)
self._mysql.commit()
except mysql.connector.Error as err:
if err.errno == errorcode.ER_DUP_KEYNAME:
if not self._ignore_duplicate_keys:
# handle a duplicate key name
self._add_index(
table_name=table_name,
index_type=index_type,
Expand All @@ -1147,59 +1165,59 @@ def _add_index(
self._logger.warning(
"""Duplicate key "%s" in table %s detected! Trying to create new key "%s_%s" ...""",
safe_identifier_length(index["name"]),
safe_identifier_length(table_name),
mysql_table_name,
safe_identifier_length(index["name"]),
index_iteration + 1,
)
else:
self._logger.warning(
"""Ignoring duplicate key "%s" in table %s!""",
safe_identifier_length(index["name"]),
safe_identifier_length(table_name),
mysql_table_name,
)
elif err.errno == errorcode.ER_DUP_ENTRY:
self._logger.warning(
"""Ignoring duplicate entry when adding index to column "%s" in table %s!""",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
)
elif err.errno == errorcode.ER_DUP_FIELDNAME:
self._logger.warning(
"""Failed adding index to column "%s" in table %s: Duplicate field name! Ignoring...""",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
)
elif err.errno == errorcode.ER_TOO_MANY_KEYS:
self._logger.warning(
"""Failed adding index to column "%s" in table %s: Too many keys! Ignoring...""",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
)
elif err.errno == errorcode.ER_TOO_LONG_KEY:
self._logger.warning(
"""Failed adding index to column "%s" in table %s: Key length too long! Ignoring...""",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
)
elif err.errno == errorcode.ER_BAD_FT_COLUMN:
# handle bad FULLTEXT index
self._logger.warning(
"""Failed adding FULLTEXT index to column "%s" in table %s. Retrying without FULLTEXT ...""",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
)
raise
else:
self._logger.error(
"""MySQL failed adding index to column "%s" in table %s: %s""",
", ".join(safe_identifier_length(index_info["name"]) for index_info in index_infos),
safe_identifier_length(table_name),
mysql_table_name,
err,
)
raise

def _add_foreign_keys(self, table_name: str) -> None:
quoted_table_name: str = self._sqlite_quote_ident(table_name)
mysql_table_name: str = self._mysql_table_name(table_name)
self._sqlite_cur.execute(f'PRAGMA foreign_key_list("{quoted_table_name}")')

foreign_keys: t.Dict[int, t.List[t.Dict[str, t.Any]]] = {}
Expand All @@ -1219,7 +1237,7 @@ def _add_foreign_keys(self, table_name: str) -> None:
self._logger.warning(
'Skipping foreign key "%s" in table %s: partially defined reference columns.',
safe_identifier_length(fk_rows[0]["from"]),
safe_identifier_length(table_name),
mysql_table_name,
)
continue

Expand All @@ -1228,14 +1246,15 @@ def _add_foreign_keys(self, table_name: str) -> None:
self._logger.warning(
'Skipping foreign key "%s" in table %s: unable to resolve referenced primary key columns from table %s.',
safe_identifier_length(fk_rows[0]["from"]),
safe_identifier_length(table_name),
safe_identifier_length(ref_table),
mysql_table_name,
self._mysql_table_name(ref_table),
)
continue
referenced_columns = primary_keys
else:
referenced_columns = [safe_identifier_length(fk_row["to"]) for fk_row in fk_rows]

mysql_ref_table_name: str = self._mysql_table_name(ref_table)
sql = """
ALTER TABLE `{table}`
ADD CONSTRAINT `{table}_FK_{id}_{seq}`
Expand All @@ -1246,9 +1265,9 @@ def _add_foreign_keys(self, table_name: str) -> None:
""".format(
id=fk_id,
seq=fk_rows[0]["seq"],
table=safe_identifier_length(table_name),
table=mysql_table_name,
columns=", ".join(f"`{column}`" for column in from_columns),
ref_table=safe_identifier_length(ref_table),
ref_table=mysql_ref_table_name,
ref_columns=", ".join(f"`{column}`" for column in referenced_columns),
on_delete=(
fk_rows[0]["on_delete"].upper() if fk_rows[0]["on_delete"].upper() != "SET DEFAULT" else "NO ACTION"
Expand All @@ -1261,19 +1280,19 @@ def _add_foreign_keys(self, table_name: str) -> None:
try:
self._logger.info(
"Adding foreign key to %s.(%s) referencing %s.(%s)",
safe_identifier_length(table_name),
mysql_table_name,
", ".join(from_columns),
safe_identifier_length(ref_table),
mysql_ref_table_name,
", ".join(referenced_columns),
)
self._mysql_cur.execute(sql)
self._mysql.commit()
except mysql.connector.Error as err:
self._logger.error(
"MySQL failed adding foreign key to %s.(%s) referencing %s.(%s): %s",
safe_identifier_length(table_name),
mysql_table_name,
", ".join(from_columns),
safe_identifier_length(ref_table),
mysql_ref_table_name,
", ".join(referenced_columns),
err,
)
Expand Down Expand Up @@ -1320,6 +1339,7 @@ def transfer(self) -> None:
table_name: str = table["name"]
object_type: str = table.get("type", "table")
quoted_table_name: str = self._sqlite_quote_ident(table_name)
mysql_table_name: str = self._mysql_table_name(table_name)

# check if we're transferring rowid
transfer_rowid: bool = self._with_rowid and self._sqlite_table_has_rowid(table_name)
Expand Down Expand Up @@ -1391,7 +1411,7 @@ def transfer(self) -> None:
{values_clause}
ON DUPLICATE KEY UPDATE {field_updates}
""".format(
table=safe_identifier_length(table_name),
table=mysql_table_name,
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
values_clause=(
"VALUES ({placeholders}) AS `__new__`"
Expand All @@ -1411,7 +1431,7 @@ def transfer(self) -> None:
VALUES ({placeholders})
""".format(
ignore="IGNORE" if self._mysql_insert_method.upper() == "IGNORE" else "",
table=safe_identifier_length(table_name),
table=mysql_table_name,
fields=("`{}`, " * len(columns)).rstrip(" ,").format(*columns),
placeholders=("%s, " * len(columns)).rstrip(" ,"),
)
Expand All @@ -1421,7 +1441,7 @@ def transfer(self) -> None:
self._logger.error(
"MySQL transfer failed inserting data into %s %s: %s",
"view" if object_type == "view" else "table",
safe_identifier_length(table_name),
mysql_table_name,
err,
)
raise
Expand Down
Loading
Loading