Skip to content

Commit

Permalink
Type signatures for .create_table() and .create_table_sql() and .crea…
Browse files Browse the repository at this point in the history
…te() and Table.__init__

Closes #314
  • Loading branch information
simonw committed Aug 18, 2021
1 parent 282e813 commit c79737b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 43 deletions.
104 changes: 61 additions & 43 deletions sqlite_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@
Trigger = namedtuple("Trigger", ("name", "table", "sql"))


ForeignKeysType = Union[
Iterable[str],
Iterable[ForeignKey],
Iterable[Tuple[str, str]],
Iterable[Tuple[str, str, str]],
Iterable[Tuple[str, str, str, str]],
]


class Default:
pass

Expand Down Expand Up @@ -572,18 +581,22 @@ def execute_returning_dicts(
) -> List[dict]:
return list(self.query(sql, params))

def resolve_foreign_keys(self, name, foreign_keys):
# foreign_keys may be a list of strcolumn names, a list of ForeignKey tuples,
def resolve_foreign_keys(
self, name: str, foreign_keys: ForeignKeysType
) -> List[ForeignKey]:
# foreign_keys may be a list of column names, a list of ForeignKey tuples,
# a list of tuple-pairs or a list of tuple-triples. We want to turn
# it into a list of ForeignKey tuples
table = cast(Table, self[name])
if all(isinstance(fk, ForeignKey) for fk in foreign_keys):
return foreign_keys
return cast(List[ForeignKey], foreign_keys)
if all(isinstance(fk, str) for fk in foreign_keys):
# It's a list of columns
fks = []
for column in foreign_keys:
other_table = self[name].guess_foreign_table(column)
other_column = self[name].guess_foreign_column(other_table)
column = cast(str, column)
other_table = table.guess_foreign_table(column)
other_column = table.guess_foreign_column(other_table)
fks.append(ForeignKey(name, column, other_table, other_column))
return fks
assert all(
Expand All @@ -596,6 +609,7 @@ def resolve_foreign_keys(self, name, foreign_keys):
3,
), "foreign_keys= should be a list of tuple pairs or triples"
if len(tuple_or_list) == 3:
tuple_or_list = cast(Tuple[str, str, str], tuple_or_list)
fks.append(
ForeignKey(
name, tuple_or_list[0], tuple_or_list[1], tuple_or_list[2]
Expand All @@ -608,7 +622,7 @@ def resolve_foreign_keys(self, name, foreign_keys):
name,
tuple_or_list[0],
tuple_or_list[1],
self[name].guess_foreign_column(tuple_or_list[1]),
table.guess_foreign_column(tuple_or_list[1]),
)
)
return fks
Expand All @@ -618,12 +632,12 @@ def create_table_sql(
name: str,
columns: Dict[str, Any],
pk: Optional[Any] = None,
foreign_keys=None,
column_order=None,
not_null=None,
defaults=None,
hash_id=None,
extracts=None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Iterable[str] = None,
defaults: Optional[Dict[str, Any]] = None,
hash_id: Optional[Any] = None,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
) -> str:
"Returns the SQL ``CREATE TABLE`` statement for creating the specified table."
foreign_keys = self.resolve_foreign_keys(name, foreign_keys or [])
Expand Down Expand Up @@ -656,9 +670,11 @@ def create_table_sql(
validate_column_names(columns.keys())
column_items = list(columns.items())
if column_order is not None:
column_items.sort(
key=lambda p: column_order.index(p[0]) if p[0] in column_order else 999
)

def sort_key(p):
return column_order.index(p[0]) if p[0] in column_order else 999

column_items.sort(key=sort_key)
if hash_id:
column_items.insert(0, (hash_id, str))
pk = hash_id
Expand Down Expand Up @@ -725,12 +741,12 @@ def create_table(
name: str,
columns: Dict[str, Any],
pk: Optional[Any] = None,
foreign_keys=None,
column_order=None,
not_null=None,
defaults=None,
hash_id=None,
extracts=None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Iterable[str] = None,
defaults: Optional[Dict[str, Any]] = None,
hash_id: Optional[Any] = None,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
) -> "Table":
"""
Create a table with the specified name and the specified ``{column_name: type}`` columns.
Expand Down Expand Up @@ -1021,19 +1037,19 @@ def __init__(
self,
db: Database,
name: str,
pk=None,
foreign_keys=None,
column_order=None,
not_null=None,
defaults=None,
batch_size=100,
hash_id=None,
alter=False,
ignore=False,
replace=False,
extracts=None,
conversions=None,
columns=None,
pk: Optional[Any] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Iterable[str] = None,
defaults: Optional[Dict[str, Any]] = None,
batch_size: int = 100,
hash_id: Optional[Any] = None,
alter: bool = False,
ignore: bool = False,
replace: bool = False,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
conversions: Optional[dict] = None,
columns: Optional[Union[Dict[str, Any]]] = None,
):
super().__init__(db, name)
self._defaults = dict(
Expand Down Expand Up @@ -1202,14 +1218,14 @@ def triggers_dict(self) -> Dict[str, str]:

def create(
self,
columns,
pk=None,
foreign_keys=None,
column_order=None,
not_null=None,
defaults=None,
hash_id=None,
extracts=None,
columns: Dict[str, Any],
pk: Optional[Any] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Iterable[str] = None,
defaults: Optional[Dict[str, Any]] = None,
hash_id: Optional[Any] = None,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
) -> "Table":
"""
Create a table with the specified columns.
Expand Down Expand Up @@ -2914,7 +2930,9 @@ def _hash(record):
).hexdigest()


def resolve_extracts(extracts):
def resolve_extracts(
extracts: Optional[Union[Dict[str, str], List[str], Tuple[str]]]
) -> dict:
if extracts is None:
extracts = {}
if isinstance(extracts, (list, tuple)):
Expand Down
1 change: 1 addition & 0 deletions tests/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_tracer():
("PRAGMA recursive_triggers=on;", None),
("select name from sqlite_master where type = 'view'", None),
("select name from sqlite_master where type = 'table'", None),
("select name from sqlite_master where type = 'view'", None),
("CREATE TABLE [dogs] (\n [name] TEXT\n);\n ", None),
("select name from sqlite_master where type = 'view'", None),
("INSERT INTO [dogs] ([name]) VALUES (?);", ["Cleopaws"]),
Expand Down

0 comments on commit c79737b

Please sign in to comment.