diff --git a/CHANGES.rst b/CHANGES.rst index 97381a7f0..99aa39624 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,7 @@ Unreleased - :meth:`.trusted` to retrieve a :class:`.RedditorList` of trusted users. - :meth:`.trust` to add a user to the trusted list. - :meth:`.distrust` to remove a user from the trusted list. +- :class:`.SQLiteTokenManager` (may not work on Windows) **Changed** diff --git a/docs/examples/use_sqlite_token_manager.py b/docs/examples/use_sqlite_token_manager.py new file mode 100755 index 000000000..10cbd8c8b --- /dev/null +++ b/docs/examples/use_sqlite_token_manager.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""This example demonstrates using the sqlite token manager for refresh tokens. + +In order to run this program, you will first need to obtain one or more valid refresh +tokens. You can use the ``obtain_refresh_token.py`` example to help. + +In this example, refresh tokens will be saved into a file ``tokens.sqlite3`` relative to +your current working directory. If your current working directory is under version +control it is strongly encouraged you add ``tokens.sqlite3`` to the version control ignore +list. + +This example differs primarily from ``use_file_token_manager.py`` due to the fact that a +shared SQLite3 database can manage many ``refresh_tokens``. While each instance of +Reddit still needs to have 1-to-1 mapping to a token manager, multiple Reddit instances +can concurrently share access to the same SQLite3 database; the same cannot be done with +the FileTokenManager. + +Usage: + + EXPORT praw_client_id= + EXPORT praw_client_secret= + python3 use_sqlite_token_manager.py TOKEN_KEY + +""" +import os +import sys + +import praw +from praw.util.token_manager import SQLiteTokenManager + +DATABASE_PATH = "tokens.sqlite3" + + +def main(): + if "praw_client_id" not in os.environ: + sys.stderr.write("Environment variable ``praw_client_id`` must be defined\n") + return 1 + if "praw_client_secret" not in os.environ: + sys.stderr.write( + "Environment variable ``praw_client_secret`` must be defined\n" + ) + return 1 + if len(sys.argv) != 2: + sys.stderr.write( + "KEY must be provided.\n\nUsage: python3 use_sqlite_token_manager.py TOKEN_KEY\n" + ) + return 1 + + refresh_token_manager = SQLiteTokenManager(DATABASE_PATH, key=sys.argv[1]) + reddit = praw.Reddit( + token_manager=refresh_token_manager, + user_agent="sqlite_token_manager/v0 by u/bboe", + ) + + if not refresh_token_manager.is_registered(): + refresh_token = input("Enter initial refresh token: ").strip() + refresh_token_manager.register(refresh_token) + + scopes = reddit.auth.scopes() + if scopes == {"*"}: + print(f"{reddit.user.me()} is authenticated with all scopes") + elif "identity" in scopes: + print( + f"{reddit.user.me()} is authenticated with the following scopes: {scopes}" + ) + else: + print(f"You are authenticated with the following scopes: {scopes}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/docs/tutorials/refresh_token.rst b/docs/tutorials/refresh_token.rst index c6ddbb1b2..de921b574 100644 --- a/docs/tutorials/refresh_token.rst +++ b/docs/tutorials/refresh_token.rst @@ -102,3 +102,13 @@ valid refresh token. .. literalinclude:: ../examples/use_file_token_manager.py :language: python + +.. _sqlite_token_manager: + +SQLiteTokenManager +~~~~~~~~~~~~~~~~~~ + +For more complex examples, PRAW provides the :class:`.SQLiteTokenManager`. + +.. literalinclude:: ../examples/use_sqlite_token_manager.py + :language: python diff --git a/praw/util/token_manager.py b/praw/util/token_manager.py index c89911470..af4c3231b 100644 --- a/praw/util/token_manager.py +++ b/praw/util/token_manager.py @@ -3,13 +3,13 @@ There should be a 1-to-1 mapping between an instance of a subclass of :class:`.BaseTokenManager` and a :class:`.Reddit` instance. -A few trivial token manager classes are provided here, but it is expected that PRAW -users will create their own token manager classes suitable for their needs. +A few proof of concept token manager classes are provided here, but it is expected that +PRAW users will create their own token manager classes suitable for their needs. -See ref:`using_refresh_tokens` for examples on how to leverage these classes. +See :ref:`using_refresh_tokens` for examples on how to leverage these classes. """ - +import sqlite3 from abc import ABC, abstractmethod @@ -59,7 +59,19 @@ def pre_refresh_callback(self, authorizer): class FileTokenManager(BaseTokenManager): - """Provides a trivial single-file based token manager.""" + """Provides a single-file based token manager. + + It is expected that the file with the initial ``refresh_token`` is created prior to + use. + + .. warning:: + + The same ``file`` should not be used by more than one instance of this class + concurrently. Doing so may result in data corruption. Consider using + :class:`.SQLiteTokenManager` if you want more than one instance of PRAW to + concurrently manage a specific ``refresh_token`` chain. + + """ def __init__(self, filename): """Load and save refresh tokens from a file. @@ -80,3 +92,98 @@ def pre_refresh_callback(self, authorizer): if authorizer.refresh_token is None: with open(self._filename) as fp: authorizer.refresh_token = fp.read().strip() + + +class SQLiteTokenManager(BaseTokenManager): + """Provides a SQLite3 based token manager. + + Unlike, :class:`.FileTokenManager`, the initial database need not be created ahead + of time, as it'll automatically be created on first use. However, initial + ``refresh_tokens`` will need to be registered via :meth:`.register` prior to use. + See :ref:`sqlite_token_manager` for an example of use. + + .. warning:: + + This class is untested on Windows because we encountered file locking issues in + the test environment. + + """ + + def __init__(self, database, key): + """Load and save refresh tokens from a SQLite database. + + :param database: The path to the SQLite database. + :param key: The key used to locate the ``refresh_token``. This ``key`` can be + anything. You might use the ``client_id`` if you expect to have unique + ``refresh_tokens`` for each ``client_id``, or you might use a Redditor's + ``username`` if you're manage multiple users' authentications. + + """ + super().__init__() + self._connection = sqlite3.connect(database) + self._connection.execute( + "CREATE TABLE IF NOT EXISTS tokens (id, refresh_token, updated_at)" + ) + self._connection.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS ux_tokens_id on tokens(id)" + ) + self._connection.commit() + self.key = key + + def _get(self): + cursor = self._connection.execute( + "SELECT refresh_token FROM tokens WHERE id=?", (self.key,) + ) + result = cursor.fetchone() + if result is None: + raise KeyError + return result[0] + + def _set(self, refresh_token): + """Set the refresh token in the database. + + This function will overwrite an existing value if the corresponding ``key`` + already exists. + + """ + self._connection.execute( + "REPLACE INTO tokens VALUES (?, ?, datetime('now'))", + (self.key, refresh_token), + ) + self._connection.commit() + + def is_registered(self): + """Return whether or not ``key`` already has a ``refresh_token``.""" + cursor = self._connection.execute( + "SELECT refresh_token FROM tokens WHERE id=?", (self.key,) + ) + return cursor.fetchone() is not None + + def post_refresh_callback(self, authorizer): + """Update the refresh token in the database.""" + self._set(authorizer.refresh_token) + + # While the following line is not strictly necessary, it ensures that the + # refresh token is not used elsewhere. And also forces the pre_refresh_callback + # to always load the latest refresh_token from the database. + authorizer.refresh_token = None + + def pre_refresh_callback(self, authorizer): + """Load the refresh token from the database.""" + assert authorizer.refresh_token is None + authorizer.refresh_token = self._get() + + def register(self, refresh_token): + """Register the initial refresh token in the database. + + :returns: ``True`` if ``refresh_token`` is saved to the database, otherwise, + ``False`` if there is already a ``refresh_token`` for the associated + ``key``. + + """ + cursor = self._connection.execute( + "INSERT OR IGNORE INTO tokens VALUES (?, ?, datetime('now'))", + (self.key, refresh_token), + ) + self._connection.commit() + return cursor.rowcount == 1 diff --git a/tests/unit/util/test_token_manager.py b/tests/unit/util/test_token_manager.py index 60c09a764..9c42b220f 100644 --- a/tests/unit/util/test_token_manager.py +++ b/tests/unit/util/test_token_manager.py @@ -1,9 +1,15 @@ """Test praw.util.refresh_token_manager.""" +import sys +from tempfile import NamedTemporaryFile from unittest import mock import pytest -from praw.util.token_manager import BaseTokenManager, FileTokenManager +from praw.util.token_manager import ( + BaseTokenManager, + FileTokenManager, + SQLiteTokenManager, +) from .. import UnitTest from ..test_reddit import DummyTokenManager @@ -58,3 +64,61 @@ def test_pre_refresh_token_callback__reads_from_file(self): assert authorizer.refresh_token == "token_value" mock_open.assert_called_once_with("mock/dummy_path") + + +class TestSQLiteTokenManager(UnitTest): + def test_is_registered(self): + manager = SQLiteTokenManager(":memory:", "dummy_key") + assert not manager.is_registered() + + @pytest.mark.skipif( + sys.platform.startswith("win"), reason="this test fails on windows" + ) + def test_multiple_instances(self): + with NamedTemporaryFile() as fp: + manager1 = SQLiteTokenManager(fp.name, "dummy_key1") + manager2 = SQLiteTokenManager(fp.name, "dummy_key1") + manager3 = SQLiteTokenManager(fp.name, "dummy_key2") + + manager1.register("dummy_value1") + assert manager2.is_registered() + assert not manager3.is_registered() + + def test_post_refresh_token_callback__sets_value(self): + authorizer = DummyAuthorizer("dummy_value") + manager = SQLiteTokenManager(":memory:", "dummy_key") + + manager.post_refresh_callback(authorizer) + assert authorizer.refresh_token is None + assert manager._get() == "dummy_value" + + def test_post_refresh_token_callback__updates_value(self): + authorizer = DummyAuthorizer("dummy_value_updated") + manager = SQLiteTokenManager(":memory:", "dummy_key") + manager.register("dummy_value") + + manager.post_refresh_callback(authorizer) + assert authorizer.refresh_token is None + assert manager._get() == "dummy_value_updated" + + def test_pre_refresh_token_callback(self): + authorizer = DummyAuthorizer(None) + manager = SQLiteTokenManager(":memory:", "dummy_key") + manager.register("dummy_value") + + manager.pre_refresh_callback(authorizer) + assert authorizer.refresh_token == "dummy_value" + + def test_pre_refresh_token_callback__raises_key_error(self): + authorizer = DummyAuthorizer(None) + manager = SQLiteTokenManager(":memory:", "dummy_key") + + with pytest.raises(KeyError): + manager.pre_refresh_callback(authorizer) + + def test_register(self): + manager = SQLiteTokenManager(":memory:", "dummy_key") + assert manager.register("dummy_value") + assert manager.is_registered() + assert not manager.register("dummy_value2") + assert manager._get() == "dummy_value"