From 4a788c754b614d13b2a58dd6453bc24407540a9a Mon Sep 17 00:00:00 2001 From: Bryce Boe Date: Wed, 24 Feb 2021 00:06:45 -0600 Subject: [PATCH 1/3] Add SQLite3TokenManager and associated example (cherry picked from commit bb0e0070a89b2e801e691a7161074d83378bd81a) --- CHANGES.rst | 1 + asyncpraw/util/token_manager.py | 129 +++++++++++++++++++++- docs/examples/use_sqlite_token_manager.py | 71 ++++++++++++ docs/tutorials/refresh_token.rst | 10 ++ setup.py | 8 +- tests/unit/util/test_token_manager.py | 69 +++++++++++- 6 files changed, 284 insertions(+), 4 deletions(-) create mode 100755 docs/examples/use_sqlite_token_manager.py diff --git a/CHANGES.rst b/CHANGES.rst index 6a1c4ce4..633f0ab5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,7 @@ Unreleased - :meth:`.trusted` to retrieve a :class:`.RedditorList` of trusted users. - :meth:`.trust` to add a user to the trusted list. - :meth:`.distrust` to remove a user from the trusted list. +- :class:`.SQLiteTokenManager` (may not work on Windows) **Changed** diff --git a/asyncpraw/util/token_manager.py b/asyncpraw/util/token_manager.py index 614a253c..c273bd44 100644 --- a/asyncpraw/util/token_manager.py +++ b/asyncpraw/util/token_manager.py @@ -6,12 +6,14 @@ A few trivial token manager classes are provided here, but it is expected that Async PRAW users will create their own token manager classes suitable for their needs. -See ref:`using_refresh_tokens` for examples on how to leverage these classes. +See :ref:`using_refresh_tokens` for examples on how to leverage these classes. """ from abc import ABC, abstractmethod import aiofiles +import aiosqlite +from asyncio_extras import async_contextmanager class BaseTokenManager(ABC): @@ -60,7 +62,19 @@ def pre_refresh_callback(self, authorizer): class FileTokenManager(BaseTokenManager): - """Provides a trivial single-file based token manager.""" + """Provides a trivial single-file based token manager. + + It is expected that the file with the initial ``refresh_token`` is created prior to + use. + + .. warning:: + + The same ``file`` should not be used by more than one instance of this class + concurrently. Doing so may result in data corruption. Consider using + :class:`.SQLiteTokenManager` if you want more than one instance of PRAW to + concurrently manage a specific ``refresh_token`` chain. + + """ def __init__(self, filename): """Load and save refresh tokens from a file. @@ -81,3 +95,114 @@ async def pre_refresh_callback(self, authorizer): if authorizer.refresh_token is None: async with aiofiles.open(self._filename) as fp: authorizer.refresh_token = (await fp.read()).strip() + + +class SQLiteTokenManager(BaseTokenManager): + """Provides a SQLite3 based token manager. + + Unlike, :class:`.FileTokenManager`, the initial database need not be created ahead + of time, as it'll automatically be created on first use. However, initial + ``refresh_tokens`` will need to be registered via :meth:`.register` prior to use. + See :ref:`sqlite_token_manager` for an example of use. + + """ + + def __init__(self, database, key): + """Load and save refresh tokens from a SQLite database. + + :param database: The path to the SQLite database. + :param key: The key used to locate the ``refresh_token``. This ``key`` can be + anything. You might use the ``client_id`` if you expect to have unique + ``refresh_tokens`` for each ``client_id``, or you might use a Redditor's + ``username`` if you're manage multiple users' authentications. + + """ + super().__init__() + self._connection = None + self._database = database + self._setup_ran = False + self.key = key + + @async_contextmanager + async def connection(self): + """Asynchronously setup and provide the sqlite3 connection.""" + if self._connection is None: + self._connection = await aiosqlite.connect(self._database) + if not self._setup_ran: + await self._connection.execute( + "CREATE TABLE IF NOT EXISTS tokens (id, refresh_token, updated_at)" + ) + await self._connection.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS ux_tokens_id on tokens(id)" + ) + await self._connection.commit() + self._setup_ran = True + yield self._connection + + async def _get(self): + async with self.connection() as conn: + cursor = await conn.execute( + "SELECT refresh_token FROM tokens WHERE id=?", (self.key,) + ) + result = await cursor.fetchone() + if result is None: + raise KeyError + return result[0] + + async def _set(self, refresh_token): + """Set the refresh token in the database. + + This function will overwrite an existing value if the corresponding ``key`` + already exists. + + """ + async with self.connection() as conn: + await conn.execute( + "REPLACE INTO tokens VALUES (?, ?, datetime('now'))", + (self.key, refresh_token), + ) + await conn.commit() + + async def close(self): + """Close the sqlite3 connection.""" + await self._connection.close() + + async def is_registered(self): + """Return whether ``key`` already has a ``refresh_token``.""" + async with self.connection() as conn: + cursor = await conn.execute( + "SELECT refresh_token FROM tokens WHERE id=?", (self.key,) + ) + result = await cursor.fetchone() + return result is not None + + async def post_refresh_callback(self, authorizer): + """Update the refresh token in the database.""" + await self._set(authorizer.refresh_token) + + # While the following line is not strictly necessary, it ensures that the + # refresh token is not used elsewhere. And also forces the pre_refresh_callback + # to always load the latest refresh_token from the database. + authorizer.refresh_token = None + + async def pre_refresh_callback(self, authorizer): + """Load the refresh token from the database.""" + assert authorizer.refresh_token is None + authorizer.refresh_token = await self._get() + + async def register(self, refresh_token): + """Register the initial refresh token in the database. + + :returns: ``True`` if ``refresh_token`` is saved to the database, otherwise, + ``False`` if there is already a ``refresh_token`` for the associated + ``key``. + + """ + async with self.connection() as conn: + cursor = await conn.execute( + "INSERT OR IGNORE INTO tokens VALUES (?, ?, datetime('now'))", + (self.key, refresh_token), + ) + await conn.commit() + row_count = cursor.rowcount + return row_count == 1 diff --git a/docs/examples/use_sqlite_token_manager.py b/docs/examples/use_sqlite_token_manager.py new file mode 100755 index 00000000..10cbd8c8 --- /dev/null +++ b/docs/examples/use_sqlite_token_manager.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""This example demonstrates using the sqlite token manager for refresh tokens. + +In order to run this program, you will first need to obtain one or more valid refresh +tokens. You can use the ``obtain_refresh_token.py`` example to help. + +In this example, refresh tokens will be saved into a file ``tokens.sqlite3`` relative to +your current working directory. If your current working directory is under version +control it is strongly encouraged you add ``tokens.sqlite3`` to the version control ignore +list. + +This example differs primarily from ``use_file_token_manager.py`` due to the fact that a +shared SQLite3 database can manage many ``refresh_tokens``. While each instance of +Reddit still needs to have 1-to-1 mapping to a token manager, multiple Reddit instances +can concurrently share access to the same SQLite3 database; the same cannot be done with +the FileTokenManager. + +Usage: + + EXPORT praw_client_id= + EXPORT praw_client_secret= + python3 use_sqlite_token_manager.py TOKEN_KEY + +""" +import os +import sys + +import praw +from praw.util.token_manager import SQLiteTokenManager + +DATABASE_PATH = "tokens.sqlite3" + + +def main(): + if "praw_client_id" not in os.environ: + sys.stderr.write("Environment variable ``praw_client_id`` must be defined\n") + return 1 + if "praw_client_secret" not in os.environ: + sys.stderr.write( + "Environment variable ``praw_client_secret`` must be defined\n" + ) + return 1 + if len(sys.argv) != 2: + sys.stderr.write( + "KEY must be provided.\n\nUsage: python3 use_sqlite_token_manager.py TOKEN_KEY\n" + ) + return 1 + + refresh_token_manager = SQLiteTokenManager(DATABASE_PATH, key=sys.argv[1]) + reddit = praw.Reddit( + token_manager=refresh_token_manager, + user_agent="sqlite_token_manager/v0 by u/bboe", + ) + + if not refresh_token_manager.is_registered(): + refresh_token = input("Enter initial refresh token: ").strip() + refresh_token_manager.register(refresh_token) + + scopes = reddit.auth.scopes() + if scopes == {"*"}: + print(f"{reddit.user.me()} is authenticated with all scopes") + elif "identity" in scopes: + print( + f"{reddit.user.me()} is authenticated with the following scopes: {scopes}" + ) + else: + print(f"You are authenticated with the following scopes: {scopes}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/docs/tutorials/refresh_token.rst b/docs/tutorials/refresh_token.rst index c6ddbb1b..de921b57 100644 --- a/docs/tutorials/refresh_token.rst +++ b/docs/tutorials/refresh_token.rst @@ -102,3 +102,13 @@ valid refresh token. .. literalinclude:: ../examples/use_file_token_manager.py :language: python + +.. _sqlite_token_manager: + +SQLiteTokenManager +~~~~~~~~~~~~~~~~~~ + +For more complex examples, PRAW provides the :class:`.SQLiteTokenManager`. + +.. literalinclude:: ../examples/use_sqlite_token_manager.py + :language: python diff --git a/setup.py b/setup.py index 4ca8f50e..83bf7a6d 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,13 @@ " python package that allows for simple access to reddit's API." ), extras_require=extras, - install_requires=["aiofiles", "asyncprawcore >=2.1, <3", "update_checker >=0.18"], + install_requires=[ + "aiofiles <=0.6.0", + "aiosqlite <=0.17.0", + "asyncio_extras <=1.3.2", + "asyncprawcore >=2.1, <3", + "update_checker >=0.18", + ], keywords="reddit api wrapper asyncpraw praw async asynchronous", license="Simplified BSD License", long_description=README, diff --git a/tests/unit/util/test_token_manager.py b/tests/unit/util/test_token_manager.py index 29586102..18b843f3 100644 --- a/tests/unit/util/test_token_manager.py +++ b/tests/unit/util/test_token_manager.py @@ -1,9 +1,15 @@ """Test asyncpraw.util.refresh_token_manager.""" +from tempfile import NamedTemporaryFile + import aiofiles import pytest from asynctest import mock -from asyncpraw.util.token_manager import BaseTokenManager, FileTokenManager +from asyncpraw.util.token_manager import ( + BaseTokenManager, + FileTokenManager, + SQLiteTokenManager, +) from .. import UnitTest from ..test_reddit import DummyTokenManager @@ -89,3 +95,64 @@ async def test_pre_refresh_token_callback__reads_from_file(self): closefd=True, opener=None, ) + + +class TestSQLiteTokenManager(UnitTest): + def setUp(self): + self.manager = SQLiteTokenManager(":memory:", "dummy_key") + + async def test_is_registered(self): + assert not await self.manager.is_registered() + await self.manager.close() + + async def test_multiple_instances(self): + with NamedTemporaryFile() as fp: + manager1 = SQLiteTokenManager(fp.name, "dummy_key1") + manager2 = SQLiteTokenManager(fp.name, "dummy_key1") + manager3 = SQLiteTokenManager(fp.name, "dummy_key2") + + await manager1.register("dummy_value1") + assert await manager2.is_registered() + assert not await manager3.is_registered() + await manager1.close() + await manager2.close() + await manager3.close() + + async def test_post_refresh_token_callback__sets_value(self): + authorizer = DummyAuthorizer("dummy_value") + + await self.manager.post_refresh_callback(authorizer) + assert authorizer.refresh_token is None + assert await self.manager._get() == "dummy_value" + await self.manager.close() + + async def test_post_refresh_token_callback__updates_value(self): + authorizer = DummyAuthorizer("dummy_value_updated") + await self.manager.register("dummy_value") + + await self.manager.post_refresh_callback(authorizer) + assert authorizer.refresh_token is None + assert await self.manager._get() == "dummy_value_updated" + await self.manager.close() + + async def test_pre_refresh_token_callback(self): + authorizer = DummyAuthorizer(None) + await self.manager.register("dummy_value") + + await self.manager.pre_refresh_callback(authorizer) + assert authorizer.refresh_token == "dummy_value" + await self.manager.close() + + async def test_pre_refresh_token_callback__raises_key_error(self): + authorizer = DummyAuthorizer(None) + + with pytest.raises(KeyError): + await self.manager.pre_refresh_callback(authorizer) + await self.manager.close() + + async def test_register(self): + assert await self.manager.register("dummy_value") + assert await self.manager.is_registered() + assert not await self.manager.register("dummy_value2") + assert await self.manager._get() == "dummy_value" + await self.manager.close() From bf7dd6d6b3876269e53c90998173ec173519165c Mon Sep 17 00:00:00 2001 From: Bryce Boe Date: Thu, 10 Jun 2021 22:27:24 -0500 Subject: [PATCH 2/3] Provide a warning for SQLiteTokenManager not being tested on Windows (cherry picked from commit bb0e00f95a40bc13db46a90ee6bb621af687197c) --- asyncpraw/util/token_manager.py | 5 +++++ tests/unit/util/test_token_manager.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/asyncpraw/util/token_manager.py b/asyncpraw/util/token_manager.py index c273bd44..467ad3b8 100644 --- a/asyncpraw/util/token_manager.py +++ b/asyncpraw/util/token_manager.py @@ -105,6 +105,11 @@ class SQLiteTokenManager(BaseTokenManager): ``refresh_tokens`` will need to be registered via :meth:`.register` prior to use. See :ref:`sqlite_token_manager` for an example of use. + .. warning:: + + This class is untested on Windows because we encountered file locking issues in + the test environment. + """ def __init__(self, database, key): diff --git a/tests/unit/util/test_token_manager.py b/tests/unit/util/test_token_manager.py index 18b843f3..cf48f710 100644 --- a/tests/unit/util/test_token_manager.py +++ b/tests/unit/util/test_token_manager.py @@ -1,4 +1,5 @@ """Test asyncpraw.util.refresh_token_manager.""" +import sys from tempfile import NamedTemporaryFile import aiofiles @@ -105,6 +106,9 @@ async def test_is_registered(self): assert not await self.manager.is_registered() await self.manager.close() + @pytest.mark.skipif( + sys.platform.startswith("win"), reason="this test fails on windows" + ) async def test_multiple_instances(self): with NamedTemporaryFile() as fp: manager1 = SQLiteTokenManager(fp.name, "dummy_key1") From 874bf138746043e9f5531b37b6caf854d17c6a94 Mon Sep 17 00:00:00 2001 From: Bryce Boe Date: Sun, 13 Jun 2021 21:04:45 -0500 Subject: [PATCH 3/3] Replace 'trivial' with 'proof of concept' (cherry picked from commit bb0e00025f4cf10be64469d8998cc6ae1d81ca8c) --- asyncpraw/util/token_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncpraw/util/token_manager.py b/asyncpraw/util/token_manager.py index 467ad3b8..299a7134 100644 --- a/asyncpraw/util/token_manager.py +++ b/asyncpraw/util/token_manager.py @@ -3,8 +3,8 @@ There should be a 1-to-1 mapping between an instance of a subclass of :class:`.BaseTokenManager` and a :class:`.Reddit` instance. -A few trivial token manager classes are provided here, but it is expected that Async -PRAW users will create their own token manager classes suitable for their needs. +A few proof of concept token manager classes are provided here, but it is expected that +Async PRAW users will create their own token manager classes suitable for their needs. See :ref:`using_refresh_tokens` for examples on how to leverage these classes. @@ -62,7 +62,7 @@ def pre_refresh_callback(self, authorizer): class FileTokenManager(BaseTokenManager): - """Provides a trivial single-file based token manager. + """Provides a single-file based token manager. It is expected that the file with the initial ``refresh_token`` is created prior to use.