Skip to content

Commit

Permalink
Simplify SQLiteCache.bulk_commit() and add a test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
JWCook committed Mar 1, 2021
1 parent de2fa1d commit a0c03c6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 46 deletions.
2 changes: 1 addition & 1 deletion aiohttp_client_cache/backends/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _scan_table(self) -> Dict:

@staticmethod
def unpickle(response_item: Dict) -> ResponseOrKey:
return super().unpickle((response_item or {}).get('value'))
return BaseCache.unpickle((response_item or {}).get('value'))

async def clear(self):
response = self._scan_table()
Expand Down
69 changes: 24 additions & 45 deletions aiohttp_client_cache/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,58 +51,35 @@ class SQLiteCache(BaseCache):
def __init__(self, filename: str, table_name: str):
self.filename = filename
self.table_name = table_name
self.can_commit = True # Transactions can be committed if this is set to `True`
self._can_commit = True # Transactions can be committed if this is set to `True`

self._bulk_commit = False
self._initialized = False
self._pending_connection = None
self._connection = None
self._lock = asyncio.Lock()

async def _get_pending_connection(self):
"""Use/create pending connection if doing a bulk commit"""
if not self._pending_connection:
self._pending_connection = await aiosqlite.connect(self.filename)
return self._pending_connection

async def _close_pending_connection(self):
if self._pending_connection:
await self._pending_connection.close()
self._pending_connection = None

async def _init_connection(self, db: aiosqlite.Connection):
"""Create table if this is the first connection opened, and set fast save if specified"""
await db.execute('PRAGMA synchronous = 0;')
if not self._initialized:
await db.execute(
f'CREATE TABLE IF NOT EXISTS `{self.table_name}` (key PRIMARY KEY, value)'
)
self._initialized = True
return db

@asynccontextmanager
async def get_connection(self, autocommit: bool = False) -> AsyncIterator[aiosqlite.Connection]:
async with self._lock:
if self._bulk_commit:
db = await self._get_pending_connection()
else:
db = await aiosqlite.connect(self.filename)
db = self._connection if self._connection else await aiosqlite.connect(self.filename)
try:
yield await self._init_connection(db)
if autocommit and self.can_commit:
yield await self._init_db(db)
if autocommit and self._can_commit:
await db.commit()
finally:
if not self._bulk_commit:
await db.close()

async def commit(self, force: bool = False):
"""
Commits pending transaction if :attr:`can_commit` or `force` is `True`
Args:
force: force commit, ignore :attr:`can_commit`
"""
if (force or self.can_commit) and self._pending_connection:
await self._pending_connection.commit()
async def _init_db(self, db: aiosqlite.Connection):
"""Create table if this is the first connection opened, and set fast save if possible"""
if not self._bulk_commit:
await db.execute('PRAGMA synchronous = 0;')
if not self._initialized:
await db.execute(
f'CREATE TABLE IF NOT EXISTS `{self.table_name}` (key PRIMARY KEY, value)'
)
self._initialized = True
return db

@asynccontextmanager
async def bulk_commit(self):
Expand All @@ -111,21 +88,23 @@ async def bulk_commit(self):
Example:
>>> d1 = SQLiteCache('test')
>>> async with d1.bulk_commit():
>>> cache = SQLiteCache('test')
>>> async with cache.bulk_commit():
... for i in range(1000):
... d1[i] = i * 2
... await cache.write(f'key_{i}', str(i * 2))
"""
self._bulk_commit = True
self.can_commit = False
self._can_commit = False
self._connection = await aiosqlite.connect(self.filename)
try:
yield
await self.commit(force=True)
await self._connection.commit()
finally:
self._bulk_commit = False
self.can_commit = True
await self._close_pending_connection()
self._can_commit = True
await self._connection.close()
self._connection = None

async def clear(self):
async with self.get_connection(autocommit=True) as db:
Expand Down
8 changes: 8 additions & 0 deletions test/integration/test_sqlite_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ async def test_write_read(cache_client):
assert await cache_client.read(k) == v


async def test_bulk_commit(cache_client):
async with cache_client.bulk_commit():
for i in range(1000):
await cache_client.write(f'key_{i}', str(i * 2))

assert await cache_client.size() == 1000


async def test_delete(cache_client):
for k, v in test_data.items():
await cache_client.write(k, v)
Expand Down

0 comments on commit a0c03c6

Please sign in to comment.