Skip to content

Commit

Permalink
Added the ability to connect to another running instance of reddit th…
Browse files Browse the repository at this point in the history
…at's

somewhere other than reddit.com. See example_settings.py.
  • Loading branch information
bboe committed Dec 4, 2011
1 parent a377f53 commit 84a5d4b
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 61 deletions.
9 changes: 9 additions & 0 deletions example_settings.py
@@ -0,0 +1,9 @@
REDDIT_DOMAIN = 'reddit.local:8888'
WAIT_BETWEEN_CALL_TIME = 0
CACHE_TIMEOUT = 0

OBJECT_KIND_MAPPING = {'Comment': 't1',
'Redditor': 't2',
'Submission': 't6',
'Subreddit': 't5',
'MoreComments': 'more'}
13 changes: 7 additions & 6 deletions reddit/__init__.py
Expand Up @@ -22,13 +22,13 @@
except ImportError:
import simplejson as json

import settings
from base_objects import RedditObject
from comment import Comment, MoreComments
from decorators import require_captcha, require_login, parse_api_json_response
from errors import ClientException
from helpers import _modify_relationship, _request
from redditor import LoggedInRedditor, Redditor
from settings import DEFAULT_CONTENT_LIMIT
from submission import Submission
from subreddit import Subreddit
from urls import urls
Expand Down Expand Up @@ -122,7 +122,7 @@ def content_id(self):
"""
return self.user.content_id

def _get_content(self, page_url, limit=DEFAULT_CONTENT_LIMIT,
def _get_content(self, page_url, limit=settings.DEFAULT_CONTENT_LIMIT,
url_data=None, place_holder=None, root_field='data',
thing_field='children', after_field='after'):
"""A generator method to return Reddit content from a URL. Starts at
Expand Down Expand Up @@ -258,23 +258,24 @@ def _mark_as_read(self, content_ids):
'uh': self.modhash}
self._request_json(urls["read_message"], params)

def get_front_page(self, limit=DEFAULT_CONTENT_LIMIT):
def get_front_page(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return the reddit front page. Login isn't required, but you'll only
see your own front page if you are logged in."""
return self._get_content(urls["reddit_url"], limit=limit)

@require_login
def get_saved_links(self, limit=DEFAULT_CONTENT_LIMIT):
def get_saved_links(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return a listing of the logged-in user's saved links."""
return self._get_content(urls["saved"], limit=limit)

def get_all_comments(self, limit=DEFAULT_CONTENT_LIMIT, place_holder=None):
def get_all_comments(self, limit=settings.DEFAULT_CONTENT_LIMIT,
place_holder=None):
"""Returns a listing from reddit.com/comments (which provides all of
the most recent comments from all users to all submissions)."""
return self._get_content(urls["comments"], limit=limit,
place_holder=place_holder)

def info(self, url=None, id=None, limit=DEFAULT_CONTENT_LIMIT):
def info(self, url=None, id=None, limit=settings.DEFAULT_CONTENT_LIMIT):
"""
Query the API to see if the given URL has been submitted already, and
if it has, return the submissions.
Expand Down
5 changes: 3 additions & 2 deletions reddit/comment.py
Expand Up @@ -15,14 +15,15 @@

from urlparse import urljoin

import settings
from base_objects import RedditContentObject
from features import Voteable, Deletable
from util import limit_chars

class Comment(RedditContentObject, Voteable, Deletable):
"""A class for comments."""

kind = "t1"
kind = settings.OBJECT_KIND_MAPPING['Comment']

def __init__(self, reddit_session, json_dict):
super(Comment, self).__init__(reddit_session, None, json_dict)
Expand Down Expand Up @@ -68,7 +69,7 @@ def mark_read(self):
class MoreComments(RedditContentObject):
"""A class indicating there are more comments."""

kind = "more"
kind = settings.OBJECT_KIND_MAPPING['MoreComments']

def __init__(self, reddit_session, json_dict):
super(MoreComments, self).__init__(reddit_session, None, json_dict)
Expand Down
7 changes: 3 additions & 4 deletions reddit/decorators.py
Expand Up @@ -20,7 +20,7 @@

import errors
import reddit
from settings import WAIT_BETWEEN_CALL_TIME
import settings
from urls import urls


Expand Down Expand Up @@ -96,7 +96,6 @@ class sleep_after(object):
delayed until the proper duration is reached.
"""
last_call_time = 0 # init to 0 to always allow the 1st call
WAIT_BETWEEN_CALL_TIME = WAIT_BETWEEN_CALL_TIME

def __init__(self, func):
wraps(func)(self)
Expand All @@ -106,8 +105,8 @@ def __call__(self, *args, **kwargs):
call_time = time.time()

since_last_call = call_time - self.last_call_time
if since_last_call < WAIT_BETWEEN_CALL_TIME:
time.sleep(WAIT_BETWEEN_CALL_TIME - since_last_call)
if since_last_call < settings.WAIT_BETWEEN_CALL_TIME:
time.sleep(settings.WAIT_BETWEEN_CALL_TIME - since_last_call)

self.__class__.last_call_time = call_time
return self.func(*args, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions reddit/errors.py
Expand Up @@ -85,6 +85,6 @@ class RateLimitExceeded(APIException):


ERROR_MAPPING = {}
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and hasattr(obj, 'ERROR_TYPE'):
ERROR_MAPPING[obj.ERROR_TYPE] = obj
predicate = lambda x: inspect.isclass(x) and hasattr(x, 'ERROR_TYPE')
for name, obj in inspect.getmembers(sys.modules[__name__], predicate):
ERROR_MAPPING[obj.ERROR_TYPE] = obj
8 changes: 5 additions & 3 deletions reddit/helpers.py
Expand Up @@ -17,8 +17,8 @@
import urllib2
from urlparse import urljoin

import settings
from decorators import require_login, sleep_after
from settings import DEFAULT_CONTENT_LIMIT
from urls import urls
from util import memoize

Expand All @@ -27,7 +27,8 @@ def _get_section(subpath=""):
Used by the Redditor class to generate each of the sections (overview,
comments, submitted).
"""
def get_section(self, sort="new", time="all", limit=DEFAULT_CONTENT_LIMIT,
def get_section(self, sort="new", time="all",
limit=settings.DEFAULT_CONTENT_LIMIT,
place_holder=None):
url_data = {"sort" : sort, "time" : time}
return self.reddit_session._get_content(urljoin(self._url, subpath),
Expand All @@ -41,7 +42,8 @@ def _get_sorter(subpath="", **defaults):
Used by the Reddit Page classes to generate each of the currently supported
sorts (hot, top, new, best).
"""
def sorted(self, limit=DEFAULT_CONTENT_LIMIT, place_holder=None, **data):
def sorted(self, limit=settings.DEFAULT_CONTENT_LIMIT,
place_holder=None, **data):
for k, v in defaults.items():
if k == "time":
# time should be "t" in the API data dict
Expand Down
56 changes: 37 additions & 19 deletions reddit/reddit_test.py
Expand Up @@ -15,8 +15,9 @@
# You should have received a copy of the GNU General Public License
# along with reddit_api. If not, see <http://www.gnu.org/licenses/>.

import itertools, unittest, util, uuid, warnings
import itertools, time, unittest, util, uuid, warnings

import settings
from reddit import Reddit, errors
from reddit.comment import Comment, MoreComments
from reddit.redditor import LoggedInRedditor
Expand All @@ -37,7 +38,10 @@ def configure(self):
class BasicTest(unittest.TestCase, BasicHelper):
def setUp(self):
self.configure()
self.self = 'http://www.reddit.com/r/programming/comments/bn2wi/'
if settings.REDDIT_DOMAIN == 'www.reddit.com':
self.self = 'http://www.reddit.com/r/programming/comments/bn2wi/'
else:
self.self = 'http://reddit.local:8888/r/bboe/comments/2z/tasdest/'

def test_require_user_agent(self):
self.assertRaises(TypeError, Reddit, user_agent=None)
Expand All @@ -50,8 +54,12 @@ def test_not_logged_in_submit(self):
self.sr, 'TITLE', text='BODY')

def test_info_by_known_url_returns_known_id_link_post(self):
url = 'http://imgur.com/Vr8ZZ'
comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
if settings.REDDIT_DOMAIN == 'www.reddit.com':
url = 'http://imgur.com/Vr8ZZ'
comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
else:
url = 'http://google.com/?q=82.1753988563'
comm = 'http://reddit.local:8888/r/reddit_test8/comments/2s/'
found_links = self.r.info(url)
tmp = self.r.get_submission(url=comm)
self.assertTrue(tmp in found_links)
Expand All @@ -65,16 +73,23 @@ def test_info_by_self_url_raises_warning(self):
self.assertTrue('self' in str(w[-1].message))

def test_info_by_url_also_found_by_id(self):
url = 'http://imgur.com/Vr8ZZ'
comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
if settings.REDDIT_DOMAIN == 'www.reddit.com':
url = 'http://imgur.com/Vr8ZZ'
comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
else:
url = 'http://google.com/?q=82.1753988563'
comm = 'http://reddit.local:8888/r/reddit_test8/comments/2s/'
found_links = self.r.info(url)
for link in found_links:
found_by_id = self.r.info(id=link.name)
self.assertTrue(found_by_id)
self.assertTrue(link in found_by_id)

def test_comments_contains_no_noncomment_objects(self):
url = 'http://www.reddit.com/r/programming/comments/bn2wi/'
if settings.REDDIT_DOMAIN == 'www.reddit.com':
url = 'http://www.reddit.com/r/programming/comments/bn2wi/'
else:
url = 'http://reddit.local:8888/r/reddit_test9/comments/1a/'
comments = self.r.get_submission(url=url).comments
self.assertFalse([item for item in comments if not
(isinstance(item, Comment) or
Expand All @@ -91,6 +106,7 @@ def test_add_comment_and_verify(self):
submission = self.subreddit.get_new_by_date().next()
self.assertTrue(submission.add_comment(text))
# reload the submission
time.sleep(1)
submission = self.r.get_submission(url=submission.permalink)
for comment in submission.comments:
if comment.body == text:
Expand All @@ -108,6 +124,7 @@ def test_add_reply_and_verify(self):
self.fail('Could not find a submission with comments.')
self.assertTrue(comment.reply(text))
# reload the submission (use id to bypass cache)
time.sleep(1)
submission = self.r.get_submission(id=submission.id)
for comment in submission.comments[0].replies:
if comment.body == text:
Expand Down Expand Up @@ -155,15 +172,16 @@ def test_flair_list(self):
self.assertTrue(self.subreddit.flair_list().next())

def test_flair_csv(self):
flair_mapping = [{'user':'bboe', 'flair_text':'dev',
'flair_css_class':''},
{'user':'pyapitestuser3', 'flair_text':'',
'flair_css_class':'css2'},
{'user':'pyapitestuser2', 'flair_text':'AWESOME',
'flair_css_class':'css'}]
flair_mapping = [{u'user':'bboe', u'flair_text':u'dev',
u'flair_css_class':u''},
{u'user':u'PyAPITestUser3', u'flair_text':u'',
u'flair_css_class':u'css2'},
{u'user':u'PyAPITestUser2', u'flair_text':u'AWESOME',
u'flair_css_class':u'css'}]
self.subreddit.set_flair_csv(flair_mapping)
expected = set([tuple(x) for x in flair_mapping])
result = set([tuple(x) for x in self.subreddit.flair_list()])
expected = set([tuple(sorted(x.items())) for x in flair_mapping])
result = set([tuple(sorted(x.items())) for x in
self.subreddit.flair_list()])
self.assertTrue(not expected - result)

def test_flair_csv_optional_args(self):
Expand All @@ -184,7 +202,10 @@ def test_flair_csv_requires_user(self):
class RedditorTest(unittest.TestCase, AuthenticatedHelper):
def setUp(self):
self.configure()
self.other = {'id':'6c1xj', 'name':'PyApiTestUser3'}
if settings.REDDIT_DOMAIN == 'www.reddit.com':
self.other = {'id':'6c1xj', 'name':'PyApiTestUser3'}
else:
self.other = {'id':'pa', 'name':'PyApiTestUser3'}
self.user = self.r.get_redditor(self.other['name'])

def test_get(self):
Expand Down Expand Up @@ -243,7 +264,6 @@ def test_clear_vote(self):
else:
self.fail('Could not find a down-voted submission.')
submission.clear_vote()
print submission, 'clear'
# reload the submission
submission = self.r.get_submission(id=submission.id)
self.assertEqual(submission.likes, None)
Expand All @@ -255,7 +275,6 @@ def test_downvote(self):
else:
self.fail('Could not find an up-voted submission.')
submission.downvote()
print submission, 'down'
# reload the submission
submission = self.r.get_submission(id=submission.id)
self.assertEqual(submission.likes, False)
Expand All @@ -267,7 +286,6 @@ def test_upvote(self):
else:
self.fail('Could not find a non-voted submission.')
submission.upvote()
print submission, 'up'
# reload the submission
submission = self.r.get_submission(id=submission.id)
self.assertEqual(submission.likes, True)
Expand Down
8 changes: 4 additions & 4 deletions reddit/redditor.py
Expand Up @@ -13,18 +13,18 @@
# You should have received a copy of the GNU General Public License
# along with reddit_api. If not, see <http://www.gnu.org/licenses/>.

import settings
from base_objects import RedditContentObject
from decorators import require_login
from helpers import _get_section
from settings import DEFAULT_CONTENT_LIMIT
from urls import urls
from util import limit_chars


class Redditor(RedditContentObject):
"""A class for Redditor methods."""

kind = "t2"
kind = settings.OBJECT_KIND_MAPPING['Redditor']

get_overview = _get_section("")
get_comments = _get_section("comments")
Expand Down Expand Up @@ -58,13 +58,13 @@ def unfriend(self):
class LoggedInRedditor(Redditor):
"""A class for a currently logged in redditor"""
@require_login
def my_reddits(self, limit=DEFAULT_CONTENT_LIMIT):
def my_reddits(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return all of the current user's subscribed subreddits."""
return self.reddit_session._get_content(urls["my_reddits"],
limit=limit)

@require_login
def my_moderation(self, limit=DEFAULT_CONTENT_LIMIT):
def my_moderation(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return all of the current user's subreddits that they moderate."""
return self.reddit_session._get_content(urls["my_moderation"],
limit=limit)
Expand Down
29 changes: 29 additions & 0 deletions reddit/settings.py
Expand Up @@ -13,11 +13,40 @@
# You should have received a copy of the GNU General Public License
# along with reddit_api. If not, see <http://www.gnu.org/licenses/>.
# How many results to retrieve by default when making content calls
import imp
import inspect
import os
import sys

# The domain to send API requests to. Useful to change for local reddit
# installations.
REDDIT_DOMAIN = 'www.reddit.com'

# The domain to use for SSL requests (login). Set to None to disable SSL
# requests.
HTTPS_DOMAIN = 'ssl.reddit.com'

DEFAULT_CONTENT_LIMIT = 25

# Seconds to wait between calls, see http://code.reddit.com/wiki/API
# specifically "In general, and especially for crawlers, make fewer than one
# request per two seconds"
WAIT_BETWEEN_CALL_TIME = 2

CACHE_TIMEOUT = 30 # in seconds

OBJECT_KIND_MAPPING = {'Comment': 't1',
'Redditor': 't2',
'Submission': 't3',
'Subreddit': 't5',
'MoreComments': 'more'}

# Python magic to overwrite the above default values if a user-defined settings
# file is provided via the REDDIT_CONFIG environment variable.
if 'REDDIT_CONFIG' in os.environ:
_tmp = imp.load_source('config', os.environ['REDDIT_CONFIG'])
for name, _ in inspect.getmembers(sys.modules[__name__],
lambda x: not inspect.ismodule(x)):
if name.startswith('_'): continue
if hasattr(_tmp, name):
setattr(sys.modules[__name__], name, getattr(_tmp, name))
3 changes: 2 additions & 1 deletion reddit/submission.py
Expand Up @@ -16,13 +16,14 @@
from urls import urls
from urlparse import urljoin

import settings
from base_objects import RedditContentObject
from features import Deletable, Saveable, Voteable

class Submission(RedditContentObject, Saveable, Voteable, Deletable):
"""A class for submissions to Reddit."""

kind = "t3"
kind = settings.OBJECT_KIND_MAPPING['Submission']

def __init__(self, reddit_session, title=None, json_dict=None):
super(Submission, self).__init__(reddit_session, title, json_dict)
Expand Down

0 comments on commit 84a5d4b

Please sign in to comment.