Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQLite3TokenManager and associated example #1692

Merged
merged 3 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Unreleased
- :meth:`.trusted` to retrieve a :class:`.RedditorList` of trusted users.
- :meth:`.trust` to add a user to the trusted list.
- :meth:`.distrust` to remove a user from the trusted list.
- :class:`.SQLiteTokenManager` (may not work on Windows)

**Changed**

Expand Down
71 changes: 71 additions & 0 deletions docs/examples/use_sqlite_token_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
"""This example demonstrates using the sqlite token manager for refresh tokens.

In order to run this program, you will first need to obtain one or more valid refresh
tokens. You can use the ``obtain_refresh_token.py`` example to help.

In this example, refresh tokens will be saved into a file ``tokens.sqlite3`` relative to
your current working directory. If your current working directory is under version
control it is strongly encouraged you add ``tokens.sqlite3`` to the version control ignore
list.

This example differs primarily from ``use_file_token_manager.py`` due to the fact that a
shared SQLite3 database can manage many ``refresh_tokens``. While each instance of
Reddit still needs to have 1-to-1 mapping to a token manager, multiple Reddit instances
can concurrently share access to the same SQLite3 database; the same cannot be done with
the FileTokenManager.

Usage:

EXPORT praw_client_id=<REDDIT_CLIENT_ID>
EXPORT praw_client_secret=<REDDIT_CLIENT_SECRET>
python3 use_sqlite_token_manager.py TOKEN_KEY

"""
import os
import sys

import praw
from praw.util.token_manager import SQLiteTokenManager

DATABASE_PATH = "tokens.sqlite3"


def main():
if "praw_client_id" not in os.environ:
sys.stderr.write("Environment variable ``praw_client_id`` must be defined\n")
return 1
if "praw_client_secret" not in os.environ:
sys.stderr.write(
"Environment variable ``praw_client_secret`` must be defined\n"
)
return 1
if len(sys.argv) != 2:
sys.stderr.write(
"KEY must be provided.\n\nUsage: python3 use_sqlite_token_manager.py TOKEN_KEY\n"
)
return 1

refresh_token_manager = SQLiteTokenManager(DATABASE_PATH, key=sys.argv[1])
reddit = praw.Reddit(
token_manager=refresh_token_manager,
user_agent="sqlite_token_manager/v0 by u/bboe",
)

if not refresh_token_manager.is_registered():
refresh_token = input("Enter initial refresh token: ").strip()
refresh_token_manager.register(refresh_token)

scopes = reddit.auth.scopes()
if scopes == {"*"}:
print(f"{reddit.user.me()} is authenticated with all scopes")
elif "identity" in scopes:
print(
f"{reddit.user.me()} is authenticated with the following scopes: {scopes}"
)
else:
print(f"You are authenticated with the following scopes: {scopes}")


if __name__ == "__main__":
sys.exit(main())
10 changes: 10 additions & 0 deletions docs/tutorials/refresh_token.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,13 @@ valid refresh token.

.. literalinclude:: ../examples/use_file_token_manager.py
:language: python

.. _sqlite_token_manager:

SQLiteTokenManager
~~~~~~~~~~~~~~~~~~

For more complex examples, PRAW provides the :class:`.SQLiteTokenManager`.

.. literalinclude:: ../examples/use_sqlite_token_manager.py
:language: python
117 changes: 112 additions & 5 deletions praw/util/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
There should be a 1-to-1 mapping between an instance of a subclass of
:class:`.BaseTokenManager` and a :class:`.Reddit` instance.

A few trivial token manager classes are provided here, but it is expected that PRAW
users will create their own token manager classes suitable for their needs.
A few proof of concept token manager classes are provided here, but it is expected that
PRAW users will create their own token manager classes suitable for their needs.

See ref:`using_refresh_tokens` for examples on how to leverage these classes.
See :ref:`using_refresh_tokens` for examples on how to leverage these classes.

"""

import sqlite3
from abc import ABC, abstractmethod


Expand Down Expand Up @@ -59,7 +59,19 @@ def pre_refresh_callback(self, authorizer):


class FileTokenManager(BaseTokenManager):
"""Provides a trivial single-file based token manager."""
"""Provides a single-file based token manager.

It is expected that the file with the initial ``refresh_token`` is created prior to
use.

.. warning::

The same ``file`` should not be used by more than one instance of this class
concurrently. Doing so may result in data corruption. Consider using
:class:`.SQLiteTokenManager` if you want more than one instance of PRAW to
concurrently manage a specific ``refresh_token`` chain.

"""

def __init__(self, filename):
"""Load and save refresh tokens from a file.
Expand All @@ -80,3 +92,98 @@ def pre_refresh_callback(self, authorizer):
if authorizer.refresh_token is None:
with open(self._filename) as fp:
authorizer.refresh_token = fp.read().strip()


class SQLiteTokenManager(BaseTokenManager):
"""Provides a SQLite3 based token manager.

Unlike, :class:`.FileTokenManager`, the initial database need not be created ahead
of time, as it'll automatically be created on first use. However, initial
``refresh_tokens`` will need to be registered via :meth:`.register` prior to use.
See :ref:`sqlite_token_manager` for an example of use.

.. warning::

This class is untested on Windows because we encountered file locking issues in
the test environment.

"""

def __init__(self, database, key):
"""Load and save refresh tokens from a SQLite database.

:param database: The path to the SQLite database.
:param key: The key used to locate the ``refresh_token``. This ``key`` can be
anything. You might use the ``client_id`` if you expect to have unique
``refresh_tokens`` for each ``client_id``, or you might use a Redditor's
``username`` if you're manage multiple users' authentications.

"""
super().__init__()
self._connection = sqlite3.connect(database)
self._connection.execute(
"CREATE TABLE IF NOT EXISTS tokens (id, refresh_token, updated_at)"
)
self._connection.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS ux_tokens_id on tokens(id)"
)
self._connection.commit()
self.key = key

def _get(self):
cursor = self._connection.execute(
"SELECT refresh_token FROM tokens WHERE id=?", (self.key,)
)
result = cursor.fetchone()
if result is None:
raise KeyError
return result[0]

def _set(self, refresh_token):
"""Set the refresh token in the database.

This function will overwrite an existing value if the corresponding ``key``
already exists.

"""
self._connection.execute(
"REPLACE INTO tokens VALUES (?, ?, datetime('now'))",
(self.key, refresh_token),
)
self._connection.commit()

def is_registered(self):
"""Return whether or not ``key`` already has a ``refresh_token``."""
cursor = self._connection.execute(
"SELECT refresh_token FROM tokens WHERE id=?", (self.key,)
)
return cursor.fetchone() is not None

def post_refresh_callback(self, authorizer):
"""Update the refresh token in the database."""
self._set(authorizer.refresh_token)

# While the following line is not strictly necessary, it ensures that the
# refresh token is not used elsewhere. And also forces the pre_refresh_callback
# to always load the latest refresh_token from the database.
authorizer.refresh_token = None

def pre_refresh_callback(self, authorizer):
"""Load the refresh token from the database."""
assert authorizer.refresh_token is None
authorizer.refresh_token = self._get()

def register(self, refresh_token):
"""Register the initial refresh token in the database.

:returns: ``True`` if ``refresh_token`` is saved to the database, otherwise,
``False`` if there is already a ``refresh_token`` for the associated
``key``.

"""
cursor = self._connection.execute(
"INSERT OR IGNORE INTO tokens VALUES (?, ?, datetime('now'))",
(self.key, refresh_token),
)
self._connection.commit()
return cursor.rowcount == 1
66 changes: 65 additions & 1 deletion tests/unit/util/test_token_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Test praw.util.refresh_token_manager."""
import sys
from tempfile import NamedTemporaryFile
from unittest import mock

import pytest

from praw.util.token_manager import BaseTokenManager, FileTokenManager
from praw.util.token_manager import (
BaseTokenManager,
FileTokenManager,
SQLiteTokenManager,
)

from .. import UnitTest
from ..test_reddit import DummyTokenManager
Expand Down Expand Up @@ -58,3 +64,61 @@ def test_pre_refresh_token_callback__reads_from_file(self):

assert authorizer.refresh_token == "token_value"
mock_open.assert_called_once_with("mock/dummy_path")


class TestSQLiteTokenManager(UnitTest):
def test_is_registered(self):
manager = SQLiteTokenManager(":memory:", "dummy_key")
assert not manager.is_registered()

@pytest.mark.skipif(
sys.platform.startswith("win"), reason="this test fails on windows"
)
def test_multiple_instances(self):
with NamedTemporaryFile() as fp:
manager1 = SQLiteTokenManager(fp.name, "dummy_key1")
manager2 = SQLiteTokenManager(fp.name, "dummy_key1")
manager3 = SQLiteTokenManager(fp.name, "dummy_key2")

manager1.register("dummy_value1")
assert manager2.is_registered()
assert not manager3.is_registered()

def test_post_refresh_token_callback__sets_value(self):
authorizer = DummyAuthorizer("dummy_value")
manager = SQLiteTokenManager(":memory:", "dummy_key")

manager.post_refresh_callback(authorizer)
assert authorizer.refresh_token is None
assert manager._get() == "dummy_value"

def test_post_refresh_token_callback__updates_value(self):
authorizer = DummyAuthorizer("dummy_value_updated")
manager = SQLiteTokenManager(":memory:", "dummy_key")
manager.register("dummy_value")

manager.post_refresh_callback(authorizer)
assert authorizer.refresh_token is None
assert manager._get() == "dummy_value_updated"

def test_pre_refresh_token_callback(self):
authorizer = DummyAuthorizer(None)
manager = SQLiteTokenManager(":memory:", "dummy_key")
manager.register("dummy_value")

manager.pre_refresh_callback(authorizer)
assert authorizer.refresh_token == "dummy_value"

def test_pre_refresh_token_callback__raises_key_error(self):
authorizer = DummyAuthorizer(None)
manager = SQLiteTokenManager(":memory:", "dummy_key")

with pytest.raises(KeyError):
manager.pre_refresh_callback(authorizer)

def test_register(self):
manager = SQLiteTokenManager(":memory:", "dummy_key")
assert manager.register("dummy_value")
assert manager.is_registered()
assert not manager.register("dummy_value2")
assert manager._get() == "dummy_value"