forked from MagicStack/asyncpg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_timeout.py
154 lines (123 loc) · 6.34 KB
/
test_timeout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import asyncpg
from asyncpg import connection as pg_connection
from asyncpg import _testbase as tb
MAX_RUNTIME = 0.5
class TestTimeout(tb.ConnectedTestCase):
async def test_timeout_01(self):
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
meth = getattr(self.con, methname)
await meth('select pg_sleep(10)', timeout=0.02)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_02(self):
st = await self.con.prepare('select pg_sleep(10)')
for methname in {'fetch', 'fetchrow', 'fetchval'}:
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
meth = getattr(st, methname)
await meth(timeout=0.02)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_03(self):
task = self.loop.create_task(
self.con.fetch('select pg_sleep(10)', timeout=0.2))
await asyncio.sleep(0.05)
task.cancel()
with self.assertRaises(asyncio.CancelledError), \
self.assertRunUnder(MAX_RUNTIME):
await task
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_04(self):
st = await self.con.prepare('select pg_sleep(10)', timeout=0.1)
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
async with self.con.transaction():
async for _ in st.cursor(timeout=0.1): # NOQA
pass
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
st = await self.con.prepare('select pg_sleep(10)', timeout=0.1)
async with self.con.transaction():
cur = await st.cursor()
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetch(1, timeout=0.1)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_05(self):
# Stress-test timeouts - try to trigger a race condition
# between a cancellation request to Postgres and next
# query (SELECT 1)
for _ in range(500):
with self.assertRaises(asyncio.TimeoutError):
await self.con.fetch('SELECT pg_sleep(1)', timeout=1e-10)
self.assertEqual(await self.con.fetch('SELECT 1'), [(1,)])
async def test_timeout_06(self):
async with self.con.transaction():
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
async for _ in self.con.cursor( # NOQA
'select pg_sleep(10)', timeout=0.1):
pass
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetch(1, timeout=0.1)
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.forward(1, timeout=1e-10)
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetchrow(timeout=0.1)
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetchrow(timeout=0.1)
with self.assertRaises(asyncpg.InFailedSQLTransactionError):
await cur.fetch(1)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_invalid_timeout(self):
for command_timeout in ('a', False, -1):
with self.subTest(command_timeout=command_timeout):
with self.assertRaisesRegex(ValueError,
'invalid command_timeout'):
await self.connect(command_timeout=command_timeout)
# Note: negative timeouts are OK for method calls.
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
for timeout in ('a', False):
with self.subTest(timeout=timeout):
with self.assertRaisesRegex(ValueError, 'invalid timeout'):
await self.con.execute('SELECT 1', timeout=timeout)
class TestConnectionCommandTimeout(tb.ConnectedTestCase):
@tb.with_connection_options(command_timeout=0.2)
async def test_command_timeout_01(self):
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
meth = getattr(self.con, methname)
await meth('select pg_sleep(10)')
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
class SlowPrepareConnection(pg_connection.Connection):
"""Connection class to test timeouts."""
async def _get_statement(self, query, timeout, **kwargs):
await asyncio.sleep(0.3)
return await super()._get_statement(query, timeout, **kwargs)
class TestTimeoutCoversPrepare(tb.ConnectedTestCase):
@tb.with_connection_options(connection_class=SlowPrepareConnection,
command_timeout=0.3)
async def test_timeout_covers_prepare_01(self):
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
with self.assertRaises(asyncio.TimeoutError):
meth = getattr(self.con, methname)
await meth('select pg_sleep($1)', 0.2)