From c472fee09dcea465a23edfc415d317045fd58424 Mon Sep 17 00:00:00 2001 From: Klemen Tusar Date: Sun, 16 Nov 2025 19:20:20 +0000 Subject: [PATCH] :recycle: refactor data transfer and schema writing for MySQL to SQLite migration --- src/mysql_to_sqlite3/data_transfer.py | 90 +++ src/mysql_to_sqlite3/schema_writer.py | 233 +++++++ src/mysql_to_sqlite3/transporter.py | 764 +++-------------------- src/mysql_to_sqlite3/type_translation.py | 437 +++++++++++++ src/mysql_to_sqlite3/types.py | 7 + tox.ini | 2 +- 6 files changed, 838 insertions(+), 695 deletions(-) create mode 100644 src/mysql_to_sqlite3/data_transfer.py create mode 100644 src/mysql_to_sqlite3/schema_writer.py create mode 100644 src/mysql_to_sqlite3/type_translation.py diff --git a/src/mysql_to_sqlite3/data_transfer.py b/src/mysql_to_sqlite3/data_transfer.py new file mode 100644 index 0000000..9c19c47 --- /dev/null +++ b/src/mysql_to_sqlite3/data_transfer.py @@ -0,0 +1,90 @@ +"""Chunked data transfer helpers.""" + +from __future__ import annotations + +import sqlite3 +import typing as t +from math import ceil + +import mysql.connector +from mysql.connector import errorcode +from tqdm import tqdm, trange + +from mysql_to_sqlite3.sqlite_utils import encode_data_for_sqlite + + +# pylint: disable=protected-access # Access transporter internals for efficiency + + +if t.TYPE_CHECKING: + from mysql_to_sqlite3.transporter import MySQLtoSQLite + + +class DataTransferManager: + """Handles moving table data from MySQL to SQLite.""" + + def __init__(self, ctx: "MySQLtoSQLite") -> None: + """Store transporter context for DB access, logging, and chunk options.""" + self._ctx = ctx + + def transfer_table_data( + self, table_name: str, sql: str, total_records: int = 0, attempting_reconnect: bool = False + ) -> None: + """Stream rows from MySQL and batch insert into SQLite, handling reconnects.""" + ctx = self._ctx + if attempting_reconnect: + ctx._mysql.reconnect() + try: + if ctx._chunk_size is not None and ctx._chunk_size > 0: + for chunk in trange( + ctx._current_chunk_number, + int(ceil(total_records / ctx._chunk_size)), + disable=ctx._quiet, + ): + ctx._current_chunk_number = chunk + ctx._sqlite_cur.executemany( + sql, + ( + tuple(encode_data_for_sqlite(col) if col is not None else None for col in row) + for row in ctx._mysql_cur.fetchmany(ctx._chunk_size) + ), + ) + else: + ctx._sqlite_cur.executemany( + sql, + ( + tuple(encode_data_for_sqlite(col) if col is not None else None for col in row) + for row in tqdm( + ctx._mysql_cur.fetchall(), + total=total_records, + disable=ctx._quiet, + ) + ), + ) + ctx._sqlite.commit() + except mysql.connector.Error as err: + if err.errno == errorcode.CR_SERVER_LOST: + if not attempting_reconnect: + ctx._logger.warning("Connection to MySQL server lost.\nAttempting to reconnect.") + self.transfer_table_data( + table_name=table_name, + sql=sql, + total_records=total_records, + attempting_reconnect=True, + ) + return + ctx._logger.warning("Connection to MySQL server lost.\nReconnection attempt aborted.") + raise + ctx._logger.error( + "MySQL transfer failed reading table data from table %s: %s", + table_name, + err, + ) + raise + except sqlite3.Error as err: + ctx._logger.error( + "SQLite transfer failed inserting data into table %s: %s", + table_name, + err, + ) + raise diff --git a/src/mysql_to_sqlite3/schema_writer.py b/src/mysql_to_sqlite3/schema_writer.py new file mode 100644 index 0000000..322f287 --- /dev/null +++ b/src/mysql_to_sqlite3/schema_writer.py @@ -0,0 +1,233 @@ +"""Schema creation helpers for the transporter.""" + +from __future__ import annotations + +import typing as t + +from mysql.connector.types import RowItemType + +from mysql_to_sqlite3.sqlite_utils import Integer_Types + + +# pylint: disable=protected-access # Helper intentionally uses transporter internals + + +if t.TYPE_CHECKING: + from mysql_to_sqlite3.transporter import MySQLtoSQLite + + +class SchemaWriter: + """Builds SQLite schemas (tables and views) from MySQL metadata.""" + + def __init__(self, ctx: "MySQLtoSQLite") -> None: + """Hold a reference to the transporter orchestrator.""" + self._ctx = ctx + + def _build_create_table_sql(self, table_name: str) -> str: + ctx = self._ctx + table_ident = ctx._quote_sqlite_identifier(table_name) + sql: str = f"CREATE TABLE IF NOT EXISTS {table_ident} (" + primary: str = "" + indices: str = "" + + safe_table = ctx._escape_mysql_backticks(table_name) + ctx._mysql_cur_dict.execute(f"SHOW COLUMNS FROM `{safe_table}`") + rows: t.Sequence[t.Optional[t.Dict[str, RowItemType]]] = ctx._mysql_cur_dict.fetchall() + + primary_keys: int = sum(1 for row in rows if row is not None and row["Key"] == "PRI") + + for row in rows: + if row is None: + continue + column_type = ctx._translate_type_from_mysql_to_sqlite( + column_type=row["Type"], # type: ignore[arg-type] + sqlite_json1_extension_enabled=ctx._sqlite_json1_extension_enabled, + ) + if row["Key"] == "PRI" and row["Extra"] == "auto_increment" and primary_keys == 1: + if column_type in Integer_Types: + sql += "\n\t{name} INTEGER PRIMARY KEY AUTOINCREMENT,".format( + name=ctx._quote_sqlite_identifier( + str(row["Field"].decode() if isinstance(row["Field"], (bytes, bytearray)) else row["Field"]) + ), + ) + else: + ctx._logger.warning( + 'Primary key "%s" in table "%s" is not an INTEGER type! Skipping.', + row["Field"], + table_name, + ) + else: + sql += "\n\t{name} {type} {notnull} {default} {collation},".format( + name=ctx._quote_sqlite_identifier( + str(row["Field"].decode() if isinstance(row["Field"], (bytes, bytearray)) else row["Field"]) + ), + type=column_type, + notnull="NULL" if row["Null"] == "YES" else "NOT NULL", + default=ctx._translate_default_from_mysql_to_sqlite(row["Default"], column_type, row["Extra"]), + collation=ctx._data_type_collation_sequence(ctx._collation, column_type), + ) + + ctx._mysql_cur_dict.execute( + """ + SELECT s.INDEX_NAME AS `name`, + IF (NON_UNIQUE = 0 AND s.INDEX_NAME = 'PRIMARY', 1, 0) AS `primary`, + IF (NON_UNIQUE = 0 AND s.INDEX_NAME <> 'PRIMARY', 1, 0) AS `unique`, + {auto_increment} + GROUP_CONCAT(s.COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS `columns`, + GROUP_CONCAT(c.COLUMN_TYPE ORDER BY SEQ_IN_INDEX) AS `types` + FROM information_schema.STATISTICS AS s + JOIN information_schema.COLUMNS AS c + ON s.TABLE_SCHEMA = c.TABLE_SCHEMA + AND s.TABLE_NAME = c.TABLE_NAME + AND s.COLUMN_NAME = c.COLUMN_NAME + WHERE s.TABLE_SCHEMA = %s + AND s.TABLE_NAME = %s + GROUP BY s.INDEX_NAME, s.NON_UNIQUE {group_by_extra} + """.format( + auto_increment=( + "IF (c.EXTRA = 'auto_increment', 1, 0) AS `auto_increment`," + if primary_keys == 1 + else "0 as `auto_increment`," + ), + group_by_extra=" ,c.EXTRA" if primary_keys == 1 else "", + ), + (ctx._mysql_database, table_name), + ) + mysql_indices: t.Sequence[t.Optional[t.Dict[str, RowItemType]]] = ctx._mysql_cur_dict.fetchall() + for index in mysql_indices: + if index is None: + continue + if isinstance(index["name"], bytes): + index_name = index["name"].decode() + elif isinstance(index["name"], str): + index_name = index["name"] + else: + index_name = str(index["name"]) + + ctx._mysql_cur_dict.execute( + """ + SELECT COUNT(*) AS `count` + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = %s + AND TABLE_NAME = %s + """, + (ctx._mysql_database, index_name), + ) + collision: t.Optional[t.Dict[str, RowItemType]] = ctx._mysql_cur_dict.fetchone() + table_collisions: int = 0 + if collision is not None: + table_collisions = int(collision["count"]) # type: ignore[arg-type] + + columns: str = ( + index["columns"].decode() if isinstance(index["columns"], (bytes, bytearray)) else str(index["columns"]) + ) + + types: str = "" + if isinstance(index["types"], bytes): + types = index["types"].decode() + elif isinstance(index["types"], str): + types = index["types"] + + if len(columns) > 0: + if index["primary"] in {1, "1"}: + if (index["auto_increment"] not in {1, "1"}) or any( + ctx._translate_type_from_mysql_to_sqlite( + column_type=_type, + sqlite_json1_extension_enabled=ctx._sqlite_json1_extension_enabled, + ) + not in Integer_Types + for _type in types.split(",") + ): + primary += "\n\tPRIMARY KEY ({columns})".format( + columns=", ".join( + ctx._quote_sqlite_identifier(column.strip()) for column in columns.split(",") + ) + ) + else: + proposed_index_name = ( + f"{table_name}_{index_name}" if (table_collisions > 0 or ctx._prefix_indices) else index_name + ) + if not ctx._prefix_indices: + unique_index_name = ctx._get_unique_index_name(proposed_index_name) + else: + unique_index_name = proposed_index_name + unique_kw = "UNIQUE " if index["unique"] in {1, "1"} else "" + indices += """CREATE {unique}INDEX IF NOT EXISTS {name} ON {table} ({columns});""".format( + unique=unique_kw, + name=ctx._quote_sqlite_identifier(unique_index_name), + table=ctx._quote_sqlite_identifier(table_name), + columns=", ".join( + ctx._quote_sqlite_identifier(column.strip()) for column in columns.split(",") + ), + ) + + sql += primary + sql = sql.rstrip(", ") + + if not ctx._without_tables and not ctx._without_foreign_keys: + server_version: t.Optional[t.Tuple[int, ...]] = ctx._mysql.get_server_version() + ctx._mysql_cur_dict.execute( + """ + SELECT k.COLUMN_NAME AS `column`, + k.REFERENCED_TABLE_NAME AS `ref_table`, + k.REFERENCED_COLUMN_NAME AS `ref_column`, + c.UPDATE_RULE AS `on_update`, + c.DELETE_RULE AS `on_delete` + FROM information_schema.TABLE_CONSTRAINTS AS i + {JOIN} information_schema.KEY_COLUMN_USAGE AS k + ON i.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND i.TABLE_NAME = k.TABLE_NAME + {JOIN} information_schema.REFERENTIAL_CONSTRAINTS AS c + ON c.CONSTRAINT_NAME = i.CONSTRAINT_NAME + AND c.TABLE_NAME = i.TABLE_NAME + WHERE i.TABLE_SCHEMA = %s + AND i.TABLE_NAME = %s + AND i.CONSTRAINT_TYPE = %s + GROUP BY i.CONSTRAINT_NAME, + k.COLUMN_NAME, + k.REFERENCED_TABLE_NAME, + k.REFERENCED_COLUMN_NAME, + c.UPDATE_RULE, + c.DELETE_RULE + """.format( + JOIN=( + "JOIN" + if (server_version is not None and server_version[0] == 8 and server_version[2] > 19) + else "LEFT JOIN" + ) + ), + (ctx._mysql_database, table_name, "FOREIGN KEY"), + ) + for foreign_key in ctx._mysql_cur_dict.fetchall(): + if foreign_key is None: + continue + col = ctx._quote_sqlite_identifier( + foreign_key["column"].decode() + if isinstance(foreign_key["column"], (bytes, bytearray)) + else str(foreign_key["column"]) # type: ignore[index] + ) + ref_table = ctx._quote_sqlite_identifier( + foreign_key["ref_table"].decode() + if isinstance(foreign_key["ref_table"], (bytes, bytearray)) + else str(foreign_key["ref_table"]) # type: ignore[index] + ) + ref_col = ctx._quote_sqlite_identifier( + foreign_key["ref_column"].decode() + if isinstance(foreign_key["ref_column"], (bytes, bytearray)) + else str(foreign_key["ref_column"]) # type: ignore[index] + ) + on_update = str(foreign_key["on_update"] or "NO ACTION").upper() # type: ignore[index] + on_delete = str(foreign_key["on_delete"] or "NO ACTION").upper() # type: ignore[index] + sql += ( + f",\n\tFOREIGN KEY({col}) REFERENCES {ref_table} ({ref_col}) " + f"ON UPDATE {on_update} " + f"ON DELETE {on_delete}" + ) + + sql += "\n)" + if ctx._sqlite_strict: + sql += " STRICT" + sql += ";\n" + sql += indices + + return sql diff --git a/src/mysql_to_sqlite3/transporter.py b/src/mysql_to_sqlite3/transporter.py index fb5400e..3e55865 100644 --- a/src/mysql_to_sqlite3/transporter.py +++ b/src/mysql_to_sqlite3/transporter.py @@ -1,5 +1,7 @@ """Use to transfer a MySQL database to SQLite.""" +from __future__ import annotations + import logging import os import re @@ -7,7 +9,6 @@ import typing as t from datetime import timedelta from decimal import Decimal -from math import ceil from os.path import realpath from sys import stdout @@ -17,7 +18,9 @@ from mysql.connector.types import RowItemType from sqlglot import Expression, exp, parse_one from sqlglot.errors import ParseError -from tqdm import tqdm, trange + + +# pylint: disable=protected-access # Legacy helpers intentionally exposed for tests try: @@ -27,16 +30,16 @@ # Python < 3.11 from typing_extensions import Unpack # type: ignore -from mysql_to_sqlite3.mysql_utils import CHARSET_INTRODUCERS +from mysql_to_sqlite3 import type_translation as _type_helpers +from mysql_to_sqlite3.data_transfer import DataTransferManager +from mysql_to_sqlite3.schema_writer import SchemaWriter from mysql_to_sqlite3.sqlite_utils import ( CollatingSequences, - Integer_Types, adapt_decimal, adapt_timedelta, convert_date, convert_decimal, convert_timedelta, - encode_data_for_sqlite, ) from mysql_to_sqlite3.types import MySQLtoSQLiteAttributes, MySQLtoSQLiteParams @@ -44,8 +47,7 @@ class MySQLtoSQLite(MySQLtoSQLiteAttributes): """Use this class to transfer a MySQL database to SQLite.""" - COLUMN_PATTERN: t.Pattern[str] = re.compile(r"^[^(]+") - COLUMN_LENGTH_PATTERN: t.Pattern[str] = re.compile(r"\(\d+\)$") + escape_mysql_backticks = staticmethod(_type_helpers.escape_mysql_backticks) def __init__(self, **kwargs: Unpack[MySQLtoSQLiteParams]) -> None: """Constructor.""" @@ -190,6 +192,9 @@ def __init__(self, **kwargs: Unpack[MySQLtoSQLiteParams]) -> None: self._logger.error(err) raise + self._schema_writer = SchemaWriter(self) + self._data_transfer = DataTransferManager(self) + @classmethod def _setup_logger( cls, log_file: t.Optional[t.Union[str, "os.PathLike[t.Any]"]] = None, quiet: bool = False @@ -214,243 +219,54 @@ def _setup_logger( @classmethod def _valid_column_type(cls, column_type: str) -> t.Optional[t.Match[str]]: - return cls.COLUMN_PATTERN.match(column_type.strip()) + return _type_helpers._valid_column_type(column_type) @classmethod def _column_type_length(cls, column_type: str) -> str: - suffix: t.Optional[t.Match[str]] = cls.COLUMN_LENGTH_PATTERN.search(column_type) - if suffix: - return suffix.group(0) - return "" + return _type_helpers._column_type_length(column_type) @staticmethod def _decode_column_type(column_type: t.Union[str, bytes]) -> str: - if isinstance(column_type, str): - return column_type - if isinstance(column_type, bytes): - try: - return column_type.decode() - except (UnicodeDecodeError, AttributeError): - pass - return str(column_type) + return _type_helpers._decode_column_type(column_type) @classmethod def _translate_type_from_mysql_to_sqlite( - cls, column_type: t.Union[str, bytes], sqlite_json1_extension_enabled=False + cls, column_type: t.Union[str, bytes], sqlite_json1_extension_enabled: bool = False ) -> str: - _column_type: str = cls._decode_column_type(column_type) - - # This could be optimized even further, however is seems adequate. - match: t.Optional[t.Match[str]] = cls._valid_column_type(_column_type) - if not match: - raise ValueError(f'"{_column_type}" is not a valid column_type!') - - data_type: str = match.group(0).upper() - - if data_type.endswith(" UNSIGNED"): - data_type = data_type.replace(" UNSIGNED", "") - - if data_type in { - "BIGINT", - "BLOB", - "BOOLEAN", - "DATE", - "DATETIME", - "DECIMAL", - "DOUBLE", - "FLOAT", - "INTEGER", - "MEDIUMINT", - "NUMERIC", - "REAL", - "SMALLINT", - "TIME", - "TINYINT", - "YEAR", - }: - return data_type - if data_type == "DOUBLE PRECISION": - return "DOUBLE" - if data_type == "FIXED": - return "DECIMAL" - if data_type in {"CHARACTER VARYING", "CHAR VARYING"}: - return "VARCHAR" + cls._column_type_length(_column_type) - if data_type in {"NATIONAL CHARACTER VARYING", "NATIONAL CHAR VARYING", "NATIONAL VARCHAR"}: - return "NVARCHAR" + cls._column_type_length(_column_type) - if data_type == "NATIONAL CHARACTER": - return "NCHAR" + cls._column_type_length(_column_type) - if data_type in { - "BIT", - "BINARY", - "LONGBLOB", - "MEDIUMBLOB", - "TINYBLOB", - "VARBINARY", - }: - return "BLOB" - if data_type in {"NCHAR", "NVARCHAR", "VARCHAR"}: - return data_type + cls._column_type_length(_column_type) - if data_type == "CHAR": - return "CHARACTER" + cls._column_type_length(_column_type) - if data_type == "INT": - return "INTEGER" - if data_type in "TIMESTAMP": - return "DATETIME" - if data_type == "JSON" and sqlite_json1_extension_enabled: - return "JSON" - # As a last resort, try sqlglot to derive a better SQLite-compatible type - sqlglot_type: t.Optional[str] = cls._transpile_mysql_type_to_sqlite( - _column_type, sqlite_json1_extension_enabled=sqlite_json1_extension_enabled + return _type_helpers.translate_type_from_mysql_to_sqlite( + column_type=column_type, + sqlite_json1_extension_enabled=sqlite_json1_extension_enabled, + decode_column_type=cls._decode_column_type, + valid_column_type=cls._valid_column_type, + transpile_mysql_type_to_sqlite=cls._transpile_mysql_type_to_sqlite, ) - if sqlglot_type: - return sqlglot_type - return "TEXT" @classmethod def _transpile_mysql_expr_to_sqlite(cls, expr_sql: str) -> t.Optional[str]: - """Transpile a MySQL scalar expression to SQLite using sqlglot. - - Returns the SQLite SQL string on success, or None on failure. - """ - cleaned: str = expr_sql.strip().rstrip(";") - try: - tree: Expression = parse_one(cleaned, read="mysql") - return tree.sql(dialect="sqlite") - except (ParseError, ValueError): - return None - except (AttributeError, TypeError): # pragma: no cover - unexpected sqlglot failure - logging.getLogger(cls.__name__ if hasattr(cls, "__name__") else "MySQLtoSQLite").debug( - "sqlglot failed to transpile expr: %r", expr_sql - ) - return None + return _type_helpers._transpile_mysql_expr_to_sqlite(expr_sql, parse_func=parse_one) @classmethod def _normalize_literal_with_sqlglot(cls, expr_sql: str) -> t.Optional[str]: - """Normalize a MySQL literal using sqlglot, returning SQLite SQL if literal-like.""" - cleaned: str = expr_sql.strip().rstrip(";") - try: - node: Expression = parse_one(cleaned, read="mysql") - except (ParseError, ValueError): - return None - if isinstance(node, exp.Literal): - return node.sql(dialect="sqlite") - if isinstance(node, exp.Paren) and isinstance(node.this, exp.Literal): - return node.this.sql(dialect="sqlite") - return None + return _type_helpers._normalize_literal_with_sqlglot(expr_sql) @staticmethod def _quote_sqlite_identifier(name: t.Union[str, bytes, bytearray]) -> str: - """Safely quote an identifier for SQLite using sqlglot. - - Always returns a double-quoted identifier, doubling any embedded quotes. - Accepts bytes and decodes them to str when needed. - """ - if isinstance(name, (bytes, bytearray)): - try: - s: str = name.decode() - except (UnicodeDecodeError, AttributeError): - s = str(name) - else: - s = str(name) - try: - # Normalize identifier using sqlglot, then wrap in quotes regardless - normalized: str = exp.to_identifier(s).name - except (AttributeError, ValueError, TypeError): # pragma: no cover - extremely unlikely - normalized = s - replaced: str = normalized.replace('"', '""') - return f'"{replaced}"' + return _type_helpers.quote_sqlite_identifier(name) @staticmethod def _escape_mysql_backticks(identifier: str) -> str: - """Escape backticks in a MySQL identifier for safe backtick quoting.""" - return identifier.replace("`", "``") + return _type_helpers.escape_mysql_backticks(identifier) @classmethod def _transpile_mysql_type_to_sqlite( cls, column_type: str, sqlite_json1_extension_enabled: bool = False ) -> t.Optional[str]: - """Attempt to derive a suitable SQLite column type using sqlglot. - - This is used as a last-resort fallback when the built-in mapper does not - recognize a MySQL type or synonym. It keeps existing behavior for known - types and only augments unknowns. - """ - # Wrap the type in a CAST expression so sqlglot can parse it consistently. - expr_sql: str = f"CAST(NULL AS {column_type.strip()})" - try: - tree: Expression = parse_one(expr_sql, read="mysql") - rendered: str = tree.sql(dialect="sqlite") - except (ParseError, ValueError, AttributeError, TypeError): - return None - - # Extract the type inside CAST(NULL AS ...) - m: t.Optional[t.Match[str]] = re.search(r"CAST\(NULL AS\s+([^)]+)\)", rendered, re.IGNORECASE) - if not m: - return None - extracted: str = m.group(1).strip() - upper: str = extracted.upper() - - # JSON handling: respect availability of JSON1 extension - if "JSON" in upper: - return "JSON" if sqlite_json1_extension_enabled else "TEXT" - - # Split out optional length suffix like (255) or (10,2) - base: str = upper - length_suffix: str = "" - paren: t.Optional[t.Match[str]] = re.match(r"^([A-Z ]+)(\(.*\))$", upper) - if paren: - base = paren.group(1).strip() - length_suffix = paren.group(2) - - # Minimal synonym normalization - synonyms: t.Dict[str, str] = { - "DOUBLE PRECISION": "DOUBLE", - "FIXED": "DECIMAL", - "CHAR VARYING": "VARCHAR", - "CHARACTER VARYING": "VARCHAR", - "NATIONAL VARCHAR": "NVARCHAR", - "NATIONAL CHARACTER VARYING": "NVARCHAR", - "NATIONAL CHAR VARYING": "NVARCHAR", - "NATIONAL CHARACTER": "NCHAR", - } - base = synonyms.get(base, base) - - # Decide the final SQLite type, aligning with existing conventions - if base in {"NCHAR", "NVARCHAR", "VARCHAR"} and length_suffix: - return f"{base}{length_suffix}" - if base in {"CHAR", "CHARACTER"}: - return f"CHARACTER{length_suffix}" - if base in {"DECIMAL", "NUMERIC"}: - # Keep without length to match the existing mapper's style - return base - if base in { - "DOUBLE", - "REAL", - "FLOAT", - "INTEGER", - "BIGINT", - "SMALLINT", - "MEDIUMINT", - "TINYINT", - "BLOB", - "DATE", - "DATETIME", - "TIME", - "YEAR", - "BOOLEAN", - }: - return base - if base in {"VARBINARY", "BINARY", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB"}: - return "BLOB" - if base in {"TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "CLOB"}: - return "TEXT" - - # ENUM/SET and other complex types -> keep default behavior (TEXT) - if base in {"ENUM", "SET"}: - return "TEXT" - - # If we reached here, we didn't find a better mapping - return None + return _type_helpers._transpile_mysql_type_to_sqlite( + column_type, + sqlite_json1_extension_enabled=sqlite_json1_extension_enabled, + parse_func=parse_one, + regex_search=re.search, + ) @classmethod def _translate_default_from_mysql_to_sqlite( @@ -459,187 +275,23 @@ def _translate_default_from_mysql_to_sqlite( column_type: t.Optional[str] = None, column_extra: RowItemType = None, ) -> str: - is_binary: bool - is_hex: bool - if isinstance(column_default, bytes): - if column_type in { - "BIT", - "BINARY", - "BLOB", - "LONGBLOB", - "MEDIUMBLOB", - "TINYBLOB", - "VARBINARY", - }: - if column_extra in {"DEFAULT_GENERATED", "default_generated"}: - for charset_introducer in CHARSET_INTRODUCERS: - if column_default.startswith(charset_introducer.encode()): - is_binary = False - is_hex = False - for b_prefix in ("B", "b"): - if column_default.startswith(rf"{charset_introducer} {b_prefix}\'".encode()): - is_binary = True - break - for x_prefix in ("X", "x"): - if column_default.startswith(rf"{charset_introducer} {x_prefix}\'".encode()): - is_hex = True - break - column_default = ( - column_default.replace(charset_introducer.encode(), b"") - .replace(rb"x\'", b"") - .replace(rb"X\'", b"") - .replace(rb"b\'", b"") - .replace(rb"B\'", b"") - .replace(rb"\'", b"") - .replace(rb"'", b"") - .strip() - ) - if is_binary: - return f"DEFAULT '{chr(int(column_default, 2))}'" - if is_hex: - return f"DEFAULT x'{column_default.decode()}'" - break - return f"DEFAULT x'{column_default.hex()}'" - try: - column_default = column_default.decode() - except (UnicodeDecodeError, AttributeError): - pass - if column_default is None: - return "" - if isinstance(column_default, bool): - if column_type == "BOOLEAN" and sqlite3.sqlite_version >= "3.23.0": - if column_default: - return "DEFAULT(TRUE)" - return "DEFAULT(FALSE)" - return f"DEFAULT '{int(column_default)}'" - if isinstance(column_default, str): - if column_default.lower() == "curtime()": - return "DEFAULT CURRENT_TIME" - if column_default.lower() == "curdate()": - return "DEFAULT CURRENT_DATE" - if column_default.lower() in {"current_timestamp()", "now()"}: - return "DEFAULT CURRENT_TIMESTAMP" - if column_extra in {"DEFAULT_GENERATED", "default_generated"}: - if column_default.upper() in { - "CURRENT_TIME", - "CURRENT_DATE", - "CURRENT_TIMESTAMP", - }: - return f"DEFAULT {column_default.upper()}" - for charset_introducer in CHARSET_INTRODUCERS: - if column_default.startswith(charset_introducer): - is_binary = False - is_hex = False - for b_prefix in ("B", "b"): - if column_default.startswith(rf"{charset_introducer} {b_prefix}\'"): - is_binary = True - break - for x_prefix in ("X", "x"): - if column_default.startswith(rf"{charset_introducer} {x_prefix}\'"): - is_hex = True - break - column_default = ( - column_default.replace(charset_introducer, "") - .replace(r"x\'", "") - .replace(r"X\'", "") - .replace(r"b\'", "") - .replace(r"B\'", "") - .replace(r"\'", "") - .replace(r"'", "") - .strip() - ) - if is_binary: - return f"DEFAULT '{chr(int(column_default, 2))}'" - if is_hex: - return f"DEFAULT x'{column_default}'" - return f"DEFAULT '{column_default}'" - transpiled: t.Optional[str] = cls._transpile_mysql_expr_to_sqlite(column_default) - if transpiled: - norm: str = transpiled.strip().rstrip(";") - upper: str = norm.upper() - if upper in {"CURRENT_TIME", "CURRENT_DATE", "CURRENT_TIMESTAMP"}: - return f"DEFAULT {upper}" - if upper == "NULL": - return "DEFAULT NULL" - # Allow blob hex literal X'..' - if re.match(r"^[Xx]'[0-9A-Fa-f]+'$", norm): - return f"DEFAULT {norm}" - # Support boolean tokens when provided as generated strings - if upper in {"TRUE", "FALSE"}: - if column_type == "BOOLEAN" and sqlite3.sqlite_version >= "3.23.0": - return f"DEFAULT({upper})" - return f"DEFAULT '{1 if upper == 'TRUE' else 0}'" - # Unwrap a single layer of parenthesis around a literal - if norm.startswith("(") and norm.endswith(")"): - inner = norm[1:-1].strip() - if (inner.startswith("'") and inner.endswith("'")) or re.match(r"^-?\d+(?:\.\d+)?$", inner): - return f"DEFAULT {inner}" - # If the expression is arithmetic-only over numeric literals, allow as-is - if re.match(r"^[\d\.\s\+\-\*/\(\)]+$", norm) and any(ch.isdigit() for ch in norm): - return f"DEFAULT {norm}" - # Allow numeric or single-quoted string literals as-is - if (norm.startswith("'") and norm.endswith("'")) or re.match(r"^-?\d+(?:\.\d+)?$", norm): - return f"DEFAULT {norm}" - # Allow simple arithmetic constant expressions composed of numbers and + - * / - if re.match(r"^[\d\.\s\+\-\*/\(\)]+$", norm) and any(ch.isdigit() for ch in norm): - return f"DEFAULT {norm}" - stripped_default = column_default.strip() - if stripped_default.startswith("'") or ( - stripped_default.startswith("(") and stripped_default.endswith(")") - ): - normalized_literal: t.Optional[str] = cls._normalize_literal_with_sqlglot(column_default) - if normalized_literal is not None: - return f"DEFAULT {normalized_literal}" - - # Fallback: robustly escape single quotes for plain string defaults - _escaped = column_default.replace("\\'", "'") - _escaped = _escaped.replace("'", "''") - return f"DEFAULT '{_escaped}'" - s = str(column_default) - s = s.replace("\\'", "'") - s = s.replace("'", "''") - return f"DEFAULT '{s}'" + return _type_helpers.translate_default_from_mysql_to_sqlite( + column_default=column_default, + column_type=column_type, + column_extra=column_extra, + normalize_literal=cls._normalize_literal_with_sqlglot, + transpile_expr=cls._transpile_mysql_expr_to_sqlite, + ) @classmethod def _data_type_collation_sequence( cls, collation: str = CollatingSequences.BINARY, column_type: t.Optional[str] = None ) -> str: - """Return a SQLite COLLATE clause for textual affinity types. - - Augmented with sqlglot: if the provided type string does not match the - quick textual prefixes, we attempt to transpile it to a SQLite type and - then apply SQLite's textual affinity rules (contains CHAR/CLOB/TEXT or - their NV*/VAR* variants). This improves handling of MySQL synonyms like - CHAR VARYING / CHARACTER VARYING / NATIONAL CHARACTER VARYING. - """ - if not column_type or collation == CollatingSequences.BINARY: - return "" - - ct: str = column_type.strip() - upper: str = ct.upper() - - # Fast-path for already normalized SQLite textual types - if upper.startswith(("CHARACTER", "NCHAR", "NVARCHAR", "TEXT", "VARCHAR")): - return f"COLLATE {collation}" - - # Avoid collations for JSON/BLOB explicitly - if "JSON" in upper or "BLOB" in upper: - return "" - - # If the type string obviously denotes text affinity, apply collation - if any(tok in upper for tok in ("VARCHAR", "NVARCHAR", "NCHAR", "CHAR", "TEXT", "CLOB", "CHARACTER")): - return f"COLLATE {collation}" - - # Try to map uncommon/synonym types to a SQLite type using sqlglot-based transpiler - mapped: t.Optional[str] = cls._transpile_mysql_type_to_sqlite(ct) - if mapped: - mu = mapped.upper() - if ( - "CHAR" in mu or "VARCHAR" in mu or "NCHAR" in mu or "NVARCHAR" in mu or "TEXT" in mu or "CLOB" in mu - ) and not ("JSON" in mu or "BLOB" in mu): - return f"COLLATE {collation}" - - return "" + return _type_helpers.data_type_collation_sequence( + collation=collation, + column_type=column_type, + transpile_mysql_type_to_sqlite=cls._transpile_mysql_type_to_sqlite, + ) def _check_sqlite_json1_extension_enabled(self) -> bool: try: @@ -648,24 +300,29 @@ def _check_sqlite_json1_extension_enabled(self) -> bool: except sqlite3.Error: return False - def _get_unique_index_name(self, base_name: str) -> str: - """Return a unique SQLite index name based on base_name. + def _get_schema_writer(self) -> SchemaWriter: + writer = getattr(self, "_schema_writer", None) + if writer is None: + writer = SchemaWriter(self) + self._schema_writer = writer + return writer + + def _get_data_transfer_manager(self) -> DataTransferManager: + manager = getattr(self, "_data_transfer", None) + if manager is None: + manager = DataTransferManager(self) + self._data_transfer = manager + return manager - If base_name has not been used yet, it is returned as-is and recorded. If it has been - used, a numeric suffix is appended starting from 2 (e.g., name_2, name_3, ...), and the - chosen name is recorded as used. This behavior is only intended for cases where index - prefixing is not enabled and SQLite requires global uniqueness for index names. - """ + def _get_unique_index_name(self, base_name: str) -> str: if base_name not in self._seen_sqlite_index_names: self._seen_sqlite_index_names.add(base_name) return base_name - # Base name already seen — assign next available counter next_num = self._sqlite_index_name_counters.get(base_name, 2) candidate = f"{base_name}_{next_num}" while candidate in self._seen_sqlite_index_names: next_num += 1 candidate = f"{base_name}_{next_num}" - # Record chosen candidate and bump counter for the base name self._seen_sqlite_index_names.add(candidate) self._sqlite_index_name_counters[base_name] = next_num + 1 self._logger.info( @@ -676,221 +333,7 @@ def _get_unique_index_name(self, base_name: str) -> str: return candidate def _build_create_table_sql(self, table_name: str) -> str: - table_ident = self._quote_sqlite_identifier(table_name) - sql: str = f"CREATE TABLE IF NOT EXISTS {table_ident} (" - primary: str = "" - indices: str = "" - - safe_table = self._escape_mysql_backticks(table_name) - self._mysql_cur_dict.execute(f"SHOW COLUMNS FROM `{safe_table}`") - rows: t.Sequence[t.Optional[t.Dict[str, RowItemType]]] = self._mysql_cur_dict.fetchall() - - primary_keys: int = sum(1 for row in rows if row is not None and row["Key"] == "PRI") - - for row in rows: - if row is not None: - column_type = self._translate_type_from_mysql_to_sqlite( - column_type=row["Type"], # type: ignore[arg-type] - sqlite_json1_extension_enabled=self._sqlite_json1_extension_enabled, - ) - if row["Key"] == "PRI" and row["Extra"] == "auto_increment" and primary_keys == 1: - if column_type in Integer_Types: - sql += "\n\t{name} INTEGER PRIMARY KEY AUTOINCREMENT,".format( - name=self._quote_sqlite_identifier( - str( - row["Field"].decode() - if isinstance(row["Field"], (bytes, bytearray)) - else row["Field"] - ) - ), - ) - else: - self._logger.warning( - 'Primary key "%s" in table "%s" is not an INTEGER type! Skipping.', - row["Field"], - table_name, - ) - else: - sql += "\n\t{name} {type} {notnull} {default} {collation},".format( - name=self._quote_sqlite_identifier( - str(row["Field"].decode() if isinstance(row["Field"], (bytes, bytearray)) else row["Field"]) - ), - type=column_type, - notnull="NULL" if row["Null"] == "YES" else "NOT NULL", - default=self._translate_default_from_mysql_to_sqlite(row["Default"], column_type, row["Extra"]), - collation=self._data_type_collation_sequence(self._collation, column_type), - ) - - self._mysql_cur_dict.execute( - """ - SELECT s.INDEX_NAME AS `name`, - IF (NON_UNIQUE = 0 AND s.INDEX_NAME = 'PRIMARY', 1, 0) AS `primary`, - IF (NON_UNIQUE = 0 AND s.INDEX_NAME <> 'PRIMARY', 1, 0) AS `unique`, - {auto_increment} - GROUP_CONCAT(s.COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS `columns`, - GROUP_CONCAT(c.COLUMN_TYPE ORDER BY SEQ_IN_INDEX) AS `types` - FROM information_schema.STATISTICS AS s - JOIN information_schema.COLUMNS AS c - ON s.TABLE_SCHEMA = c.TABLE_SCHEMA - AND s.TABLE_NAME = c.TABLE_NAME - AND s.COLUMN_NAME = c.COLUMN_NAME - WHERE s.TABLE_SCHEMA = %s - AND s.TABLE_NAME = %s - GROUP BY s.INDEX_NAME, s.NON_UNIQUE {group_by_extra} - """.format( - auto_increment=( - "IF (c.EXTRA = 'auto_increment', 1, 0) AS `auto_increment`," - if primary_keys == 1 - else "0 as `auto_increment`," - ), - group_by_extra=" ,c.EXTRA" if primary_keys == 1 else "", - ), - (self._mysql_database, table_name), - ) - mysql_indices: t.Sequence[t.Optional[t.Dict[str, RowItemType]]] = self._mysql_cur_dict.fetchall() - for index in mysql_indices: - if index is not None: - index_name: str - if isinstance(index["name"], bytes): - index_name = index["name"].decode() - elif isinstance(index["name"], str): - index_name = index["name"] - else: - index_name = str(index["name"]) - - # check if the index name collides with any table name - self._mysql_cur_dict.execute( - """ - SELECT COUNT(*) AS `count` - FROM information_schema.TABLES - WHERE TABLE_SCHEMA = %s - AND TABLE_NAME = %s - """, - (self._mysql_database, index_name), - ) - collision: t.Optional[t.Dict[str, RowItemType]] = self._mysql_cur_dict.fetchone() - table_collisions: int = 0 - if collision is not None: - table_collisions = int(collision["count"]) # type: ignore[arg-type] - - columns: str = "" - if isinstance(index["columns"], bytes): - columns = index["columns"].decode() - elif isinstance(index["columns"], str): - columns = index["columns"] - - types: str = "" - if isinstance(index["types"], bytes): - types = index["types"].decode() - elif isinstance(index["types"], str): - types = index["types"] - - if len(columns) > 0: - if index["primary"] in {1, "1"}: - if (index["auto_increment"] not in {1, "1"}) or any( - self._translate_type_from_mysql_to_sqlite( - column_type=_type, - sqlite_json1_extension_enabled=self._sqlite_json1_extension_enabled, - ) - not in Integer_Types - for _type in types.split(",") - ): - primary += "\n\tPRIMARY KEY ({columns})".format( - columns=", ".join( - self._quote_sqlite_identifier(column.strip()) for column in columns.split(",") - ) - ) - else: - # Determine the SQLite index name, considering table name collisions and prefix option - proposed_index_name = ( - f"{table_name}_{index_name}" - if (table_collisions > 0 or self._prefix_indices) - else index_name - ) - # Ensure index name is unique across the whole SQLite database when prefixing is disabled - if not self._prefix_indices: - unique_index_name = self._get_unique_index_name(proposed_index_name) - else: - unique_index_name = proposed_index_name - unique_kw = "UNIQUE " if index["unique"] in {1, "1"} else "" - indices += """CREATE {unique}INDEX IF NOT EXISTS {name} ON {table} ({columns});""".format( - unique=unique_kw, - name=self._quote_sqlite_identifier(unique_index_name), - table=self._quote_sqlite_identifier(table_name), - columns=", ".join( - self._quote_sqlite_identifier(column.strip()) for column in columns.split(",") - ), - ) - - sql += primary - sql = sql.rstrip(", ") - - if not self._without_tables and not self._without_foreign_keys: - server_version: t.Optional[t.Tuple[int, ...]] = self._mysql.get_server_version() - self._mysql_cur_dict.execute( - """ - SELECT k.COLUMN_NAME AS `column`, - k.REFERENCED_TABLE_NAME AS `ref_table`, - k.REFERENCED_COLUMN_NAME AS `ref_column`, - c.UPDATE_RULE AS `on_update`, - c.DELETE_RULE AS `on_delete` - FROM information_schema.TABLE_CONSTRAINTS AS i - {JOIN} information_schema.KEY_COLUMN_USAGE AS k - ON i.CONSTRAINT_NAME = k.CONSTRAINT_NAME - AND i.TABLE_NAME = k.TABLE_NAME - {JOIN} information_schema.REFERENTIAL_CONSTRAINTS AS c - ON c.CONSTRAINT_NAME = i.CONSTRAINT_NAME - AND c.TABLE_NAME = i.TABLE_NAME - WHERE i.TABLE_SCHEMA = %s - AND i.TABLE_NAME = %s - AND i.CONSTRAINT_TYPE = %s - GROUP BY i.CONSTRAINT_NAME, - k.COLUMN_NAME, - k.REFERENCED_TABLE_NAME, - k.REFERENCED_COLUMN_NAME, - c.UPDATE_RULE, - c.DELETE_RULE - """.format( - JOIN=( - "JOIN" - if (server_version is not None and server_version[0] == 8 and server_version[2] > 19) - else "LEFT JOIN" - ) - ), - (self._mysql_database, table_name, "FOREIGN KEY"), - ) - for foreign_key in self._mysql_cur_dict.fetchall(): - if foreign_key is not None: - col = self._quote_sqlite_identifier( - foreign_key["column"].decode() - if isinstance(foreign_key["column"], (bytes, bytearray)) - else str(foreign_key["column"]) # type: ignore[index] - ) - ref_table = self._quote_sqlite_identifier( - foreign_key["ref_table"].decode() - if isinstance(foreign_key["ref_table"], (bytes, bytearray)) - else str(foreign_key["ref_table"]) # type: ignore[index] - ) - ref_col = self._quote_sqlite_identifier( - foreign_key["ref_column"].decode() - if isinstance(foreign_key["ref_column"], (bytes, bytearray)) - else str(foreign_key["ref_column"]) # type: ignore[index] - ) - on_update = str(foreign_key["on_update"] or "NO ACTION").upper() # type: ignore[index] - on_delete = str(foreign_key["on_delete"] or "NO ACTION").upper() # type: ignore[index] - sql += ( - f",\n\tFOREIGN KEY({col}) REFERENCES {ref_table} ({ref_col}) " - f"ON UPDATE {on_update} " - f"ON DELETE {on_delete}" - ) - - sql += "\n)" - if self._sqlite_strict: - sql += " STRICT" - sql += ";\n" - sql += indices - - return sql + return self._get_schema_writer()._build_create_table_sql(table_name) def _create_table(self, table_name: str, attempting_reconnect: bool = False) -> None: try: @@ -904,9 +347,8 @@ def _create_table(self, table_name: str, attempting_reconnect: bool = False) -> self._logger.warning("Connection to MySQL server lost.\nAttempting to reconnect.") self._create_table(table_name, True) return - else: - self._logger.warning("Connection to MySQL server lost.\nReconnection attempt aborted.") - raise + self._logger.warning("Connection to MySQL server lost.\nReconnection attempt aborted.") + raise self._logger.error( "MySQL failed reading table definition from table %s: %s", table_name, @@ -918,28 +360,22 @@ def _create_table(self, table_name: str, attempting_reconnect: bool = False) -> raise def _mysql_viewdef_to_sqlite(self, view_select_sql: str, view_name: str) -> str: - """Convert a MySQL VIEW_DEFINITION (a SELECT ...) to a SQLite CREATE VIEW statement.""" - # Normalize whitespace and avoid double semicolons in output cleaned_sql = view_select_sql.strip().rstrip(";") try: tree: Expression = parse_one(cleaned_sql, read="mysql") except (ParseError, ValueError, AttributeError, TypeError): - # Fallback: try to remove schema qualifiers if requested, then return stripped_sql = cleaned_sql - # Remove qualifiers `schema`.tbl or "schema".tbl or schema.tbl sn: str = re.escape(self._mysql_database) for pat in (rf"`{sn}`\.", rf'"{sn}"\.', rf"\b{sn}\."): stripped_sql = re.sub(pat, "", stripped_sql, flags=re.IGNORECASE) view_ident = self._quote_sqlite_identifier(view_name) return f"CREATE VIEW IF NOT EXISTS {view_ident} AS\n{stripped_sql};" - # Remove schema qualifiers that match schema_name on tables for tbl in tree.find_all(exp.Table): db = tbl.args.get("db") if db and db.name.strip('`"').lower() == self._mysql_database.lower(): tbl.set("db", None) - # Also remove schema qualifiers on fully-qualified columns (db.table.column) for col in tree.find_all(exp.Column): db = col.args.get("db") if db and db.name.strip('`"').lower() == self._mysql_database.lower(): @@ -950,8 +386,6 @@ def _mysql_viewdef_to_sqlite(self, view_select_sql: str, view_name: str) -> str: return f"CREATE VIEW IF NOT EXISTS {view_ident} AS\n{sqlite_select};" def _build_create_view_sql(self, view_name: str) -> str: - """Build a CREATE VIEW statement for SQLite from a MySQL VIEW definition.""" - # Try to obtain the view definition from information_schema.VIEWS definition: t.Optional[str] = None try: self._mysql_cur_dict.execute( @@ -974,13 +408,10 @@ def _build_create_view_sql(self, view_name: str) -> str: else: definition = t.cast(str, val) except mysql.connector.Error: - # Fall back to SHOW CREATE VIEW below definition = None if not definition: - # Fallback: use SHOW CREATE VIEW and extract the SELECT part try: - # Escape backticks in the MySQL view name for safe interpolation safe_view_name = view_name.replace("`", "``") self._mysql_cur.execute(f"SHOW CREATE VIEW `{safe_view_name}`") res = self._mysql_cur.fetchone() @@ -993,13 +424,10 @@ def _build_create_view_sql(self, view_name: str) -> str: create_stmt_str = str(create_stmt) else: create_stmt_str = t.cast(str, create_stmt) - # Extract the SELECT ... part after AS (supporting newlines) m = re.search(r"\bAS\b\s*(.*)$", create_stmt_str, re.IGNORECASE | re.DOTALL) if m: definition = m.group(1).strip().rstrip(";") else: - # As a last resort, try to use the full statement replacing the prefix - # Not ideal, but better than failing outright idx = create_stmt_str.upper().find(" AS ") if idx != -1: definition = create_stmt_str[idx + 4 :].strip().rstrip(";") @@ -1027,9 +455,8 @@ def _create_view(self, view_name: str, attempting_reconnect: bool = False) -> No self._logger.warning("Connection to MySQL server lost.\nAttempting to reconnect.") self._create_view(view_name, True) return - else: - self._logger.warning("Connection to MySQL server lost.\nReconnection attempt aborted.") - raise + self._logger.warning("Connection to MySQL server lost.\nReconnection attempt aborted.") + raise self._logger.error( "MySQL failed reading view definition from view %s: %s", view_name, @@ -1043,63 +470,12 @@ def _create_view(self, view_name: str, attempting_reconnect: bool = False) -> No def _transfer_table_data( self, table_name: str, sql: str, total_records: int = 0, attempting_reconnect: bool = False ) -> None: - if attempting_reconnect: - self._mysql.reconnect() - try: - if self._chunk_size is not None and self._chunk_size > 0: - for chunk in trange( - self._current_chunk_number, - int(ceil(total_records / self._chunk_size)), - disable=self._quiet, - ): - self._current_chunk_number = chunk - self._sqlite_cur.executemany( - sql, - ( - tuple(encode_data_for_sqlite(col) if col is not None else None for col in row) - for row in self._mysql_cur.fetchmany(self._chunk_size) - ), - ) - else: - self._sqlite_cur.executemany( - sql, - ( - tuple(encode_data_for_sqlite(col) if col is not None else None for col in row) - for row in tqdm( - self._mysql_cur.fetchall(), - total=total_records, - disable=self._quiet, - ) - ), - ) - self._sqlite.commit() - except mysql.connector.Error as err: - if err.errno == errorcode.CR_SERVER_LOST: - if not attempting_reconnect: - self._logger.warning("Connection to MySQL server lost.\nAttempting to reconnect.") - self._transfer_table_data( - table_name=table_name, - sql=sql, - total_records=total_records, - attempting_reconnect=True, - ) - return - else: - self._logger.warning("Connection to MySQL server lost.\nReconnection attempt aborted.") - raise - self._logger.error( - "MySQL transfer failed reading table data from table %s: %s", - table_name, - err, - ) - raise - except sqlite3.Error as err: - self._logger.error( - "SQLite transfer failed inserting data into table %s: %s", - table_name, - err, - ) - raise + self._get_data_transfer_manager().transfer_table_data( + table_name=table_name, + sql=sql, + total_records=total_records, + attempting_reconnect=attempting_reconnect, + ) def transfer(self) -> None: """The primary and only method with which we transfer all the data.""" @@ -1223,7 +599,7 @@ def _coerce_row(row: t.Any) -> t.Tuple[str, str]: fields=('"{}", ' * len(columns)).rstrip(" ,").format(*columns), placeholders=("?, " * len(columns)).rstrip(" ,"), ) - self._transfer_table_data( + self._data_transfer.transfer_table_data( table_name=table_name, # type: ignore[arg-type] sql=sql, total_records=total_records_count, diff --git a/src/mysql_to_sqlite3/type_translation.py b/src/mysql_to_sqlite3/type_translation.py new file mode 100644 index 0000000..2c7a09a --- /dev/null +++ b/src/mysql_to_sqlite3/type_translation.py @@ -0,0 +1,437 @@ +"""Utilities for translating MySQL schema definitions to SQLite constructs.""" + +from __future__ import annotations + +import logging +import re +import sqlite3 +import typing as t + +from mysql.connector.types import RowItemType +from sqlglot import Expression, exp, parse_one +from sqlglot.errors import ParseError + +from mysql_to_sqlite3.mysql_utils import CHARSET_INTRODUCERS +from mysql_to_sqlite3.sqlite_utils import CollatingSequences + + +LOGGER = logging.getLogger(__name__) + +COLUMN_PATTERN: t.Pattern[str] = re.compile(r"^[^(]+") +COLUMN_LENGTH_PATTERN: t.Pattern[str] = re.compile(r"\(\d+\)$") + +ParseCallable = t.Callable[..., Expression] +RegexSearchCallable = t.Callable[..., t.Optional[t.Match[str]]] + + +def _valid_column_type(column_type: str) -> t.Optional[t.Match[str]]: + return COLUMN_PATTERN.match(column_type.strip()) + + +def _column_type_length(column_type: str) -> str: + suffix: t.Optional[t.Match[str]] = COLUMN_LENGTH_PATTERN.search(column_type) + if suffix: + return suffix.group(0) + return "" + + +def _decode_column_type(column_type: t.Union[str, bytes]) -> str: + if isinstance(column_type, str): + return column_type + if isinstance(column_type, bytes): + try: + return column_type.decode() + except (UnicodeDecodeError, AttributeError): + pass + return str(column_type) + + +def translate_type_from_mysql_to_sqlite( + column_type: t.Union[str, bytes], + sqlite_json1_extension_enabled: bool = False, + *, + decode_column_type: t.Optional[t.Callable[[t.Union[str, bytes]], str]] = None, + valid_column_type: t.Optional[t.Callable[[str], t.Optional[t.Match[str]]]] = None, + transpile_mysql_type_to_sqlite: t.Optional[t.Callable[[str, bool], t.Optional[str]]] = None, +) -> str: + """Return a SQLite column definition for a MySQL column type.""" + decoder = decode_column_type or _decode_column_type + validator = valid_column_type or _valid_column_type + transpiler = transpile_mysql_type_to_sqlite or _transpile_mysql_type_to_sqlite + _column_type: str = decoder(column_type) + + match: t.Optional[t.Match[str]] = validator(_column_type) + if not match: + raise ValueError(f'"{_column_type}" is not a valid column_type!') + + data_type: str = match.group(0).upper() + + if data_type.endswith(" UNSIGNED"): + data_type = data_type.replace(" UNSIGNED", "") + + if data_type in { + "BIGINT", + "BLOB", + "BOOLEAN", + "DATE", + "DATETIME", + "DECIMAL", + "DOUBLE", + "FLOAT", + "INTEGER", + "MEDIUMINT", + "NUMERIC", + "REAL", + "SMALLINT", + "TIME", + "TINYINT", + "YEAR", + }: + return data_type + if data_type == "DOUBLE PRECISION": + return "DOUBLE" + if data_type == "FIXED": + return "DECIMAL" + if data_type in {"CHARACTER VARYING", "CHAR VARYING"}: + return "VARCHAR" + _column_type_length(_column_type) + if data_type in {"NATIONAL CHARACTER VARYING", "NATIONAL CHAR VARYING", "NATIONAL VARCHAR"}: + return "NVARCHAR" + _column_type_length(_column_type) + if data_type == "NATIONAL CHARACTER": + return "NCHAR" + _column_type_length(_column_type) + if data_type in { + "BIT", + "BINARY", + "LONGBLOB", + "MEDIUMBLOB", + "TINYBLOB", + "VARBINARY", + }: + return "BLOB" + if data_type in {"NCHAR", "NVARCHAR", "VARCHAR"}: + return data_type + _column_type_length(_column_type) + if data_type == "CHAR": + return "CHARACTER" + _column_type_length(_column_type) + if data_type == "INT": + return "INTEGER" + if data_type in "TIMESTAMP": + return "DATETIME" + if data_type == "JSON" and sqlite_json1_extension_enabled: + return "JSON" + sqlglot_type: t.Optional[str] = transpiler(_column_type, sqlite_json1_extension_enabled) + if sqlglot_type: + return sqlglot_type + return "TEXT" + + +def _transpile_mysql_expr_to_sqlite(expr_sql: str, parse_func: t.Optional[ParseCallable] = None) -> t.Optional[str]: + """Transpile a MySQL scalar expression to SQLite using sqlglot.""" + cleaned: str = expr_sql.strip().rstrip(";") + parser = parse_func or parse_one + try: + tree: Expression = parser(cleaned, read="mysql") + return tree.sql(dialect="sqlite") + except (ParseError, ValueError): + return None + except (AttributeError, TypeError): # pragma: no cover - unexpected sqlglot failure + LOGGER.debug("sqlglot failed to transpile expr: %r", expr_sql) + return None + + +def _normalize_literal_with_sqlglot(expr_sql: str) -> t.Optional[str]: + """Normalize a MySQL literal using sqlglot, returning SQLite SQL if literal-like.""" + cleaned: str = expr_sql.strip().rstrip(";") + try: + node: Expression = parse_one(cleaned, read="mysql") + except (ParseError, ValueError): + return None + if isinstance(node, exp.Literal): + return node.sql(dialect="sqlite") + if isinstance(node, exp.Paren) and isinstance(node.this, exp.Literal): + return node.this.sql(dialect="sqlite") + return None + + +def quote_sqlite_identifier(name: t.Union[str, bytes, bytearray]) -> str: + """Safely quote an identifier for SQLite using sqlglot.""" + if isinstance(name, (bytes, bytearray)): + try: + s: str = name.decode() + except (UnicodeDecodeError, AttributeError): + s = str(name) + else: + s = str(name) + try: + normalized: str = exp.to_identifier(s).name + except (AttributeError, ValueError, TypeError): # pragma: no cover - extremely unlikely + normalized = s + replaced: str = normalized.replace('"', '""') + return f'"{replaced}"' + + +def escape_mysql_backticks(identifier: str) -> str: + """Escape backticks in a MySQL identifier for safe backtick quoting.""" + return identifier.replace("`", "``") + + +def _transpile_mysql_type_to_sqlite( + column_type: str, + sqlite_json1_extension_enabled: bool = False, + *, + parse_func: t.Optional[ParseCallable] = None, + regex_search: t.Optional[RegexSearchCallable] = None, +) -> t.Optional[str]: + """Attempt to derive a suitable SQLite column type using sqlglot.""" + expr_sql: str = f"CAST(NULL AS {column_type.strip()})" + parser = parse_func or parse_one + search = regex_search or re.search + try: + tree: Expression = parser(expr_sql, read="mysql") + rendered: str = tree.sql(dialect="sqlite") + except (ParseError, ValueError, AttributeError, TypeError): + return None + + m: t.Optional[t.Match[str]] = search(r"CAST\(NULL AS\s+([^)]+)\)", rendered, re.IGNORECASE) + if not m: + return None + extracted: str = m.group(1).strip() + upper: str = extracted.upper() + + if "JSON" in upper: + return "JSON" if sqlite_json1_extension_enabled else "TEXT" + + base: str = upper + length_suffix: str = "" + paren: t.Optional[t.Match[str]] = re.match(r"^([A-Z ]+)(\(.*\))$", upper) + if paren: + base = paren.group(1).strip() + length_suffix = paren.group(2) + + synonyms: t.Dict[str, str] = { + "DOUBLE PRECISION": "DOUBLE", + "FIXED": "DECIMAL", + "CHAR VARYING": "VARCHAR", + "CHARACTER VARYING": "VARCHAR", + "NATIONAL VARCHAR": "NVARCHAR", + "NATIONAL CHARACTER VARYING": "NVARCHAR", + "NATIONAL CHAR VARYING": "NVARCHAR", + "NATIONAL CHARACTER": "NCHAR", + } + base = synonyms.get(base, base) + + if base in {"NCHAR", "NVARCHAR", "VARCHAR"} and length_suffix: + return f"{base}{length_suffix}" + if base in {"CHAR", "CHARACTER"}: + return f"CHARACTER{length_suffix}" + if base in {"DECIMAL", "NUMERIC"}: + return base + if base in { + "DOUBLE", + "REAL", + "FLOAT", + "INTEGER", + "BIGINT", + "SMALLINT", + "MEDIUMINT", + "TINYINT", + "BLOB", + "DATE", + "DATETIME", + "TIME", + "YEAR", + "BOOLEAN", + }: + return base + if base in {"VARBINARY", "BINARY", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB"}: + return "BLOB" + if base in {"TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "CLOB"}: + return "TEXT" + if base in {"ENUM", "SET"}: + return "TEXT" + return None + + +def translate_default_from_mysql_to_sqlite( + column_default: RowItemType = None, + column_type: t.Optional[str] = None, + column_extra: RowItemType = None, + *, + normalize_literal: t.Optional[t.Callable[[str], t.Optional[str]]] = None, + transpile_expr: t.Optional[t.Callable[[str], t.Optional[str]]] = None, +) -> str: + """Render a DEFAULT clause suitable for SQLite.""" + normalizer = normalize_literal or _normalize_literal_with_sqlglot + expr_transpiler = transpile_expr or _transpile_mysql_expr_to_sqlite + is_binary: bool + is_hex: bool + if isinstance(column_default, bytes): + if column_type in { + "BIT", + "BINARY", + "BLOB", + "LONGBLOB", + "MEDIUMBLOB", + "TINYBLOB", + "VARBINARY", + }: + if column_extra in {"DEFAULT_GENERATED", "default_generated"}: + for charset_introducer in CHARSET_INTRODUCERS: + if column_default.startswith(charset_introducer.encode()): + is_binary = False + is_hex = False + for b_prefix in ("B", "b"): + if column_default.startswith(rf"{charset_introducer} {b_prefix}\'".encode()): + is_binary = True + break + for x_prefix in ("X", "x"): + if column_default.startswith(rf"{charset_introducer} {x_prefix}\'".encode()): + is_hex = True + break + column_default = ( + column_default.replace(charset_introducer.encode(), b"") + .replace(rb"x\'", b"") + .replace(rb"X\'", b"") + .replace(rb"b\'", b"") + .replace(rb"B\'", b"") + .replace(rb"\'", b"") + .replace(rb"'", b"") + .strip() + ) + if is_binary: + return f"DEFAULT '{chr(int(column_default, 2))}'" + if is_hex: + return f"DEFAULT x'{column_default.decode()}'" + break + return f"DEFAULT x'{column_default.hex()}'" + try: + column_default = column_default.decode() + except (UnicodeDecodeError, AttributeError): + pass + if column_default is None: + return "" + if isinstance(column_default, bool): + if column_type == "BOOLEAN" and sqlite3.sqlite_version >= "3.23.0": + if column_default: + return "DEFAULT(TRUE)" + return "DEFAULT(FALSE)" + return f"DEFAULT '{int(column_default)}'" + if isinstance(column_default, str): + if column_default.lower() == "curtime()": + return "DEFAULT CURRENT_TIME" + if column_default.lower() == "curdate()": + return "DEFAULT CURRENT_DATE" + if column_default.lower() in {"current_timestamp()", "now()"}: + return "DEFAULT CURRENT_TIMESTAMP" + if column_extra in {"DEFAULT_GENERATED", "default_generated"}: + if column_default.upper() in { + "CURRENT_TIME", + "CURRENT_DATE", + "CURRENT_TIMESTAMP", + }: + return f"DEFAULT {column_default.upper()}" + for charset_introducer in CHARSET_INTRODUCERS: + if column_default.startswith(charset_introducer): + is_binary = False + is_hex = False + for b_prefix in ("B", "b"): + if column_default.startswith(rf"{charset_introducer} {b_prefix}\'"): + is_binary = True + break + for x_prefix in ("X", "x"): + if column_default.startswith(rf"{charset_introducer} {x_prefix}\'"): + is_hex = True + break + column_default = ( + column_default.replace(charset_introducer, "") + .replace(r"x\'", "") + .replace(r"X\'", "") + .replace(r"b\'", "") + .replace(r"B\'", "") + .replace(r"\'", "") + .replace(r"'", "") + .strip() + ) + if is_binary: + return f"DEFAULT '{chr(int(column_default, 2))}'" + if is_hex: + return f"DEFAULT x'{column_default}'" + return f"DEFAULT '{column_default}'" + transpiled: t.Optional[str] = expr_transpiler(column_default) + if transpiled: + norm: str = transpiled.strip().rstrip(";") + upper: str = norm.upper() + if upper in {"CURRENT_TIME", "CURRENT_DATE", "CURRENT_TIMESTAMP"}: + return f"DEFAULT {upper}" + if upper == "NULL": + return "DEFAULT NULL" + if re.match(r"^[Xx]'[0-9A-Fa-f]+'$", norm): + return f"DEFAULT {norm}" + if upper in {"TRUE", "FALSE"}: + if column_type == "BOOLEAN" and sqlite3.sqlite_version >= "3.23.0": + return f"DEFAULT({upper})" + return f"DEFAULT '{1 if upper == 'TRUE' else 0}'" + if norm.startswith("(") and norm.endswith(")"): + inner = norm[1:-1].strip() + if (inner.startswith("'") and inner.endswith("'")) or re.match(r"^-?\d+(?:\.\d+)?$", inner): + return f"DEFAULT {inner}" + if re.match(r"^[\d\.\s\+\-\*/\(\)]+$", norm) and any(ch.isdigit() for ch in norm): + return f"DEFAULT {norm}" + if (norm.startswith("'") and norm.endswith("'")) or re.match(r"^-?\d+(?:\.\d+)?$", norm): + return f"DEFAULT {norm}" + if re.match(r"^[\d\.\s\+\-\*/\(\)]+$", norm) and any(ch.isdigit() for ch in norm): + return f"DEFAULT {norm}" + stripped_default = column_default.strip() + if stripped_default.startswith("'") or (stripped_default.startswith("(") and stripped_default.endswith(")")): + normalized_literal: t.Optional[str] = normalizer(column_default) + if normalized_literal is not None: + return f"DEFAULT {normalized_literal}" + + _escaped = column_default.replace("\\'", "'") + _escaped = _escaped.replace("'", "''") + return f"DEFAULT '{_escaped}'" + s = str(column_default) + s = s.replace("\\'", "'") + s = s.replace("'", "''") + return f"DEFAULT '{s}'" + + +def data_type_collation_sequence( + collation: str = CollatingSequences.BINARY, + column_type: t.Optional[str] = None, + *, + transpile_mysql_type_to_sqlite: t.Optional[t.Callable[[str, bool], t.Optional[str]]] = None, +) -> str: + """Return a SQLite COLLATE clause for textual affinity types.""" + if not column_type or collation == CollatingSequences.BINARY: + return "" + + ct: str = column_type.strip() + upper: str = ct.upper() + + if upper.startswith(("CHARACTER", "NCHAR", "NVARCHAR", "TEXT", "VARCHAR")): + return f"COLLATE {collation}" + + if "JSON" in upper or "BLOB" in upper: + return "" + + if any(tok in upper for tok in ("VARCHAR", "NVARCHAR", "NCHAR", "CHAR", "TEXT", "CLOB", "CHARACTER")): + return f"COLLATE {collation}" + + transpiler = transpile_mysql_type_to_sqlite or _transpile_mysql_type_to_sqlite + mapped: t.Optional[str] = transpiler(ct, False) + if mapped: + mu = mapped.upper() + if ( + "CHAR" in mu or "VARCHAR" in mu or "NCHAR" in mu or "NVARCHAR" in mu or "TEXT" in mu or "CLOB" in mu + ) and not ("JSON" in mu or "BLOB" in mu): + return f"COLLATE {collation}" + + return "" + + +__all__ = [ + "data_type_collation_sequence", + "escape_mysql_backticks", + "quote_sqlite_identifier", + "translate_default_from_mysql_to_sqlite", + "translate_type_from_mysql_to_sqlite", +] diff --git a/src/mysql_to_sqlite3/types.py b/src/mysql_to_sqlite3/types.py index 8fe99b2..b32fda2 100644 --- a/src/mysql_to_sqlite3/types.py +++ b/src/mysql_to_sqlite3/types.py @@ -9,6 +9,11 @@ from mysql.connector.cursor import MySQLCursorDict, MySQLCursorPrepared, MySQLCursorRaw +if t.TYPE_CHECKING: + from mysql_to_sqlite3.data_transfer import DataTransferManager + from mysql_to_sqlite3.schema_writer import SchemaWriter + + try: # Python 3.11+ from typing import TypedDict # type: ignore[attr-defined] @@ -86,3 +91,5 @@ class MySQLtoSQLiteAttributes: # Tracking of SQLite index names and counters to ensure uniqueness when prefixing is disabled _seen_sqlite_index_names: t.Set[str] _sqlite_index_name_counters: t.Dict[str, int] + _schema_writer: "SchemaWriter" + _data_transfer: "DataTransferManager" diff --git a/tox.ini b/tox.ini index 8f6b995..c928273 100644 --- a/tox.ini +++ b/tox.ini @@ -103,7 +103,7 @@ commands = {[testenv:mypy]commands} [flake8] -ignore = I100,I201,I202,D203,D401,W503,E203,F401,F403,C901,E501 +ignore = I100,I201,I202,I300,D203,D401,W503,E203,F401,F403,C901,E501 exclude = *__init__.py *__version__.py