Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/sqlite3_to_mysql/transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,55 @@ def _valid_column_type(cls, column_type: str) -> t.Optional[t.Match[str]]:
return cls.COLUMN_PATTERN.match(column_type.strip())

def _translate_type_from_sqlite_to_mysql(self, column_type: str) -> str:
normalized: t.Optional[str] = self._normalize_sqlite_column_type(column_type)
if normalized and normalized.upper() != column_type.upper():
self._logger.info("Normalised SQLite column type %r -> %r", column_type, normalized)
try:
return self._translate_type_from_sqlite_to_mysql_legacy(normalized)
except ValueError:
pass
return self._translate_type_from_sqlite_to_mysql_legacy(column_type)

def _normalize_sqlite_column_type(self, column_type: str) -> t.Optional[str]:
clean_type: str = column_type.strip()
if not clean_type:
return None

normalized_for_parse: str = clean_type.upper().replace("UNSIGNED BIG INT", "BIGINT UNSIGNED")
try:
expression = sqlglot.parse_one(f"SELECT CAST(NULL AS {normalized_for_parse})", read="sqlite")
except sqlglot_errors.ParseError:
# Retry: strip UNSIGNED to aid parsing; we'll re-attach it below if present.
try:
no_unsigned = re.sub(r"\bUNSIGNED\b", "", normalized_for_parse).strip()
expression = sqlglot.parse_one(f"SELECT CAST(NULL AS {no_unsigned})", read="sqlite")
except sqlglot_errors.ParseError:
return None

cast: t.Optional[exp.Cast] = expression.find(exp.Cast)
if not cast or not isinstance(cast.to, exp.DataType):
return None

params: t.List[str] = []
for expr_param in cast.to.expressions or []:
value_expr = expr_param.this if isinstance(expr_param, exp.DataTypeParam) else expr_param
if value_expr is None:
continue
params.append(value_expr.sql(dialect="mysql"))

base_match: t.Optional[t.Match[str]] = self._valid_column_type(clean_type)
base = base_match.group(0).upper().strip() if base_match else clean_type.upper()

normalized = base
if params:
normalized += "(" + ",".join(param.strip("\"'") for param in params) + ")"

if "UNSIGNED" in clean_type.upper() and "UNSIGNED" not in normalized.upper().split():
normalized = f"{normalized} UNSIGNED"

return normalized

def _translate_type_from_sqlite_to_mysql_legacy(self, column_type: str) -> str:
"""This could be optimized even further, however is seems adequate."""
full_column_type: str = column_type.upper()
unsigned: bool = self.COLUMN_UNSIGNED_PATTERN.search(full_column_type) is not None
Expand Down Expand Up @@ -534,6 +583,19 @@ def _translate_default_for_mysql(self, column_type: str, default: str) -> str:
return s

# Fallback: return stripped expression (MySQL 8.0.13+ allows expression defaults)
if self._allow_expr_defaults:
try:
expr = sqlglot.parse_one(s, read="sqlite")
except sqlglot_errors.ParseError:
return s

expr = expr.transform(self._rewrite_sqlite_view_functions)

try:
return expr.sql(dialect="mysql")
except sqlglot_errors.SqlglotError:
return s

return s

@classmethod
Expand Down
Loading