diff --git a/openapi/db/dbmodel.py b/openapi/db/dbmodel.py index 7e19d56..8bd363d 100644 --- a/openapi/db/dbmodel.py +++ b/openapi/db/dbmodel.py @@ -17,6 +17,13 @@ async def db_delete(self, table, filters, *, conn=None, consumer=None): async with self.ensure_connection(conn) as conn: return await conn.fetch(sql, *args) + async def db_count(self, table, filters, *, conn=None, consumer=None): + query = self.get_query(table, table.select(), consumer, filters) + sql, args = compile_query(query.alias('inner').count()) + async with self.ensure_connection(conn) as conn: + total = await conn.fetchrow(sql, *args) + return total['tbl_row_count'] + async def db_insert(self, table, data, *, conn=None): async with self.ensure_connection(conn) as conn: statement, args = self.get_insert(table, data) diff --git a/tests/test_db_model.py b/tests/test_db_model.py index 8b10db9..d489d5b 100644 --- a/tests/test_db_model.py +++ b/tests/test_db_model.py @@ -7,3 +7,12 @@ async def test_get_attr(cli): with pytest.raises(AttributeError) as ex_info: db.fooooo assert 'fooooo' in str(ex_info.value) + + +async def test_db_count(cli): + db = cli.app['db'] + n = await db.db_count(db.tasks, {}) + assert n == 0 + await db.db_insert(db.tasks, dict(title='testing rollback')) + n = await db.db_count(db.tasks, {}) + assert n == 1