Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for sqlalchemy #322

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional

from sqlalchemy.sql import compiler
from sqlalchemy.sql.base import DialectKWArgs
from sqlalchemy.sql.schema import Table

# https://trino.io/docs/current/language/reserved.html
RESERVED_WORDS = {
Expand Down Expand Up @@ -92,7 +95,7 @@


class TrinoSQLCompiler(compiler.SQLCompiler):
def limit_clause(self, select, **kw):
def limit_clause(self, select: Any, **kw: Dict[str, Any]) -> str:
"""
Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
"""
Expand All @@ -103,15 +106,15 @@ def limit_clause(self, select, **kw):
text += "\nLIMIT " + self.process(select._limit_clause, **kw)
return text

def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
def visit_table(self, table: Table, asfrom: bool = False, iscrud: bool = False, ashint: bool = False,
fromhints: Optional[Any] = None, use_schema: bool = True, **kwargs: Any) -> str:
sql = super(TrinoSQLCompiler, self).visit_table(
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
)
return self.add_catalog(sql, table)

@staticmethod
def add_catalog(sql, table):
def add_catalog(sql: str, table: Table) -> str:
if table is None or not isinstance(table, DialectKWArgs):
return sql

Expand All @@ -131,7 +134,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):


class TrinoTypeCompiler(compiler.GenericTypeCompiler):
def visit_FLOAT(self, type_, **kw):
def visit_FLOAT(self, type_: Any, **kw: Dict[str, Any]) -> str:
precision = type_.precision or 32
if 0 <= precision <= 32:
return self.visit_REAL(type_, **kw)
Expand All @@ -140,37 +143,37 @@ def visit_FLOAT(self, type_, **kw):
else:
raise ValueError(f"type.precision must be in range [0, 64], got {type_.precision}")

def visit_DOUBLE(self, type_, **kw):
def visit_DOUBLE(self, type_: Any, **kw: Dict[str, Any]) -> str:
return "DOUBLE"

def visit_NUMERIC(self, type_, **kw):
def visit_NUMERIC(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_DECIMAL(type_, **kw)

def visit_NCHAR(self, type_, **kw):
def visit_NCHAR(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_CHAR(type_, **kw)

def visit_NVARCHAR(self, type_, **kw):
def visit_NVARCHAR(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_TEXT(self, type_, **kw):
def visit_TEXT(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_BINARY(self, type_, **kw):
def visit_BINARY(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_VARBINARY(type_, **kw)

def visit_CLOB(self, type_, **kw):
def visit_CLOB(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_NCLOB(self, type_, **kw):
def visit_NCLOB(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_BLOB(self, type_, **kw):
def visit_BLOB(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_VARBINARY(type_, **kw)

def visit_DATETIME(self, type_, **kw):
def visit_DATETIME(self, type_: Any, **kw: Dict[str, Any]) -> str:
return self.visit_TIMESTAMP(type_, **kw)

def visit_TIMESTAMP(self, type_, **kw):
def visit_TIMESTAMP(self, type_: Any, **kw: Dict[str, Any]) -> str:
datatype = "TIMESTAMP"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
Expand All @@ -182,7 +185,7 @@ def visit_TIMESTAMP(self, type_, **kw):

return datatype

def visit_TIME(self, type_, **kw):
def visit_TIME(self, type_: Any, **kw: Dict[str, Any]) -> str:
datatype = "TIME"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
Expand All @@ -193,13 +196,13 @@ def visit_TIME(self, type_, **kw):
datatype += " WITH TIME ZONE"
return datatype

def visit_JSON(self, type_, **kw):
def visit_JSON(self, type_: Any, **kw: Dict[str, Any]) -> str:
return 'JSON'


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS

def format_table(self, table, use_schema=True, name=None):
def format_table(self, table: Table, use_schema: bool = True, name: Optional[str] = None) -> str:
result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name)
return TrinoSQLCompiler.add_catalog(result, table)
20 changes: 12 additions & 8 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
# limitations under the License.
import json
import re
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional
from typing import Text as typing_Text
from typing import Tuple, Type, TypeVar, Union

from sqlalchemy import util
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
from sqlalchemy.types import String

SQLType = Union[TypeEngine, Type[TypeEngine]]
_T = TypeVar('_T')


class DOUBLE(sqltypes.Float):
Expand All @@ -38,7 +42,7 @@ def __init__(self, key_type: SQLType, value_type: SQLType):
self.value_type: TypeEngine = value_type

@property
def python_type(self):
def python_type(self) -> type:
return dict


Expand All @@ -53,36 +57,36 @@ def __init__(self, attr_types: List[Tuple[Optional[str], SQLType]]):
self.attr_types.append((attr_name, attr_type))

@property
def python_type(self):
def python_type(self) -> type:
return list


class TIME(sqltypes.TIME):
__visit_name__ = "TIME"

def __init__(self, precision=None, timezone=False):
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision


class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"

def __init__(self, precision=None, timezone=False):
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision


class JSON(TypeDecorator):
impl = String

def process_bind_param(self, value, dialect):
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]:
return json.dumps(value)

def process_result_value(self, value, dialect):
def process_result_value(self, value: Union[str, bytes], dialect: Dialect) -> Optional[_T]:
return json.loads(value)

def get_col_spec(self, **kw):
def get_col_spec(self, **kw: Dict[str, Any]) -> str:
return 'JSON'


Expand Down
Loading