Skip to content

Commit

Permalink
Merge 874bf13 into 9dc53ae
Browse files Browse the repository at this point in the history
  • Loading branch information
LilSpazJoekp committed Jun 18, 2021
2 parents 9dc53ae + 874bf13 commit 8fb5fb8
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Unreleased
- :meth:`.trusted` to retrieve a :class:`.RedditorList` of trusted users.
- :meth:`.trust` to add a user to the trusted list.
- :meth:`.distrust` to remove a user from the trusted list.
- :class:`.SQLiteTokenManager` (may not work on Windows)

**Changed**

Expand Down
138 changes: 134 additions & 4 deletions asyncpraw/util/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
There should be a 1-to-1 mapping between an instance of a subclass of
:class:`.BaseTokenManager` and a :class:`.Reddit` instance.
A few trivial token manager classes are provided here, but it is expected that Async
PRAW users will create their own token manager classes suitable for their needs.
A few proof of concept token manager classes are provided here, but it is expected that
Async PRAW users will create their own token manager classes suitable for their needs.
See ref:`using_refresh_tokens` for examples on how to leverage these classes.
See :ref:`using_refresh_tokens` for examples on how to leverage these classes.
"""
from abc import ABC, abstractmethod

import aiofiles
import aiosqlite
from asyncio_extras import async_contextmanager


class BaseTokenManager(ABC):
Expand Down Expand Up @@ -60,7 +62,19 @@ def pre_refresh_callback(self, authorizer):


class FileTokenManager(BaseTokenManager):
"""Provides a trivial single-file based token manager."""
"""Provides a single-file based token manager.
It is expected that the file with the initial ``refresh_token`` is created prior to
use.
.. warning::
The same ``file`` should not be used by more than one instance of this class
concurrently. Doing so may result in data corruption. Consider using
:class:`.SQLiteTokenManager` if you want more than one instance of PRAW to
concurrently manage a specific ``refresh_token`` chain.
"""

def __init__(self, filename):
"""Load and save refresh tokens from a file.
Expand All @@ -81,3 +95,119 @@ async def pre_refresh_callback(self, authorizer):
if authorizer.refresh_token is None:
async with aiofiles.open(self._filename) as fp:
authorizer.refresh_token = (await fp.read()).strip()


class SQLiteTokenManager(BaseTokenManager):
"""Provides a SQLite3 based token manager.
Unlike, :class:`.FileTokenManager`, the initial database need not be created ahead
of time, as it'll automatically be created on first use. However, initial
``refresh_tokens`` will need to be registered via :meth:`.register` prior to use.
See :ref:`sqlite_token_manager` for an example of use.
.. warning::
This class is untested on Windows because we encountered file locking issues in
the test environment.
"""

def __init__(self, database, key):
"""Load and save refresh tokens from a SQLite database.
:param database: The path to the SQLite database.
:param key: The key used to locate the ``refresh_token``. This ``key`` can be
anything. You might use the ``client_id`` if you expect to have unique
``refresh_tokens`` for each ``client_id``, or you might use a Redditor's
``username`` if you're manage multiple users' authentications.
"""
super().__init__()
self._connection = None
self._database = database
self._setup_ran = False
self.key = key

@async_contextmanager
async def connection(self):
"""Asynchronously setup and provide the sqlite3 connection."""
if self._connection is None:
self._connection = await aiosqlite.connect(self._database)
if not self._setup_ran:
await self._connection.execute(
"CREATE TABLE IF NOT EXISTS tokens (id, refresh_token, updated_at)"
)
await self._connection.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS ux_tokens_id on tokens(id)"
)
await self._connection.commit()
self._setup_ran = True
yield self._connection

async def _get(self):
async with self.connection() as conn:
cursor = await conn.execute(
"SELECT refresh_token FROM tokens WHERE id=?", (self.key,)
)
result = await cursor.fetchone()
if result is None:
raise KeyError
return result[0]

async def _set(self, refresh_token):
"""Set the refresh token in the database.
This function will overwrite an existing value if the corresponding ``key``
already exists.
"""
async with self.connection() as conn:
await conn.execute(
"REPLACE INTO tokens VALUES (?, ?, datetime('now'))",
(self.key, refresh_token),
)
await conn.commit()

async def close(self):
"""Close the sqlite3 connection."""
await self._connection.close()

async def is_registered(self):
"""Return whether ``key`` already has a ``refresh_token``."""
async with self.connection() as conn:
cursor = await conn.execute(
"SELECT refresh_token FROM tokens WHERE id=?", (self.key,)
)
result = await cursor.fetchone()
return result is not None

async def post_refresh_callback(self, authorizer):
"""Update the refresh token in the database."""
await self._set(authorizer.refresh_token)

# While the following line is not strictly necessary, it ensures that the
# refresh token is not used elsewhere. And also forces the pre_refresh_callback
# to always load the latest refresh_token from the database.
authorizer.refresh_token = None

async def pre_refresh_callback(self, authorizer):
"""Load the refresh token from the database."""
assert authorizer.refresh_token is None
authorizer.refresh_token = await self._get()

async def register(self, refresh_token):
"""Register the initial refresh token in the database.
:returns: ``True`` if ``refresh_token`` is saved to the database, otherwise,
``False`` if there is already a ``refresh_token`` for the associated
``key``.
"""
async with self.connection() as conn:
cursor = await conn.execute(
"INSERT OR IGNORE INTO tokens VALUES (?, ?, datetime('now'))",
(self.key, refresh_token),
)
await conn.commit()
row_count = cursor.rowcount
return row_count == 1
71 changes: 71 additions & 0 deletions docs/examples/use_sqlite_token_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
"""This example demonstrates using the sqlite token manager for refresh tokens.
In order to run this program, you will first need to obtain one or more valid refresh
tokens. You can use the ``obtain_refresh_token.py`` example to help.
In this example, refresh tokens will be saved into a file ``tokens.sqlite3`` relative to
your current working directory. If your current working directory is under version
control it is strongly encouraged you add ``tokens.sqlite3`` to the version control ignore
list.
This example differs primarily from ``use_file_token_manager.py`` due to the fact that a
shared SQLite3 database can manage many ``refresh_tokens``. While each instance of
Reddit still needs to have 1-to-1 mapping to a token manager, multiple Reddit instances
can concurrently share access to the same SQLite3 database; the same cannot be done with
the FileTokenManager.
Usage:
EXPORT praw_client_id=<REDDIT_CLIENT_ID>
EXPORT praw_client_secret=<REDDIT_CLIENT_SECRET>
python3 use_sqlite_token_manager.py TOKEN_KEY
"""
import os
import sys

import praw
from praw.util.token_manager import SQLiteTokenManager

DATABASE_PATH = "tokens.sqlite3"


def main():
if "praw_client_id" not in os.environ:
sys.stderr.write("Environment variable ``praw_client_id`` must be defined\n")
return 1
if "praw_client_secret" not in os.environ:
sys.stderr.write(
"Environment variable ``praw_client_secret`` must be defined\n"
)
return 1
if len(sys.argv) != 2:
sys.stderr.write(
"KEY must be provided.\n\nUsage: python3 use_sqlite_token_manager.py TOKEN_KEY\n"
)
return 1

refresh_token_manager = SQLiteTokenManager(DATABASE_PATH, key=sys.argv[1])
reddit = praw.Reddit(
token_manager=refresh_token_manager,
user_agent="sqlite_token_manager/v0 by u/bboe",
)

if not refresh_token_manager.is_registered():
refresh_token = input("Enter initial refresh token: ").strip()
refresh_token_manager.register(refresh_token)

scopes = reddit.auth.scopes()
if scopes == {"*"}:
print(f"{reddit.user.me()} is authenticated with all scopes")
elif "identity" in scopes:
print(
f"{reddit.user.me()} is authenticated with the following scopes: {scopes}"
)
else:
print(f"You are authenticated with the following scopes: {scopes}")


if __name__ == "__main__":
sys.exit(main())
10 changes: 10 additions & 0 deletions docs/tutorials/refresh_token.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,13 @@ valid refresh token.

.. literalinclude:: ../examples/use_file_token_manager.py
:language: python

.. _sqlite_token_manager:

SQLiteTokenManager
~~~~~~~~~~~~~~~~~~

For more complex examples, PRAW provides the :class:`.SQLiteTokenManager`.

.. literalinclude:: ../examples/use_sqlite_token_manager.py
:language: python
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@
" python package that allows for simple access to reddit's API."
),
extras_require=extras,
install_requires=["aiofiles", "asyncprawcore >=2.1, <3", "update_checker >=0.18"],
install_requires=[
"aiofiles <=0.6.0",
"aiosqlite <=0.17.0",
"asyncio_extras <=1.3.2",
"asyncprawcore >=2.1, <3",
"update_checker >=0.18",
],
keywords="reddit api wrapper asyncpraw praw async asynchronous",
license="Simplified BSD License",
long_description=README,
Expand Down
73 changes: 72 additions & 1 deletion tests/unit/util/test_token_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Test asyncpraw.util.refresh_token_manager."""
import sys
from tempfile import NamedTemporaryFile

import aiofiles
import pytest
from asynctest import mock

from asyncpraw.util.token_manager import BaseTokenManager, FileTokenManager
from asyncpraw.util.token_manager import (
BaseTokenManager,
FileTokenManager,
SQLiteTokenManager,
)

from .. import UnitTest
from ..test_reddit import DummyTokenManager
Expand Down Expand Up @@ -89,3 +96,67 @@ async def test_pre_refresh_token_callback__reads_from_file(self):
closefd=True,
opener=None,
)


class TestSQLiteTokenManager(UnitTest):
def setUp(self):
self.manager = SQLiteTokenManager(":memory:", "dummy_key")

async def test_is_registered(self):
assert not await self.manager.is_registered()
await self.manager.close()

@pytest.mark.skipif(
sys.platform.startswith("win"), reason="this test fails on windows"
)
async def test_multiple_instances(self):
with NamedTemporaryFile() as fp:
manager1 = SQLiteTokenManager(fp.name, "dummy_key1")
manager2 = SQLiteTokenManager(fp.name, "dummy_key1")
manager3 = SQLiteTokenManager(fp.name, "dummy_key2")

await manager1.register("dummy_value1")
assert await manager2.is_registered()
assert not await manager3.is_registered()
await manager1.close()
await manager2.close()
await manager3.close()

async def test_post_refresh_token_callback__sets_value(self):
authorizer = DummyAuthorizer("dummy_value")

await self.manager.post_refresh_callback(authorizer)
assert authorizer.refresh_token is None
assert await self.manager._get() == "dummy_value"
await self.manager.close()

async def test_post_refresh_token_callback__updates_value(self):
authorizer = DummyAuthorizer("dummy_value_updated")
await self.manager.register("dummy_value")

await self.manager.post_refresh_callback(authorizer)
assert authorizer.refresh_token is None
assert await self.manager._get() == "dummy_value_updated"
await self.manager.close()

async def test_pre_refresh_token_callback(self):
authorizer = DummyAuthorizer(None)
await self.manager.register("dummy_value")

await self.manager.pre_refresh_callback(authorizer)
assert authorizer.refresh_token == "dummy_value"
await self.manager.close()

async def test_pre_refresh_token_callback__raises_key_error(self):
authorizer = DummyAuthorizer(None)

with pytest.raises(KeyError):
await self.manager.pre_refresh_callback(authorizer)
await self.manager.close()

async def test_register(self):
assert await self.manager.register("dummy_value")
assert await self.manager.is_registered()
assert not await self.manager.register("dummy_value2")
assert await self.manager._get() == "dummy_value"
await self.manager.close()

0 comments on commit 8fb5fb8

Please sign in to comment.