diff --git a/example_settings.py b/example_settings.py
new file mode 100644
index 000000000..aa817759c
--- /dev/null
+++ b/example_settings.py
@@ -0,0 +1,9 @@
+REDDIT_DOMAIN = 'reddit.local:8888'
+WAIT_BETWEEN_CALL_TIME = 0
+CACHE_TIMEOUT = 0
+
+OBJECT_KIND_MAPPING = {'Comment': 't1',
+ 'Redditor': 't2',
+ 'Submission': 't6',
+ 'Subreddit': 't5',
+ 'MoreComments': 'more'}
diff --git a/reddit/__init__.py b/reddit/__init__.py
index d895dc171..046c11b34 100644
--- a/reddit/__init__.py
+++ b/reddit/__init__.py
@@ -22,13 +22,13 @@
except ImportError:
import simplejson as json
+import settings
from base_objects import RedditObject
from comment import Comment, MoreComments
from decorators import require_captcha, require_login, parse_api_json_response
from errors import ClientException
from helpers import _modify_relationship, _request
from redditor import LoggedInRedditor, Redditor
-from settings import DEFAULT_CONTENT_LIMIT
from submission import Submission
from subreddit import Subreddit
from urls import urls
@@ -122,7 +122,7 @@ def content_id(self):
"""
return self.user.content_id
- def _get_content(self, page_url, limit=DEFAULT_CONTENT_LIMIT,
+ def _get_content(self, page_url, limit=settings.DEFAULT_CONTENT_LIMIT,
url_data=None, place_holder=None, root_field='data',
thing_field='children', after_field='after'):
"""A generator method to return Reddit content from a URL. Starts at
@@ -258,23 +258,24 @@ def _mark_as_read(self, content_ids):
'uh': self.modhash}
self._request_json(urls["read_message"], params)
- def get_front_page(self, limit=DEFAULT_CONTENT_LIMIT):
+ def get_front_page(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return the reddit front page. Login isn't required, but you'll only
see your own front page if you are logged in."""
return self._get_content(urls["reddit_url"], limit=limit)
@require_login
- def get_saved_links(self, limit=DEFAULT_CONTENT_LIMIT):
+ def get_saved_links(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return a listing of the logged-in user's saved links."""
return self._get_content(urls["saved"], limit=limit)
- def get_all_comments(self, limit=DEFAULT_CONTENT_LIMIT, place_holder=None):
+ def get_all_comments(self, limit=settings.DEFAULT_CONTENT_LIMIT,
+ place_holder=None):
"""Returns a listing from reddit.com/comments (which provides all of
the most recent comments from all users to all submissions)."""
return self._get_content(urls["comments"], limit=limit,
place_holder=place_holder)
- def info(self, url=None, id=None, limit=DEFAULT_CONTENT_LIMIT):
+ def info(self, url=None, id=None, limit=settings.DEFAULT_CONTENT_LIMIT):
"""
Query the API to see if the given URL has been submitted already, and
if it has, return the submissions.
diff --git a/reddit/comment.py b/reddit/comment.py
index 32bbc80a1..1a997f7d1 100644
--- a/reddit/comment.py
+++ b/reddit/comment.py
@@ -15,6 +15,7 @@
from urlparse import urljoin
+import settings
from base_objects import RedditContentObject
from features import Voteable, Deletable
from util import limit_chars
@@ -22,7 +23,7 @@
class Comment(RedditContentObject, Voteable, Deletable):
"""A class for comments."""
- kind = "t1"
+ kind = settings.OBJECT_KIND_MAPPING['Comment']
def __init__(self, reddit_session, json_dict):
super(Comment, self).__init__(reddit_session, None, json_dict)
@@ -68,7 +69,7 @@ def mark_read(self):
class MoreComments(RedditContentObject):
"""A class indicating there are more comments."""
- kind = "more"
+ kind = settings.OBJECT_KIND_MAPPING['MoreComments']
def __init__(self, reddit_session, json_dict):
super(MoreComments, self).__init__(reddit_session, None, json_dict)
diff --git a/reddit/decorators.py b/reddit/decorators.py
index 97e827c1c..6d2494610 100644
--- a/reddit/decorators.py
+++ b/reddit/decorators.py
@@ -20,7 +20,7 @@
import errors
import reddit
-from settings import WAIT_BETWEEN_CALL_TIME
+import settings
from urls import urls
@@ -96,7 +96,6 @@ class sleep_after(object):
delayed until the proper duration is reached.
"""
last_call_time = 0 # init to 0 to always allow the 1st call
- WAIT_BETWEEN_CALL_TIME = WAIT_BETWEEN_CALL_TIME
def __init__(self, func):
wraps(func)(self)
@@ -106,8 +105,8 @@ def __call__(self, *args, **kwargs):
call_time = time.time()
since_last_call = call_time - self.last_call_time
- if since_last_call < WAIT_BETWEEN_CALL_TIME:
- time.sleep(WAIT_BETWEEN_CALL_TIME - since_last_call)
+ if since_last_call < settings.WAIT_BETWEEN_CALL_TIME:
+ time.sleep(settings.WAIT_BETWEEN_CALL_TIME - since_last_call)
self.__class__.last_call_time = call_time
return self.func(*args, **kwargs)
diff --git a/reddit/errors.py b/reddit/errors.py
index 6622ca406..6da86fae6 100644
--- a/reddit/errors.py
+++ b/reddit/errors.py
@@ -85,6 +85,6 @@ class RateLimitExceeded(APIException):
ERROR_MAPPING = {}
-for name, obj in inspect.getmembers(sys.modules[__name__]):
- if inspect.isclass(obj) and hasattr(obj, 'ERROR_TYPE'):
- ERROR_MAPPING[obj.ERROR_TYPE] = obj
+predicate = lambda x: inspect.isclass(x) and hasattr(x, 'ERROR_TYPE')
+for name, obj in inspect.getmembers(sys.modules[__name__], predicate):
+ ERROR_MAPPING[obj.ERROR_TYPE] = obj
diff --git a/reddit/helpers.py b/reddit/helpers.py
index 1a1c369a0..9e939bcc8 100644
--- a/reddit/helpers.py
+++ b/reddit/helpers.py
@@ -17,8 +17,8 @@
import urllib2
from urlparse import urljoin
+import settings
from decorators import require_login, sleep_after
-from settings import DEFAULT_CONTENT_LIMIT
from urls import urls
from util import memoize
@@ -27,7 +27,8 @@ def _get_section(subpath=""):
Used by the Redditor class to generate each of the sections (overview,
comments, submitted).
"""
- def get_section(self, sort="new", time="all", limit=DEFAULT_CONTENT_LIMIT,
+ def get_section(self, sort="new", time="all",
+ limit=settings.DEFAULT_CONTENT_LIMIT,
place_holder=None):
url_data = {"sort" : sort, "time" : time}
return self.reddit_session._get_content(urljoin(self._url, subpath),
@@ -41,7 +42,8 @@ def _get_sorter(subpath="", **defaults):
Used by the Reddit Page classes to generate each of the currently supported
sorts (hot, top, new, best).
"""
- def sorted(self, limit=DEFAULT_CONTENT_LIMIT, place_holder=None, **data):
+ def sorted(self, limit=settings.DEFAULT_CONTENT_LIMIT,
+ place_holder=None, **data):
for k, v in defaults.items():
if k == "time":
# time should be "t" in the API data dict
diff --git a/reddit/reddit_test.py b/reddit/reddit_test.py
index a747137bb..ea638815a 100755
--- a/reddit/reddit_test.py
+++ b/reddit/reddit_test.py
@@ -15,8 +15,9 @@
# You should have received a copy of the GNU General Public License
# along with reddit_api. If not, see .
-import itertools, unittest, util, uuid, warnings
+import itertools, time, unittest, util, uuid, warnings
+import settings
from reddit import Reddit, errors
from reddit.comment import Comment, MoreComments
from reddit.redditor import LoggedInRedditor
@@ -37,7 +38,10 @@ def configure(self):
class BasicTest(unittest.TestCase, BasicHelper):
def setUp(self):
self.configure()
- self.self = 'http://www.reddit.com/r/programming/comments/bn2wi/'
+ if settings.REDDIT_DOMAIN == 'www.reddit.com':
+ self.self = 'http://www.reddit.com/r/programming/comments/bn2wi/'
+ else:
+ self.self = 'http://reddit.local:8888/r/bboe/comments/2z/tasdest/'
def test_require_user_agent(self):
self.assertRaises(TypeError, Reddit, user_agent=None)
@@ -50,8 +54,12 @@ def test_not_logged_in_submit(self):
self.sr, 'TITLE', text='BODY')
def test_info_by_known_url_returns_known_id_link_post(self):
- url = 'http://imgur.com/Vr8ZZ'
- comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
+ if settings.REDDIT_DOMAIN == 'www.reddit.com':
+ url = 'http://imgur.com/Vr8ZZ'
+ comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
+ else:
+ url = 'http://google.com/?q=82.1753988563'
+ comm = 'http://reddit.local:8888/r/reddit_test8/comments/2s/'
found_links = self.r.info(url)
tmp = self.r.get_submission(url=comm)
self.assertTrue(tmp in found_links)
@@ -65,8 +73,12 @@ def test_info_by_self_url_raises_warning(self):
self.assertTrue('self' in str(w[-1].message))
def test_info_by_url_also_found_by_id(self):
- url = 'http://imgur.com/Vr8ZZ'
- comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
+ if settings.REDDIT_DOMAIN == 'www.reddit.com':
+ url = 'http://imgur.com/Vr8ZZ'
+ comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/'
+ else:
+ url = 'http://google.com/?q=82.1753988563'
+ comm = 'http://reddit.local:8888/r/reddit_test8/comments/2s/'
found_links = self.r.info(url)
for link in found_links:
found_by_id = self.r.info(id=link.name)
@@ -74,7 +86,10 @@ def test_info_by_url_also_found_by_id(self):
self.assertTrue(link in found_by_id)
def test_comments_contains_no_noncomment_objects(self):
- url = 'http://www.reddit.com/r/programming/comments/bn2wi/'
+ if settings.REDDIT_DOMAIN == 'www.reddit.com':
+ url = 'http://www.reddit.com/r/programming/comments/bn2wi/'
+ else:
+ url = 'http://reddit.local:8888/r/reddit_test9/comments/1a/'
comments = self.r.get_submission(url=url).comments
self.assertFalse([item for item in comments if not
(isinstance(item, Comment) or
@@ -91,6 +106,7 @@ def test_add_comment_and_verify(self):
submission = self.subreddit.get_new_by_date().next()
self.assertTrue(submission.add_comment(text))
# reload the submission
+ time.sleep(1)
submission = self.r.get_submission(url=submission.permalink)
for comment in submission.comments:
if comment.body == text:
@@ -108,6 +124,7 @@ def test_add_reply_and_verify(self):
self.fail('Could not find a submission with comments.')
self.assertTrue(comment.reply(text))
# reload the submission (use id to bypass cache)
+ time.sleep(1)
submission = self.r.get_submission(id=submission.id)
for comment in submission.comments[0].replies:
if comment.body == text:
@@ -155,15 +172,16 @@ def test_flair_list(self):
self.assertTrue(self.subreddit.flair_list().next())
def test_flair_csv(self):
- flair_mapping = [{'user':'bboe', 'flair_text':'dev',
- 'flair_css_class':''},
- {'user':'pyapitestuser3', 'flair_text':'',
- 'flair_css_class':'css2'},
- {'user':'pyapitestuser2', 'flair_text':'AWESOME',
- 'flair_css_class':'css'}]
+ flair_mapping = [{u'user':'bboe', u'flair_text':u'dev',
+ u'flair_css_class':u''},
+ {u'user':u'PyAPITestUser3', u'flair_text':u'',
+ u'flair_css_class':u'css2'},
+ {u'user':u'PyAPITestUser2', u'flair_text':u'AWESOME',
+ u'flair_css_class':u'css'}]
self.subreddit.set_flair_csv(flair_mapping)
- expected = set([tuple(x) for x in flair_mapping])
- result = set([tuple(x) for x in self.subreddit.flair_list()])
+ expected = set([tuple(sorted(x.items())) for x in flair_mapping])
+ result = set([tuple(sorted(x.items())) for x in
+ self.subreddit.flair_list()])
self.assertTrue(not expected - result)
def test_flair_csv_optional_args(self):
@@ -184,7 +202,10 @@ def test_flair_csv_requires_user(self):
class RedditorTest(unittest.TestCase, AuthenticatedHelper):
def setUp(self):
self.configure()
- self.other = {'id':'6c1xj', 'name':'PyApiTestUser3'}
+ if settings.REDDIT_DOMAIN == 'www.reddit.com':
+ self.other = {'id':'6c1xj', 'name':'PyApiTestUser3'}
+ else:
+ self.other = {'id':'pa', 'name':'PyApiTestUser3'}
self.user = self.r.get_redditor(self.other['name'])
def test_get(self):
@@ -243,7 +264,6 @@ def test_clear_vote(self):
else:
self.fail('Could not find a down-voted submission.')
submission.clear_vote()
- print submission, 'clear'
# reload the submission
submission = self.r.get_submission(id=submission.id)
self.assertEqual(submission.likes, None)
@@ -255,7 +275,6 @@ def test_downvote(self):
else:
self.fail('Could not find an up-voted submission.')
submission.downvote()
- print submission, 'down'
# reload the submission
submission = self.r.get_submission(id=submission.id)
self.assertEqual(submission.likes, False)
@@ -267,7 +286,6 @@ def test_upvote(self):
else:
self.fail('Could not find a non-voted submission.')
submission.upvote()
- print submission, 'up'
# reload the submission
submission = self.r.get_submission(id=submission.id)
self.assertEqual(submission.likes, True)
diff --git a/reddit/redditor.py b/reddit/redditor.py
index e1a578600..27220255b 100644
--- a/reddit/redditor.py
+++ b/reddit/redditor.py
@@ -13,10 +13,10 @@
# You should have received a copy of the GNU General Public License
# along with reddit_api. If not, see .
+import settings
from base_objects import RedditContentObject
from decorators import require_login
from helpers import _get_section
-from settings import DEFAULT_CONTENT_LIMIT
from urls import urls
from util import limit_chars
@@ -24,7 +24,7 @@
class Redditor(RedditContentObject):
"""A class for Redditor methods."""
- kind = "t2"
+ kind = settings.OBJECT_KIND_MAPPING['Redditor']
get_overview = _get_section("")
get_comments = _get_section("comments")
@@ -58,13 +58,13 @@ def unfriend(self):
class LoggedInRedditor(Redditor):
"""A class for a currently logged in redditor"""
@require_login
- def my_reddits(self, limit=DEFAULT_CONTENT_LIMIT):
+ def my_reddits(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return all of the current user's subscribed subreddits."""
return self.reddit_session._get_content(urls["my_reddits"],
limit=limit)
@require_login
- def my_moderation(self, limit=DEFAULT_CONTENT_LIMIT):
+ def my_moderation(self, limit=settings.DEFAULT_CONTENT_LIMIT):
"""Return all of the current user's subreddits that they moderate."""
return self.reddit_session._get_content(urls["my_moderation"],
limit=limit)
diff --git a/reddit/settings.py b/reddit/settings.py
index 4646a8c27..b0f8cc64d 100644
--- a/reddit/settings.py
+++ b/reddit/settings.py
@@ -13,11 +13,40 @@
# You should have received a copy of the GNU General Public License
# along with reddit_api. If not, see .
# How many results to retrieve by default when making content calls
+import imp
+import inspect
+import os
+import sys
+
+# The domain to send API requests to. Useful to change for local reddit
+# installations.
+REDDIT_DOMAIN = 'www.reddit.com'
+
+# The domain to use for SSL requests (login). Set to None to disable SSL
+# requests.
+HTTPS_DOMAIN = 'ssl.reddit.com'
DEFAULT_CONTENT_LIMIT = 25
+
# Seconds to wait between calls, see http://code.reddit.com/wiki/API
# specifically "In general, and especially for crawlers, make fewer than one
# request per two seconds"
WAIT_BETWEEN_CALL_TIME = 2
CACHE_TIMEOUT = 30 # in seconds
+
+OBJECT_KIND_MAPPING = {'Comment': 't1',
+ 'Redditor': 't2',
+ 'Submission': 't3',
+ 'Subreddit': 't5',
+ 'MoreComments': 'more'}
+
+# Python magic to overwrite the above default values if a user-defined settings
+# file is provided via the REDDIT_CONFIG environment variable.
+if 'REDDIT_CONFIG' in os.environ:
+ _tmp = imp.load_source('config', os.environ['REDDIT_CONFIG'])
+ for name, _ in inspect.getmembers(sys.modules[__name__],
+ lambda x: not inspect.ismodule(x)):
+ if name.startswith('_'): continue
+ if hasattr(_tmp, name):
+ setattr(sys.modules[__name__], name, getattr(_tmp, name))
diff --git a/reddit/submission.py b/reddit/submission.py
index f0e3b8f58..188ab702b 100644
--- a/reddit/submission.py
+++ b/reddit/submission.py
@@ -16,13 +16,14 @@
from urls import urls
from urlparse import urljoin
+import settings
from base_objects import RedditContentObject
from features import Deletable, Saveable, Voteable
class Submission(RedditContentObject, Saveable, Voteable, Deletable):
"""A class for submissions to Reddit."""
- kind = "t3"
+ kind = settings.OBJECT_KIND_MAPPING['Submission']
def __init__(self, reddit_session, title=None, json_dict=None):
super(Submission, self).__init__(reddit_session, title, json_dict)
diff --git a/reddit/subreddit.py b/reddit/subreddit.py
index cdc0a5364..0ffdc55f3 100644
--- a/reddit/subreddit.py
+++ b/reddit/subreddit.py
@@ -15,6 +15,7 @@
from urls import urls
+import settings
from base_objects import RedditContentObject
from helpers import _modify_relationship, _get_sorter
from util import limit_chars
@@ -22,7 +23,7 @@
class Subreddit(RedditContentObject):
"""A class for Subreddits."""
- kind = "t5"
+ kind = settings.OBJECT_KIND_MAPPING['Subreddit']
ban = _modify_relationship("banned")
make_contributor = _modify_relationship("contributor")
diff --git a/reddit/urls.py b/reddit/urls.py
index c53e9f922..a38826f01 100644
--- a/reddit/urls.py
+++ b/reddit/urls.py
@@ -12,20 +12,21 @@
#
# You should have received a copy of the GNU General Public License
# along with reddit_api. If not, see .
-
from urlparse import urljoin
+import settings
+
class URLDict(object):
- def __init__(self, base_url, *args):
+ def __init__(self, *args):
"""
Builds a URL dictionary. `args` should be tuples of the form:
`(url_prefix, {"url_name", "url_path"})`
"""
- self._base_url = base_url
+ _base_url = 'http://%s' % settings.REDDIT_DOMAIN
self._urls = {}
for prefix, url_dict in args:
- full_prefix = urljoin(self._base_url, prefix)
+ full_prefix = urljoin(_base_url, prefix)
for name, url in url_dict.iteritems():
self[name] = '%s/' % urljoin(full_prefix, url)
@@ -38,17 +39,16 @@ def __setitem__(self, key, value):
def group(self, *urls):
return [v for v in (self[k] for k in urls)]
-urls = URLDict("http://www.reddit.com",
- ("", {"reddit_url" : "",
- "api_url" : "api",
- "comments" : "comments",
- "help" : "help",
- "info" : "button_info",
- "logout" : "logout",
- "my_reddits" : "reddits/mine",
- "my_moderation" : "reddits/mine/moderator",
- "saved" : "saved",
- "view_captcha" : "captcha"}),
+urls = URLDict(("", {"reddit_url" : "",
+ "api_url" : "api",
+ "comments" : "comments",
+ "help" : "help",
+ "info" : "button_info",
+ "logout" : "logout",
+ "my_reddits" : "reddits/mine",
+ "my_moderation" : "reddits/mine/moderator",
+ "saved" : "saved",
+ "view_captcha" : "captcha"}),
("api/", {"comment" : "comment",
"compose_message" : "compose",
"del" : "del",
diff --git a/reddit/util.py b/reddit/util.py
index b5514d23b..aa7e3f382 100644
--- a/reddit/util.py
+++ b/reddit/util.py
@@ -16,7 +16,8 @@
import time
from functools import wraps
-from settings import CACHE_TIMEOUT
+import settings
+
class memoize(object):
"""
@@ -26,7 +27,6 @@ class memoize(object):
For RedditContentObject methods, this means removal by URL, provided by the
is_stale method.
"""
- TIMEOUT = CACHE_TIMEOUT
def __init__(self, func):
wraps(func)(self)
@@ -50,7 +50,7 @@ def clear_timeouts(self, call_time):
Clears the _caches of results which have timed out.
"""
need_clearing = (k for k, v in self._timeouts.items()
- if call_time - v > self.TIMEOUT)
+ if call_time - v > settings.CACHE_TIMEOUT)
for k in need_clearing:
try:
del self._cache[k]