diff --git a/src/mysql_to_sqlite3/transporter.py b/src/mysql_to_sqlite3/transporter.py index 1778d6a..bea3f86 100644 --- a/src/mysql_to_sqlite3/transporter.py +++ b/src/mysql_to_sqlite3/transporter.py @@ -15,7 +15,7 @@ from mysql.connector import CharacterSet, errorcode from mysql.connector.abstracts import MySQLConnectionAbstract from mysql.connector.types import RowItemType -from sqlglot import exp, parse_one +from sqlglot import Expression, exp, parse_one from sqlglot.errors import ParseError from tqdm import tqdm, trange @@ -269,6 +269,16 @@ def _translate_type_from_mysql_to_sqlite( "YEAR", }: return data_type + if data_type == "DOUBLE PRECISION": + return "DOUBLE" + if data_type == "FIXED": + return "DECIMAL" + if data_type in {"CHARACTER VARYING", "CHAR VARYING"}: + return "VARCHAR" + cls._column_type_length(_column_type) + if data_type in {"NATIONAL CHARACTER VARYING", "NATIONAL CHAR VARYING", "NATIONAL VARCHAR"}: + return "NVARCHAR" + cls._column_type_length(_column_type) + if data_type == "NATIONAL CHARACTER": + return "NCHAR" + cls._column_type_length(_column_type) if data_type in { "BIT", "BINARY", @@ -288,8 +298,146 @@ def _translate_type_from_mysql_to_sqlite( return "DATETIME" if data_type == "JSON" and sqlite_json1_extension_enabled: return "JSON" + # As a last resort, try sqlglot to derive a better SQLite-compatible type + sqlglot_type: t.Optional[str] = cls._transpile_mysql_type_to_sqlite( + _column_type, sqlite_json1_extension_enabled=sqlite_json1_extension_enabled + ) + if sqlglot_type: + return sqlglot_type return "TEXT" + @classmethod + def _transpile_mysql_expr_to_sqlite(cls, expr_sql: str) -> t.Optional[str]: + """Transpile a MySQL scalar expression to SQLite using sqlglot. + + Returns the SQLite SQL string on success, or None on failure. + """ + cleaned: str = expr_sql.strip().rstrip(";") + try: + tree: Expression = parse_one(cleaned, read="mysql") + return tree.sql(dialect="sqlite") + except (ParseError, ValueError): + return None + except (AttributeError, TypeError): # pragma: no cover - unexpected sqlglot failure + logging.getLogger(cls.__name__ if hasattr(cls, "__name__") else "MySQLtoSQLite").debug( + "sqlglot failed to transpile expr: %r", expr_sql + ) + return None + + @staticmethod + def _quote_sqlite_identifier(name: t.Union[str, bytes, bytearray]) -> str: + """Safely quote an identifier for SQLite using sqlglot. + + Always returns a double-quoted identifier, doubling any embedded quotes. + Accepts bytes and decodes them to str when needed. + """ + if isinstance(name, (bytes, bytearray)): + try: + s: str = name.decode() + except (UnicodeDecodeError, AttributeError): + s = str(name) + else: + s = str(name) + try: + # Normalize identifier using sqlglot, then wrap in quotes regardless + normalized: str = exp.to_identifier(s).name + except (AttributeError, ValueError, TypeError): # pragma: no cover - extremely unlikely + normalized = s + replaced: str = normalized.replace('"', '""') + return f'"{replaced}"' + + @staticmethod + def _escape_mysql_backticks(identifier: str) -> str: + """Escape backticks in a MySQL identifier for safe backtick quoting.""" + return identifier.replace("`", "``") + + @classmethod + def _transpile_mysql_type_to_sqlite( + cls, column_type: str, sqlite_json1_extension_enabled: bool = False + ) -> t.Optional[str]: + """Attempt to derive a suitable SQLite column type using sqlglot. + + This is used as a last-resort fallback when the built-in mapper does not + recognize a MySQL type or synonym. It keeps existing behavior for known + types and only augments unknowns. + """ + # Wrap the type in a CAST expression so sqlglot can parse it consistently. + expr_sql: str = f"CAST(NULL AS {column_type.strip()})" + try: + tree: Expression = parse_one(expr_sql, read="mysql") + rendered: str = tree.sql(dialect="sqlite") + except (ParseError, ValueError, AttributeError, TypeError): + return None + + # Extract the type inside CAST(NULL AS ...) + m: t.Optional[t.Match[str]] = re.search(r"CAST\(NULL AS\s+([^)]+)\)", rendered, re.IGNORECASE) + if not m: + return None + extracted: str = m.group(1).strip() + upper: str = extracted.upper() + + # JSON handling: respect availability of JSON1 extension + if "JSON" in upper: + return "JSON" if sqlite_json1_extension_enabled else "TEXT" + + # Split out optional length suffix like (255) or (10,2) + base: str = upper + length_suffix: str = "" + paren: t.Optional[t.Match[str]] = re.match(r"^([A-Z ]+)(\(.*\))$", upper) + if paren: + base = paren.group(1).strip() + length_suffix = paren.group(2) + + # Minimal synonym normalization + synonyms: t.Dict[str, str] = { + "DOUBLE PRECISION": "DOUBLE", + "FIXED": "DECIMAL", + "CHAR VARYING": "VARCHAR", + "CHARACTER VARYING": "VARCHAR", + "NATIONAL VARCHAR": "NVARCHAR", + "NATIONAL CHARACTER VARYING": "NVARCHAR", + "NATIONAL CHAR VARYING": "NVARCHAR", + "NATIONAL CHARACTER": "NCHAR", + } + base = synonyms.get(base, base) + + # Decide the final SQLite type, aligning with existing conventions + if base in {"NCHAR", "NVARCHAR", "VARCHAR"} and length_suffix: + return f"{base}{length_suffix}" + if base in {"CHAR", "CHARACTER"}: + return f"CHARACTER{length_suffix}" + if base in {"DECIMAL", "NUMERIC"}: + # Keep without length to match the existing mapper's style + return base + if base in { + "DOUBLE", + "REAL", + "FLOAT", + "INTEGER", + "BIGINT", + "SMALLINT", + "MEDIUMINT", + "TINYINT", + "BLOB", + "DATE", + "DATETIME", + "TIME", + "YEAR", + "BOOLEAN", + }: + return base + if base in {"VARBINARY", "BINARY", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB"}: + return "BLOB" + if base in {"TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "CLOB"}: + return "TEXT" + + # ENUM/SET and other complex types -> keep default behavior (TEXT) + if base in {"ENUM", "SET"}: + return "TEXT" + + # If we reached here, we didn't find a better mapping + return None + @classmethod def _translate_default_from_mysql_to_sqlite( cls, @@ -391,24 +539,84 @@ def _translate_default_from_mysql_to_sqlite( if is_hex: return f"DEFAULT x'{column_default}'" return f"DEFAULT '{column_default}'" - return "DEFAULT '{}'".format(column_default.replace(r"\'", r"''")) - return "DEFAULT '{}'".format(str(column_default).replace(r"\'", r"''")) + transpiled: t.Optional[str] = cls._transpile_mysql_expr_to_sqlite(column_default) + if transpiled: + norm: str = transpiled.strip().rstrip(";") + upper: str = norm.upper() + if upper in {"CURRENT_TIME", "CURRENT_DATE", "CURRENT_TIMESTAMP"}: + return f"DEFAULT {upper}" + if upper == "NULL": + return "DEFAULT NULL" + # Allow blob hex literal X'..' + if re.match(r"^[Xx]'[0-9A-Fa-f]+'$", norm): + return f"DEFAULT {norm}" + # Support boolean tokens when provided as generated strings + if upper in {"TRUE", "FALSE"}: + if column_type == "BOOLEAN" and sqlite3.sqlite_version >= "3.23.0": + return f"DEFAULT({upper})" + return f"DEFAULT '{1 if upper == 'TRUE' else 0}'" + # Unwrap a single layer of parenthesis around a literal + if norm.startswith("(") and norm.endswith(")"): + inner = norm[1:-1].strip() + if (inner.startswith("'") and inner.endswith("'")) or re.match(r"^-?\d+(?:\.\d+)?$", inner): + return f"DEFAULT {inner}" + # If the expression is arithmetic-only over numeric literals, allow as-is + if re.match(r"^[\d\.\s\+\-\*/\(\)]+$", norm) and any(ch.isdigit() for ch in norm): + return f"DEFAULT {norm}" + # Allow numeric or single-quoted string literals as-is + if (norm.startswith("'") and norm.endswith("'")) or re.match(r"^-?\d+(?:\.\d+)?$", norm): + return f"DEFAULT {norm}" + # Allow simple arithmetic constant expressions composed of numbers and + - * / + if re.match(r"^[\d\.\s\+\-\*/\(\)]+$", norm) and any(ch.isdigit() for ch in norm): + return f"DEFAULT {norm}" + # Robustly escape single quotes for plain string defaults + _escaped = column_default.replace("\\'", "'") + _escaped = _escaped.replace("'", "''") + return f"DEFAULT '{_escaped}'" + s = str(column_default) + s = s.replace("\\'", "'") + s = s.replace("'", "''") + return f"DEFAULT '{s}'" @classmethod def _data_type_collation_sequence( cls, collation: str = CollatingSequences.BINARY, column_type: t.Optional[str] = None ) -> str: - if column_type and collation != CollatingSequences.BINARY: - if column_type.startswith( - ( - "CHARACTER", - "NCHAR", - "NVARCHAR", - "TEXT", - "VARCHAR", - ) - ): + """Return a SQLite COLLATE clause for textual affinity types. + + Augmented with sqlglot: if the provided type string does not match the + quick textual prefixes, we attempt to transpile it to a SQLite type and + then apply SQLite's textual affinity rules (contains CHAR/CLOB/TEXT or + their NV*/VAR* variants). This improves handling of MySQL synonyms like + CHAR VARYING / CHARACTER VARYING / NATIONAL CHARACTER VARYING. + """ + if not column_type or collation == CollatingSequences.BINARY: + return "" + + ct: str = column_type.strip() + upper: str = ct.upper() + + # Fast-path for already normalized SQLite textual types + if upper.startswith(("CHARACTER", "NCHAR", "NVARCHAR", "TEXT", "VARCHAR")): + return f"COLLATE {collation}" + + # Avoid collations for JSON/BLOB explicitly + if "JSON" in upper or "BLOB" in upper: + return "" + + # If the type string obviously denotes text affinity, apply collation + if any(tok in upper for tok in ("VARCHAR", "NVARCHAR", "NCHAR", "CHAR", "TEXT", "CLOB", "CHARACTER")): + return f"COLLATE {collation}" + + # Try to map uncommon/synonym types to a SQLite type using sqlglot-based transpiler + mapped: t.Optional[str] = cls._transpile_mysql_type_to_sqlite(ct) + if mapped: + mu = mapped.upper() + if ( + "CHAR" in mu or "VARCHAR" in mu or "NCHAR" in mu or "NVARCHAR" in mu or "TEXT" in mu or "CLOB" in mu + ) and not ("JSON" in mu or "BLOB" in mu): return f"COLLATE {collation}" + return "" def _check_sqlite_json1_extension_enabled(self) -> bool: @@ -446,11 +654,13 @@ def _get_unique_index_name(self, base_name: str) -> str: return candidate def _build_create_table_sql(self, table_name: str) -> str: - sql: str = f'CREATE TABLE IF NOT EXISTS "{table_name}" (' + table_ident = self._quote_sqlite_identifier(table_name) + sql: str = f"CREATE TABLE IF NOT EXISTS {table_ident} (" primary: str = "" indices: str = "" - self._mysql_cur_dict.execute(f"SHOW COLUMNS FROM `{table_name}`") + safe_table = self._escape_mysql_backticks(table_name) + self._mysql_cur_dict.execute(f"SHOW COLUMNS FROM `{safe_table}`") rows: t.Sequence[t.Optional[t.Dict[str, RowItemType]]] = self._mysql_cur_dict.fetchall() primary_keys: int = sum(1 for row in rows if row is not None and row["Key"] == "PRI") @@ -463,8 +673,14 @@ def _build_create_table_sql(self, table_name: str) -> str: ) if row["Key"] == "PRI" and row["Extra"] == "auto_increment" and primary_keys == 1: if column_type in Integer_Types: - sql += '\n\t"{name}" INTEGER PRIMARY KEY AUTOINCREMENT,'.format( - name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"], + sql += "\n\t{name} INTEGER PRIMARY KEY AUTOINCREMENT,".format( + name=self._quote_sqlite_identifier( + str( + row["Field"].decode() + if isinstance(row["Field"], (bytes, bytearray)) + else row["Field"] + ) + ), ) else: self._logger.warning( @@ -473,8 +689,10 @@ def _build_create_table_sql(self, table_name: str) -> str: table_name, ) else: - sql += '\n\t"{name}" {type} {notnull} {default} {collation},'.format( - name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"], + sql += "\n\t{name} {type} {notnull} {default} {collation},".format( + name=self._quote_sqlite_identifier( + str(row["Field"].decode() if isinstance(row["Field"], (bytes, bytearray)) else row["Field"]) + ), type=column_type, notnull="NULL" if row["Null"] == "YES" else "NOT NULL", default=self._translate_default_from_mysql_to_sqlite(row["Default"], column_type, row["Extra"]), @@ -556,7 +774,9 @@ def _build_create_table_sql(self, table_name: str) -> str: for _type in types.split(",") ): primary += "\n\tPRIMARY KEY ({columns})".format( - columns=", ".join(f'"{column}"' for column in columns.split(",")) + columns=", ".join( + self._quote_sqlite_identifier(column.strip()) for column in columns.split(",") + ) ) else: # Determine the SQLite index name, considering table name collisions and prefix option @@ -570,11 +790,14 @@ def _build_create_table_sql(self, table_name: str) -> str: unique_index_name = self._get_unique_index_name(proposed_index_name) else: unique_index_name = proposed_index_name - indices += """CREATE {unique} INDEX IF NOT EXISTS "{name}" ON "{table}" ({columns});""".format( - unique="UNIQUE" if index["unique"] in {1, "1"} else "", - name=unique_index_name, - table=table_name, - columns=", ".join(f'"{column}"' for column in columns.split(",")), + unique_kw = "UNIQUE " if index["unique"] in {1, "1"} else "" + indices += """CREATE {unique}INDEX IF NOT EXISTS {name} ON {table} ({columns});""".format( + unique=unique_kw, + name=self._quote_sqlite_identifier(unique_index_name), + table=self._quote_sqlite_identifier(table_name), + columns=", ".join( + self._quote_sqlite_identifier(column.strip()) for column in columns.split(",") + ), ) sql += primary @@ -616,10 +839,27 @@ def _build_create_table_sql(self, table_name: str) -> str: ) for foreign_key in self._mysql_cur_dict.fetchall(): if foreign_key is not None: + col = self._quote_sqlite_identifier( + foreign_key["column"].decode() + if isinstance(foreign_key["column"], (bytes, bytearray)) + else str(foreign_key["column"]) # type: ignore[index] + ) + ref_table = self._quote_sqlite_identifier( + foreign_key["ref_table"].decode() + if isinstance(foreign_key["ref_table"], (bytes, bytearray)) + else str(foreign_key["ref_table"]) # type: ignore[index] + ) + ref_col = self._quote_sqlite_identifier( + foreign_key["ref_column"].decode() + if isinstance(foreign_key["ref_column"], (bytes, bytearray)) + else str(foreign_key["ref_column"]) # type: ignore[index] + ) + on_update = str(foreign_key["on_update"] or "NO ACTION").upper() # type: ignore[index] + on_delete = str(foreign_key["on_delete"] or "NO ACTION").upper() # type: ignore[index] sql += ( - ',\n\tFOREIGN KEY("{column}") REFERENCES "{ref_table}" ("{ref_column}") ' - "ON UPDATE {on_update} " - "ON DELETE {on_delete}".format(**foreign_key) # type: ignore[str-bytes-safe] + f",\n\tFOREIGN KEY({col}) REFERENCES {ref_table} ({ref_col}) " + f"ON UPDATE {on_update} " + f"ON DELETE {on_delete}" ) sql += "\n)" @@ -925,13 +1165,15 @@ def _coerce_row(row: t.Any) -> t.Tuple[str, str]: # get the size of the data if self._limit_rows > 0: # limit to the requested number of rows + safe_table = self._escape_mysql_backticks(table_name) self._mysql_cur_dict.execute( "SELECT COUNT(*) AS `total_records` " - f"FROM (SELECT * FROM `{table_name}` LIMIT {self._limit_rows}) AS `table`" + f"FROM (SELECT * FROM `{safe_table}` LIMIT {self._limit_rows}) AS `table`" ) else: # get all rows - self._mysql_cur_dict.execute(f"SELECT COUNT(*) AS `total_records` FROM `{table_name}`") + safe_table = self._escape_mysql_backticks(table_name) + self._mysql_cur_dict.execute(f"SELECT COUNT(*) AS `total_records` FROM `{safe_table}`") total_records: t.Optional[t.Dict[str, RowItemType]] = self._mysql_cur_dict.fetchone() if total_records is not None: @@ -942,9 +1184,10 @@ def _coerce_row(row: t.Any) -> t.Tuple[str, str]: # only continue if there is anything to transfer if total_records_count > 0: # populate it + safe_table = self._escape_mysql_backticks(table_name) self._mysql_cur.execute( "SELECT * FROM `{table_name}` {limit}".format( - table_name=table_name, + table_name=safe_table, limit=f"LIMIT {self._limit_rows}" if self._limit_rows > 0 else "", ) ) diff --git a/tests/unit/test_build_create_table_sql_sqlglot_identifiers.py b/tests/unit/test_build_create_table_sql_sqlglot_identifiers.py new file mode 100644 index 0000000..63a4a37 --- /dev/null +++ b/tests/unit/test_build_create_table_sql_sqlglot_identifiers.py @@ -0,0 +1,97 @@ +import re +from unittest.mock import MagicMock, patch + +from mysql_to_sqlite3.transporter import MySQLtoSQLite + + +def _make_base_instance(): + with patch.object(MySQLtoSQLite, "__init__", return_value=None): + inst = MySQLtoSQLite() # type: ignore[call-arg] + inst._mysql_cur_dict = MagicMock() + inst._mysql_database = "db" + inst._sqlite_json1_extension_enabled = False + inst._collation = "BINARY" + inst._prefix_indices = False + inst._without_tables = False + inst._without_foreign_keys = True + inst._logger = MagicMock() + inst._sqlite_strict = False + # Track index names for uniqueness + inst._seen_sqlite_index_names = set() + inst._sqlite_index_name_counters = {} + return inst + + +def test_show_columns_backticks_are_escaped_in_mysql_query() -> None: + inst = _make_base_instance() + + # Capture executed SQL + executed_sql = [] + + def capture_execute(sql: str, *_, **__): + executed_sql.append(sql) + + inst._mysql_cur_dict.execute.side_effect = capture_execute + + # SHOW COLUMNS -> then STATISTICS query + inst._mysql_cur_dict.fetchall.side_effect = [ + [ + { + "Field": "id", + "Type": "INT", + "Null": "NO", + "Default": None, + "Key": "PRI", + "Extra": "", + } + ], + [], + ] + # TABLE collision check -> 0 + inst._mysql_cur_dict.fetchone.return_value = {"count": 0} + + sql = inst._build_create_table_sql("we`ird") + assert sql.startswith('CREATE TABLE IF NOT EXISTS "we`ird" (') + + # First executed SQL should be SHOW COLUMNS with backticks escaped + assert executed_sql + assert executed_sql[0] == "SHOW COLUMNS FROM `we``ird`" + + +def test_identifiers_with_double_quotes_are_safely_quoted_in_create_and_index() -> None: + inst = _make_base_instance() + inst._prefix_indices = True # ensure an index is emitted with a deterministic name prefix + + # SHOW COLUMNS first call, then STATISTICS rows + inst._mysql_cur_dict.fetchall.side_effect = [ + [ + { + "Field": 'na"me', + "Type": "VARCHAR(10)", + "Null": "YES", + "Default": None, + "Key": "", + "Extra": "", + }, + ], + [ + { + "name": "idx", + "primary": 0, + "unique": 0, + "auto_increment": 0, + "columns": 'na"me', + "types": "VARCHAR(10)", + } + ], + ] + inst._mysql_cur_dict.fetchone.return_value = {"count": 0} + + sql = inst._build_create_table_sql('ta"ble') + + # Column should be quoted with doubled quotes inside + assert '"na""me" VARCHAR(10)' in sql or '"na""me" TEXT' in sql + + # Index should quote table and column names with doubled quotes + norm = re.sub(r"\s+", " ", sql) + assert 'CREATE INDEX IF NOT EXISTS "ta""ble_idx" ON "ta""ble" ("na""me")' in norm diff --git a/tests/unit/test_collation_sqlglot_augmented.py b/tests/unit/test_collation_sqlglot_augmented.py new file mode 100644 index 0000000..12ab9d5 --- /dev/null +++ b/tests/unit/test_collation_sqlglot_augmented.py @@ -0,0 +1,51 @@ +import pytest + +from mysql_to_sqlite3.sqlite_utils import CollatingSequences +from mysql_to_sqlite3.transporter import MySQLtoSQLite + + +class TestCollationSqlglotAugmented: + @pytest.mark.parametrize( + "mysql_type", + [ + "char varying(12)", + "CHARACTER VARYING(12)", + ], + ) + def test_collation_applied_for_char_varying_synonyms(self, mysql_type: str) -> None: + out = MySQLtoSQLite._data_type_collation_sequence(collation=CollatingSequences.NOCASE, column_type=mysql_type) + assert out == f"COLLATE {CollatingSequences.NOCASE}" + + def test_collation_applied_for_national_character_varying(self) -> None: + out = MySQLtoSQLite._data_type_collation_sequence( + collation=CollatingSequences.NOCASE, column_type="national character varying(15)" + ) + assert out == f"COLLATE {CollatingSequences.NOCASE}" + + def test_no_collation_for_json(self) -> None: + # Regardless of case or synonym handling, JSON should not have collation applied + assert ( + MySQLtoSQLite._data_type_collation_sequence(collation=CollatingSequences.NOCASE, column_type="json") == "" + ) + + def test_no_collation_when_binary_collation(self) -> None: + # BINARY collation disables COLLATE clause entirely + assert ( + MySQLtoSQLite._data_type_collation_sequence(collation=CollatingSequences.BINARY, column_type="VARCHAR(10)") + == "" + ) + + @pytest.mark.parametrize( + "numeric_synonym", + [ + "double precision", + "FIXED(10,2)", + ], + ) + def test_no_collation_for_numeric_synonyms(self, numeric_synonym: str) -> None: + assert ( + MySQLtoSQLite._data_type_collation_sequence( + collation=CollatingSequences.NOCASE, column_type=numeric_synonym + ) + == "" + ) diff --git a/tests/unit/test_defaults_sqlglot_enhanced.py b/tests/unit/test_defaults_sqlglot_enhanced.py new file mode 100644 index 0000000..4179d62 --- /dev/null +++ b/tests/unit/test_defaults_sqlglot_enhanced.py @@ -0,0 +1,60 @@ +import pytest + +from mysql_to_sqlite3.transporter import MySQLtoSQLite + + +class TestDefaultsSqlglotEnhanced: + @pytest.mark.parametrize( + "expr,expected", + [ + ("CURRENT_TIME", "DEFAULT CURRENT_TIME"), + ("CURRENT_DATE", "DEFAULT CURRENT_DATE"), + ("CURRENT_TIMESTAMP", "DEFAULT CURRENT_TIMESTAMP"), + ], + ) + def test_current_tokens_passthrough(self, expr: str, expected: str) -> None: + assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite(expr, column_extra="DEFAULT_GENERATED") == expected + + def test_null_literal_generated(self) -> None: + assert ( + MySQLtoSQLite._translate_default_from_mysql_to_sqlite("NULL", column_extra="DEFAULT_GENERATED") + == "DEFAULT NULL" + ) + + @pytest.mark.parametrize( + "expr,boolean_type,expected", + [ + ("true", "BOOLEAN", {"DEFAULT(TRUE)", "DEFAULT '1'"}), + ("false", "BOOLEAN", {"DEFAULT(FALSE)", "DEFAULT '0'"}), + ("true", "INTEGER", {"DEFAULT '1'"}), + ("false", "INTEGER", {"DEFAULT '0'"}), + ], + ) + def test_boolean_tokens_generated(self, expr: str, boolean_type: str, expected: set) -> None: + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite( + expr, column_type=boolean_type, column_extra="DEFAULT_GENERATED" + ) + assert out in expected + + def test_parenthesized_string_literal_generated(self) -> None: + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite("('abc')", column_extra="DEFAULT_GENERATED") + # Either DEFAULT 'abc' or DEFAULT ('abc') depending on normalization + assert out in {"DEFAULT 'abc'", "DEFAULT ('abc')"} + + def test_parenthesized_numeric_literal_generated(self) -> None: + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite("(42)", column_extra="DEFAULT_GENERATED") + assert out in {"DEFAULT 42", "DEFAULT (42)"} + + def test_constant_arithmetic_expression_generated(self) -> None: + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite("1+2*3", column_extra="DEFAULT_GENERATED") + # sqlglot formats with spaces for sqlite dialect + assert out in {"DEFAULT 1 + 2 * 3", "DEFAULT (1 + 2 * 3)"} + + def test_hex_blob_literal_generated(self) -> None: + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite("x'41'", column_extra="DEFAULT_GENERATED") + # Should recognize as blob literal and keep as-is + assert out.upper() == "DEFAULT X'41'" + + def test_plain_string_escaping_single_quote(self) -> None: + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite("O'Reilly") + assert out == "DEFAULT 'O''Reilly'" diff --git a/tests/unit/test_indices_prefix_and_uniqueness.py b/tests/unit/test_indices_prefix_and_uniqueness.py index c0b5c01..d40560a 100644 --- a/tests/unit/test_indices_prefix_and_uniqueness.py +++ b/tests/unit/test_indices_prefix_and_uniqueness.py @@ -49,7 +49,7 @@ def test_build_create_table_sql_prefix_indices_true_prefixes_index_names() -> No sql = inst._build_create_table_sql("users") # With prefix_indices=True, the index name should be prefixed with table name - assert 'CREATE INDEX IF NOT EXISTS "users_idx_name" ON "users" ("name");' in sql + assert 'CREATE INDEX IF NOT EXISTS "users_idx_name" ON "users" ("name");' in sql def test_build_create_table_sql_collision_renamed_and_uniqueness_suffix() -> None: diff --git a/tests/unit/test_types_and_defaults_extra.py b/tests/unit/test_types_and_defaults_extra.py index 3317601..ee7b74c 100644 --- a/tests/unit/test_types_and_defaults_extra.py +++ b/tests/unit/test_types_and_defaults_extra.py @@ -40,6 +40,18 @@ def test_data_type_collation_sequence(self) -> None: def test_translate_default_common_keywords(self, default: str, expected: str) -> None: assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite(default) == expected + def test_translate_default_current_timestamp_precision_transpiled(self) -> None: + # MySQL allows fractional seconds: CURRENT_TIMESTAMP(6). Ensure it's normalized to SQLite token. + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite( + "CURRENT_TIMESTAMP(6)", column_extra="DEFAULT_GENERATED" + ) + assert out == "DEFAULT CURRENT_TIMESTAMP" + + def test_translate_default_generated_expr_fallback_quotes(self) -> None: + # Unknown expressions should fall back to quoted string default for safety + out = MySQLtoSQLite._translate_default_from_mysql_to_sqlite("uuid()", column_extra="DEFAULT_GENERATED") + assert out == "DEFAULT 'uuid()'" + def test_translate_default_charset_introducer_str_hex_and_bin(self) -> None: # DEFAULT_GENERATED with charset introducer and hex (escaped as in MySQL) s = "_utf8mb4 X\\'41\\'" # hex for 'A' diff --git a/tests/unit/test_types_sqlglot_augmented.py b/tests/unit/test_types_sqlglot_augmented.py new file mode 100644 index 0000000..1e5ab4c --- /dev/null +++ b/tests/unit/test_types_sqlglot_augmented.py @@ -0,0 +1,72 @@ +import pytest + +from mysql_to_sqlite3.transporter import MySQLtoSQLite + + +class TestSqlglotAugmentedTypeTranslation: + @pytest.mark.parametrize("mysql_type", ["double precision", "DOUBLE PRECISION", "DoUbLe PrEcIsIoN"]) + def test_double_precision_maps_to_numeric_type(self, mysql_type: str) -> None: + # Prior mapper would resolve this to TEXT; sqlglot fallback should improve it + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite(mysql_type) + assert out in {"DOUBLE", "REAL"} + + def test_fixed_maps_to_decimal(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("fixed(10,2)") + # Normalize to DECIMAL (without length) to match existing style + assert out == "DECIMAL" + + def test_character_varying_keeps_length_as_varchar(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("character varying(20)") + assert out == "VARCHAR(20)" + + def test_char_varying_keeps_length_as_varchar(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("char varying(12)") + assert out == "VARCHAR(12)" + + def test_national_character_varying_maps_to_nvarchar(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("national character varying(15)") + assert out == "NVARCHAR(15)" + + def test_national_character_maps_to_nchar(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("national character(5)") + assert out == "NCHAR(5)" + + @pytest.mark.parametrize( + "mysql_type,expected", + [ + ("int unsigned", "INTEGER"), + ("mediumint unsigned", "MEDIUMINT"), + ("smallint unsigned", "SMALLINT"), + ("tinyint unsigned", "TINYINT"), + ("bigint unsigned", "BIGINT"), + ], + ) + def test_unsigned_variants_strip_unsigned(self, mysql_type: str, expected: str) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite(mysql_type) + assert out == expected + + def test_timestamp_maps_to_datetime(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("timestamp") + assert out == "DATETIME" + + def test_varbinary_and_blobs_map_to_blob(self) -> None: + assert MySQLtoSQLite._translate_type_from_mysql_to_sqlite("varbinary(16)") == "BLOB" + assert MySQLtoSQLite._translate_type_from_mysql_to_sqlite("mediumblob") == "BLOB" + + def test_char_maps_to_character_with_length(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("char(3)") + assert out == "CHARACTER(3)" + + def test_json_mapping_respects_json1(self) -> None: + assert ( + MySQLtoSQLite._translate_type_from_mysql_to_sqlite("json", sqlite_json1_extension_enabled=False) == "TEXT" + ) + assert MySQLtoSQLite._translate_type_from_mysql_to_sqlite("json", sqlite_json1_extension_enabled=True) == "JSON" + + def test_fallback_to_text_on_unknown_type(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("geography") + assert out == "TEXT" + + def test_enum_remains_text(self) -> None: + out = MySQLtoSQLite._translate_type_from_mysql_to_sqlite("enum('a','b')") + assert out == "TEXT" diff --git a/tox.ini b/tox.ini index e377116..f22dc55 100644 --- a/tox.ini +++ b/tox.ini @@ -112,4 +112,4 @@ import-order-style = pycharm application-import-names = flake8 [pylint] -disable = C0209,C0301,C0411,R,W0107,W0622,C0103 \ No newline at end of file +disable = C0209,C0301,C0411,R,W0107,W0622,C0103,C0302 \ No newline at end of file