Skip to content

Commit

Permalink
feat: add ci for mssql and make all test pass for mssql (#1107)
Browse files Browse the repository at this point in the history
* feat: add ci for mssql and make all test pass for mssql

* ci: fix mssql options

* ci: fix mssql options add -C

* ci: fix mssql options

* ci: fix mssql options

* ci: fix mssql options

* ci: fix mssql options

* ci: add sudo

* ci: add sudo

* ci: add sudo

* ci: add -o

* style: fix flake

* ci: fix ci.yml

* test: fix tests

* test: skip TestFuzz

* fix odbc _expire_connections

* fix asyncodbc

* test: fix test for mssql
  • Loading branch information
long2ice committed Apr 20, 2022
1 parent 78cc110 commit 316d711
Show file tree
Hide file tree
Showing 30 changed files with 369 additions and 233 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,24 @@ jobs:
env:
MYSQL_ROOT_PASSWORD: 123456
options: --health-cmd="mysqladmin ping" --health-interval 10s --health-timeout 5s --health-retries 5
mssql:
image: mcr.microsoft.com/mssql/server:2019-CU15-ubuntu-20.04
ports:
- 1433:1433
env:
ACCEPT_EULA: Y
SA_PASSWORD: Abcd12345678
options: >-
--health-cmd "/opt/mssql-tools/bin/sqlcmd -U sa -P Abcd12345678 -Q 'select 1' -b -o /dev/null"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
TORTOISE_TEST_MODULES: tests.testmodels
TORTOISE_MYSQL_PASS: 123456
TORTOISE_POSTGRES_PASS: 123456
TORTOISE_MSSQL_PASS: Abcd12345678
TORTOISE_MSSQL_DRIVER: ODBC Driver 18 for SQL Server
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10" ]
Expand All @@ -49,6 +63,12 @@ jobs:
poetry config virtualenvs.create false
- name: Install requirements
run: make deps
- name: Install ODBC driver
run: |
sudo curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add -
sudo curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list -o /etc/apt/sources.list.d/mssql-release.list
sudo apt-get update
ACCEPT_EULA=Y sudo apt-get install -y msodbcsql18
- name: Run ci
run: make ci
- name: Upload Coverage
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ test_mysql:
$(py_warn) TORTOISE_TEST_DB="mysql://root:$(TORTOISE_MYSQL_PASS)@127.0.0.1:3306/test_\{\}" pytest $(pytest_opts) --cov-append --cov-report=

test_mssql:
$(py_warn) TORTOISE_TEST_DB="mssql://sa:$(TORTOISE_MSSQL_PASS)@127.0.0.1:1433/test_\{\}?driver=${TORTOISE_MSSQL_DRIVER}&TrustServerCertificate=YES&autocommit=1" pytest $(pytest_opts) --cov-append --cov-report=
$(py_warn) TORTOISE_TEST_DB="mssql://sa:$(TORTOISE_MSSQL_PASS)@127.0.0.1:1433/test_\{\}?driver=$(TORTOISE_MSSQL_DRIVER)&TrustServerCertificate=YES" pytest $(pytest_opts) --cov-append --cov-report=

test_oracle:
$(py_warn) TORTOISE_TEST_DB="oracle://SYSTEM:$(TORTOISE_ORACLE_PASS)@127.0.0.1:1521/test_\{\}?driver=${TORTOISE_ORACLE_DRIVER}" pytest $(pytest_opts) --cov-append --cov-report=
$(py_warn) TORTOISE_TEST_DB="oracle://SYSTEM:$(TORTOISE_ORACLE_PASS)@127.0.0.1:1521/test_\{\}?driver=$(TORTOISE_ORACLE_DRIVER)" pytest $(pytest_opts) --cov-append --cov-report=

_testall: test_sqlite test_postgres_asyncpg test_postgres_psycopg test_mysql_myisam test_mysql
_testall: test_sqlite test_postgres_asyncpg test_postgres_psycopg test_mysql_myisam test_mysql test_mssql

testall: deps _testall
coverage report
Expand Down
9 changes: 3 additions & 6 deletions examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class Event(Model):
id = fields.IntField(pk=True)
name = fields.CharField(max_length=200)
name = fields.TextField()
datetime = fields.DatetimeField(null=True)

class Meta:
Expand All @@ -18,11 +18,8 @@ def __str__(self):


async def run():
await Tortoise.init(
db_url="oracle://test:123456@127.0.0.1:1521/test?driver=Oracle",
modules={"models": ["__main__"]},
)
# await Tortoise.generate_schemas()
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})
await Tortoise.generate_schemas()

event = await Event.create(name="Test")
await Event.filter(id=event.id).update(name="Updated name")
Expand Down
315 changes: 160 additions & 155 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ asyncpg = { version = "*", optional = true }
aiomysql = { version = "*", optional = true }
asyncmy = { version = "^0.2.5", optional = true }
psycopg = { extras = ["pool", "binary"], version = "*", optional = true }
asyncodbc = { version = "*", optional = true }
asyncodbc = { git = "https://github.com/tortoise/asyncodbc.git", branch = "main", optional = true }

[tool.poetry.dev-dependencies]
# Linter tools
Expand Down
2 changes: 2 additions & 0 deletions tests/backends/test_explain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from tests.testmodels import Tournament
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ


class TestExplain(test.TestCase):
@test.requireCapability(dialect=NotEQ("mssql"))
async def test_explain(self):
# NOTE: we do not provide any guarantee on the format of the value
# returned by `.explain()`, as it heavily depends on the database.
Expand Down
2 changes: 1 addition & 1 deletion tests/fields/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ async def test_values_list(self):
self.assertEqual(values, timedelta(days=35, seconds=8, microseconds=1))

async def test_get(self):
delta = timedelta(days=35, seconds=8, microseconds=1)
delta = timedelta(days=35, seconds=8, microseconds=2)
await testmodels.TimeDeltaFields.create(timedelta=delta)
obj = await testmodels.TimeDeltaFields.get(timedelta=delta)
self.assertEqual(obj.timedelta, delta)
4 changes: 2 additions & 2 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tests.testmodels import Author, Book, Event, MinRelation, Team, Tournament
from tortoise.contrib import test
from tortoise.contrib.test.condition import In
from tortoise.exceptions import ConfigurationError
from tortoise.expressions import Q
from tortoise.functions import Avg, Coalesce, Concat, Count, Lower, Max, Min, Sum, Trim
Expand Down Expand Up @@ -149,8 +150,7 @@ async def test_nested_functions(self):
ret = await Book.all().annotate(max_name=Lower(Max("name"))).values("max_name")
self.assertEqual(ret, [{"max_name": "third!"}])

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
@test.requireCapability(dialect=In("postgres", "mssql"))
async def test_concat_functions(self):
author = await Author.create(name="Some One")
await Book.create(name="Physics Book", author=author, rating=4, subject="physics ")
Expand Down
11 changes: 10 additions & 1 deletion tests/test_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from tests.testmodels import UniqueName, UUIDPkModel
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.exceptions import IntegrityError
from tortoise.transactions import in_transaction

Expand All @@ -15,6 +16,7 @@ async def test_bulk_create(self):
all_, [{"id": val + inc, "name": None} for val in range(1000)], sorted_key="id"
)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_update_fields(self):
await UniqueName.bulk_create([UniqueName(name="name")])
await UniqueName.bulk_create(
Expand All @@ -25,6 +27,7 @@ async def test_bulk_create_update_fields(self):
all_ = await UniqueName.all().values("name", "optional")
self.assertListSortEqual(all_, [{"name": "name", "optional": "optional"}])

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_more_that_one_update_fields(self):
await UniqueName.bulk_create([UniqueName(name="name")])
await UniqueName.bulk_create(
Expand All @@ -37,6 +40,7 @@ async def test_bulk_create_more_that_one_update_fields(self):
all_, [{"name": "name", "optional": "optional", "other_optional": "other_optional"}]
)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_with_batch_size(self):
await UniqueName.bulk_create(
[UniqueName(id=id_ + 1) for id_ in range(1000)], batch_size=100
Expand All @@ -46,13 +50,15 @@ async def test_bulk_create_with_batch_size(self):
all_, [{"id": val + 1, "name": None} for val in range(1000)], sorted_key="id"
)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_with_specified(self):
await UniqueName.bulk_create([UniqueName(id=id_) for id_ in range(1000, 2000)])
all_ = await UniqueName.all().values("id", "name")
self.assertListSortEqual(
all_, [{"id": id_, "name": None} for id_ in range(1000, 2000)], sorted_key="id"
)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_mix_specified(self):
await UniqueName.bulk_create(
[UniqueName(id=id_) for id_ in range(10000, 11000)]
Expand All @@ -77,6 +83,7 @@ async def test_bulk_create_uuidpk(self):
self.assertIsInstance(res[0], UUID)

@test.requireCapability(supports_transactions=True)
@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_in_transaction(self):
async with in_transaction():
await UniqueName.bulk_create([UniqueName() for _ in range(1000)])
Expand All @@ -92,6 +99,7 @@ async def test_bulk_create_uuidpk_in_transaction(self):
self.assertEqual(len(res), 1000)
self.assertIsInstance(res[0], UUID)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_fail(self):
with self.assertRaises(IntegrityError):
await UniqueName.bulk_create(
Expand All @@ -104,7 +112,7 @@ async def test_bulk_create_uuidpk_fail(self):
with self.assertRaises(IntegrityError):
await UUIDPkModel.bulk_create([UUIDPkModel(id=val) for _ in range(10)])

@test.requireCapability(supports_transactions=True)
@test.requireCapability(supports_transactions=True, dialect=NotEQ("mssql"))
async def test_bulk_create_in_transaction_fail(self):
with self.assertRaises(IntegrityError):
async with in_transaction():
Expand All @@ -120,6 +128,7 @@ async def test_bulk_create_uuidpk_in_transaction_fail(self):
async with in_transaction():
await UUIDPkModel.bulk_create([UUIDPkModel(id=val) for _ in range(10)])

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_create_ignore_conflicts(self):
name1 = UniqueName(name="name1")
name2 = UniqueName(name="name2")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from tests.testmodels import Tournament, UniqueName
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.transactions import in_transaction


Expand Down Expand Up @@ -30,6 +31,7 @@ async def test_nonconcurrent_get_or_create(self):
self.assertEqual(una[0], unas[0][0])

@test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well")
@test.requireCapability(dialect=NotEQ("mssql"))
async def test_concurrent_get_or_create(self):
unas = await asyncio.gather(*[UniqueName.get_or_create(name="d") for _ in range(10)])
una_created = [una[1] for una in unas if una[1] is True]
Expand Down
15 changes: 8 additions & 7 deletions tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from tests.testmodels import DefaultModel
from tortoise.backends.asyncpg import AsyncpgDBClient
from tortoise.backends.mssql import MSSQLClient
from tortoise.backends.mysql import MySQLClient
from tortoise.backends.psycopg import PsycopgClient
from tortoise.backends.sqlite import SqliteClient
Expand All @@ -14,17 +15,17 @@
class TestDefault(test.TestCase):
async def asyncSetUp(self) -> None:
await super(TestDefault, self).asyncSetUp()
connection = self._db
if isinstance(connection, MySQLClient):
await connection.execute_query(
db = self._db
if isinstance(db, MySQLClient):
await db.execute_query(
"insert into defaultmodel (`int_default`,`float_default`,`decimal_default`,`bool_default`,`char_default`,`date_default`,`datetime_default`) values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)",
)
elif isinstance(connection, SqliteClient):
await connection.execute_query(
elif isinstance(db, SqliteClient):
await db.execute_query(
"insert into defaultmodel default values",
)
elif isinstance(connection, (AsyncpgDBClient, PsycopgClient)):
await connection.execute_query(
elif isinstance(db, (AsyncpgDBClient, PsycopgClient, MSSQLClient)):
await db.execute_query(
'insert into defaultmodel ("int_default","float_default","decimal_default","bool_default","char_default","date_default","datetime_default") values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)',
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_early_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_early_init(self):
"constraints": {"readOnly": True},
"db_field_types": {
"": "TIMESTAMP",
"mssql": "DATETIME",
"mssql": "DATETIME2",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
"oracle": "TIMESTAMP WITH TIME ZONE",
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_early_init(self):
"db_column": "created_at",
"db_field_types": {
"": "TIMESTAMP",
"mssql": "DATETIME",
"mssql": "DATETIME2",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
"oracle": "TIMESTAMP WITH TIME ZONE",
Expand Down
3 changes: 3 additions & 0 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from tests.testmodels import DatetimeFields, Event, IntFields, Reporter, Team, Tournament
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.expressions import F, Q
from tortoise.functions import Coalesce, Count, Length, Lower, Max, Trim, Upper

Expand Down Expand Up @@ -299,6 +300,7 @@ async def test_filter_by_aggregation_field_trim(self):
self.assertEqual(len(tournaments), 1)
self.assertSetEqual({(t.name, t.trimmed_name) for t in tournaments}, {(" 1 ", "1")})

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_filter_by_aggregation_field_length(self):
await Tournament.create(name="12345")
await Tournament.create(name="123")
Expand Down Expand Up @@ -343,6 +345,7 @@ async def test_filter_by_aggregation_field_comparison_coalesce_numeric(self):
self.assertEqual(len(ints), 2)
self.assertSetEqual({i.clean_intnum_null for i in ints}, {10, 4})

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_filter_by_aggregation_field_comparison_length(self):
t1 = await Tournament.create(name="Tournament")
await Event.create(name="event1", tournament=t1)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_fuzz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tests.testmodels import CharFields
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ

DODGY_STRINGS = [
"a/",
Expand Down Expand Up @@ -95,10 +96,11 @@


class TestFuzz(test.TestCase):
@test.requireCapability(dialect=NotEQ("mssql"))
async def test_char_fuzz(self):
for char in DODGY_STRINGS:
# print(repr(char))
if "\x00" in char and self._db.capabilities.dialect == "postgres":
if "\x00" in char and self._db.capabilities.dialect in ["postgres"]:
# PostgreSQL doesn't support null values as text. Ever. So skip these.
continue

Expand Down
3 changes: 3 additions & 0 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UUIDFkRelatedNullModel,
)
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.exceptions import (
ConfigurationError,
DoesNotExist,
Expand All @@ -38,6 +39,7 @@ async def test_save_non_generated(self):
mdl2 = await UUIDFkRelatedNullModel.get(id=mdl.id)
self.assertEqual(mdl, mdl2)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_save_generated_custom_id(self):
cid = 12345
mdl = await Tournament.create(id=cid, name="Test")
Expand All @@ -52,6 +54,7 @@ async def test_save_non_generated_custom_id(self):
mdl2 = await UUIDFkRelatedNullModel.get(id=cid)
self.assertEqual(mdl, mdl2)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_save_generated_duplicate_custom_id(self):
cid = 12345
await Tournament.create(id=cid, name="TestOriginal")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tests.testmodels import Event, IntFields, MinRelation, Node, Reporter, Tournament, Tree
from tortoise import connections
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.exceptions import (
DoesNotExist,
FieldError,
Expand Down Expand Up @@ -583,6 +584,7 @@ async def test_raw_sql_count(self):
ret = await Tournament.filter(pk=t1.pk).annotate(count=RawSQL("count(*)")).values("count")
self.assertEqual(ret, [{"count": 1}])

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_raw_sql_select(self):
t1 = await Tournament.create(id=1, name="1")
ret = (
Expand Down
2 changes: 2 additions & 0 deletions tests/test_source_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from tests.testmodels import NumberSourceField, SourceFields, StraightFields
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.expressions import F, Q
from tortoise.functions import Coalesce, Count, Length, Lower, Trim, Upper

Expand Down Expand Up @@ -218,6 +219,7 @@ async def test_filter_by_aggregation_field_count(self):
self.assertEqual(len(obj), 1)
self.assertEqual(obj[0].chars, "aaa")

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_filter_by_aggregation_field_length(self):
await self.model.create(chars="aaa")
await self.model.create(chars="bbbbb")
Expand Down
5 changes: 3 additions & 2 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
UUIDFields,
)
from tortoise.contrib import test
from tortoise.contrib.test.condition import In, NotEQ
from tortoise.expressions import F


Expand Down Expand Up @@ -76,6 +77,7 @@ async def test_bulk_update_json_value(self):
self.assertEqual((await JSONFields.get(pk=objs[0].pk)).data, objs[0].data)
self.assertEqual((await JSONFields.get(pk=objs[1].pk)).data, objs[1].data)

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_bulk_update_smallint_none(self):
objs = [
await SmallIntFields.create(smallintnum=1, smallintnum_null=1),
Expand Down Expand Up @@ -107,8 +109,7 @@ async def test_update_relation(self):
event = await Event.first()
self.assertEqual(event.tournament_id, tournament_second.id)

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="sqlite")
@test.requireCapability(dialect=In("mysql", "sqlite"))
async def test_update_with_custom_function(self):
class JsonSet(Function):
def __init__(self, field: F, expression: str, value: Any):
Expand Down
Loading

0 comments on commit 316d711

Please sign in to comment.