Skip to content

Commit

Permalink
test: oracle improve
Browse files Browse the repository at this point in the history
  • Loading branch information
long2ice committed Apr 20, 2022
1 parent 316d711 commit c7a03c1
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 32 deletions.
7 changes: 6 additions & 1 deletion tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tortoise.backends.asyncpg import AsyncpgDBClient
from tortoise.backends.mssql import MSSQLClient
from tortoise.backends.mysql import MySQLClient
from tortoise.backends.oracle import OracleClient
from tortoise.backends.psycopg import PsycopgClient
from tortoise.backends.sqlite import SqliteClient
from tortoise.contrib import test
Expand All @@ -28,6 +29,10 @@ async def asyncSetUp(self) -> None:
await db.execute_query(
'insert into defaultmodel ("int_default","float_default","decimal_default","bool_default","char_default","date_default","datetime_default") values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)',
)
elif isinstance(db, OracleClient):
await db.execute_query(
'insert into "defaultmodel" ("int_default","float_default","decimal_default","bool_default","char_default","date_default","datetime_default") values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)',
)

async def test_default(self):
default_model = await DefaultModel.first()
Expand All @@ -36,7 +41,7 @@ async def test_default(self):
self.assertEqual(default_model.decimal_default, Decimal(1))
self.assertTrue(default_model.bool_default)
self.assertEqual(default_model.char_default, "tortoise")
self.assertEqual(default_model.date_default, datetime.date(year=2020, month=5, day=20))
self.assertEqual(default_model.date_default, datetime.date(year=2020, month=5, day=21))
self.assertEqual(
default_model.datetime_default,
datetime.datetime(year=2020, month=5, day=20, tzinfo=pytz.utc),
Expand Down
3 changes: 3 additions & 0 deletions tests/test_order_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Tournament,
)
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.exceptions import ConfigurationError, FieldError
from tortoise.functions import Count, Sum

Expand Down Expand Up @@ -84,13 +85,15 @@ async def test_order_by_aggregation_reversed(self):


class TestDefaultOrdering(test.TestCase):
@test.requireCapability(dialect=NotEQ("oracle"))
async def test_default_order(self):
await DefaultOrdered.create(one="2", second=1)
await DefaultOrdered.create(one="1", second=1)

instance_list = await DefaultOrdered.all()
self.assertEqual([i.one for i in instance_list], ["1", "2"])

@test.requireCapability(dialect=NotEQ("oracle"))
async def test_default_order_desc(self):
await DefaultOrderedDesc.create(one="1", second=1)
await DefaultOrderedDesc.create(one="2", second=1)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_order_by_nested.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from tests.testmodels import Event, Tournament
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ


class TestOrderByNested(test.TestCase):
@test.requireCapability(dialect=NotEQ("oracle"))
async def test_basic(self):
await Event.create(
name="Event 1", tournament=await Tournament.create(name="Tournament 1", desc="B")
Expand Down
2 changes: 1 addition & 1 deletion tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ class DefaultModel(Model):
decimal_default = fields.DecimalField(max_digits=8, decimal_places=2, default=Decimal(1))
bool_default = fields.BooleanField(default=True)
char_default = fields.CharField(max_length=20, default="tortoise")
date_default = fields.DateField(default=datetime.date(year=2020, month=5, day=20))
date_default = fields.DateField(default=datetime.date(year=2020, month=5, day=21))
datetime_default = fields.DatetimeField(
default=datetime.datetime(year=2020, month=5, day=20, tzinfo=pytz.utc)
)
Expand Down
81 changes: 64 additions & 17 deletions tortoise/backends/oracle/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
from typing import Any, SupportsInt
import datetime
import functools
from typing import Any, SupportsInt, Union

import pyodbc
import pytz

try:
from ciso8601 import parse_datetime
except ImportError: # pragma: nocoverage
from iso8601 import parse_date

parse_datetime = functools.partial(parse_date, default_timezone=None)
from pypika import OracleQuery

from tortoise.backends.base.client import TransactionContext, TransactionContextPooled
from tortoise.backends.base.client import (
Capabilities,
ConnectionWrapper,
PoolConnectionWrapper,
TransactionContext,
TransactionContextPooled,
)
from tortoise.backends.odbc.client import ODBCClient, ODBCTransactionWrapper, translate_exceptions
from tortoise.backends.oracle.executor import OracleExecutor
from tortoise.backends.oracle.schema_generator import OracleSchemaGenerator
Expand All @@ -12,6 +29,7 @@ class OracleClient(ODBCClient):
query_class = OracleQuery
schema_generator = OracleSchemaGenerator
executor_class = OracleExecutor
capabilities = Capabilities(dialect="oracle")

def __init__(
self,
Expand All @@ -30,27 +48,34 @@ def __init__(
def _in_transaction(self) -> "TransactionContext":
return TransactionContextPooled(TransactionWrapper(self))

async def create_connection(self, with_db: bool) -> None:
await super(OracleClient, self).create_connection(with_db=with_db)
if with_db:
await self.execute_query(f'ALTER SESSION SET CURRENT_SCHEMA = "{self.database}"')
await self.execute_query("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'")
await self.execute_query(
"ALTER SESSION SET NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD\"T\"HH24:MI:SSTZH:TZM'"
)
def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
return OraclePoolConnectionWrapper(self)

async def db_create(self) -> None:
await self.create_connection(with_db=False)
await self.execute_script(
f'CREATE USER "{self.database}" IDENTIFIED BY "{self.password}";GRANT ALL PRIVILEGES TO "{self.database}";'
)
await self.execute_script(f'CREATE USER "{self.database}" IDENTIFIED BY "{self.password}"')
await self.execute_script(f'GRANT ALL PRIVILEGES TO "{self.database}"')
await self.close()

async def db_delete(self) -> None:
await self.create_connection(with_db=False)
await self.execute_script(f'DROP USER "{self.database}" CASCADE')
try:
await self.execute_script(f'DROP USER "{self.database}" CASCADE')
except pyodbc.Error as e:
if "does not exist" not in str(e):
raise
await self.close()

@translate_exceptions
async def execute_script(self, query: str) -> None:
async with self.acquire_connection() as connection:
self.log.debug(query)
async with connection.cursor() as cursor:
for q in query.split(";"):
if not q.strip():
continue
await cursor.execute(q)

@translate_exceptions
async def execute_insert(self, query: str, values: list) -> int:
async with self.acquire_connection() as connection:
Expand All @@ -59,6 +84,28 @@ async def execute_insert(self, query: str, values: list) -> int:
return 0


class TransactionWrapper(OracleClient, ODBCTransactionWrapper):
def __init__(self, connection: ODBCClient) -> None:
ODBCTransactionWrapper.__init__(self, connection=connection)
class OraclePoolConnectionWrapper(PoolConnectionWrapper):
def _timestamp_convert(self, value: bytes) -> datetime.date:
try:
return parse_datetime(value.decode()).date()
except ValueError:
return parse_datetime(value.decode()[:-32]).astimezone(tz=pytz.utc)

async def __aenter__(self):
connection = await super(OraclePoolConnectionWrapper, self).__aenter__() # type: ignore
if self.client._template.get("database"):
await connection.execute(f'ALTER SESSION SET CURRENT_SCHEMA = "{self.client.database}"')
await connection.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'")
await connection.execute(
"ALTER SESSION SET NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD\"T\"HH24:MI:SSTZH:TZM'"
)
await connection.add_output_converter(
pyodbc.SQL_TYPE_TIMESTAMP, self._timestamp_convert
)
return connection


class TransactionWrapper(ODBCTransactionWrapper, OracleClient):
async def start(self) -> None:
await self._connection.execute("SET TRANSACTION READ WRITE")
await super().start()
6 changes: 3 additions & 3 deletions tortoise/backends/oracle/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

class OracleExecutor(ODBCExecutor):
async def _process_insert_result(self, instance: Model, results: int) -> None:
sql = "SELECT sequence_name FROM ALL_TAB_IDENTITY_COLS where TABLE_NAME = ? and OWNER = ?"
sql = "SELECT SEQUENCE_NAME FROM ALL_TAB_IDENTITY_COLS where TABLE_NAME = ? and OWNER = ?"
ret = await self.db.execute_query_dict(
sql, values=[instance._meta.db_table, self.db.database] # type: ignore
)
seq = ret[0]["sequence_name"]
seq = ret[0]["SEQUENCE_NAME"]
sql = f"SELECT {seq}.CURRVAL FROM DUAL"
ret = await self.db.execute_query_dict(sql)
await super(OracleExecutor, self)._process_insert_result(instance, ret[0]["currval"])
await super(OracleExecutor, self)._process_insert_result(instance, ret[0]["CURRVAL"])
2 changes: 1 addition & 1 deletion tortoise/backends/oracle/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _create_fk_string(
) -> str:
if on_delete not in [CASCADE, SET_NULL]:
on_delete = CASCADE
constraint = f"CONSTRAINT {constraint_name} " if constraint_name else ""
constraint = f'CONSTRAINT "{constraint_name}" ' if constraint_name else ""
fk = self.FK_TEMPLATE.format(
constraint=constraint,
db_column=db_column,
Expand Down
21 changes: 13 additions & 8 deletions tortoise/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,19 @@ def constraints(self) -> dict:
return {}

def _get_dialects(self) -> Dict[str, dict]:
return {
dialect[4:]: {
key: val
for key, val in getattr(self, dialect).__dict__.items()
if not key.startswith("_")
}
for dialect in [key for key in dir(self) if key.startswith("_db_")]
}
ret = {}
for dialect in [key for key in dir(self) if key.startswith("_db_")]:
item = {}
cls = getattr(self, dialect)
try:
cls = cls(self)
except TypeError:
pass
for key, val in cls.__dict__.items():
if not key.startswith("_"):
item[key] = val
ret[dialect[4:]] = item
return ret

def get_db_field_types(self) -> Optional[Dict[str, str]]:
"""
Expand Down
10 changes: 9 additions & 1 deletion tortoise/fields/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ def constraints(self) -> dict:
def SQL_TYPE(self) -> str: # type: ignore
return f"VARCHAR({self.max_length})"

class _db_oracle:
def __init__(self, field: "CharField") -> None:
self.field = field

@property
def SQL_TYPE(self) -> str:
return f"NVARCHAR2({self.field.max_length})"


class TextField(Field, str): # type: ignore
"""
Expand Down Expand Up @@ -429,7 +437,7 @@ class TimeField(Field, datetime.time):
SQL_TYPE = "TIME"

class _db_oracle:
SQL_TYPE = "VARCHAR(8)"
SQL_TYPE = "NVARCHAR2(8)"

def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None:
if auto_now_add and auto_now:
Expand Down

0 comments on commit c7a03c1

Please sign in to comment.