diff --git a/example_settings.py b/example_settings.py new file mode 100644 index 000000000..aa817759c --- /dev/null +++ b/example_settings.py @@ -0,0 +1,9 @@ +REDDIT_DOMAIN = 'reddit.local:8888' +WAIT_BETWEEN_CALL_TIME = 0 +CACHE_TIMEOUT = 0 + +OBJECT_KIND_MAPPING = {'Comment': 't1', + 'Redditor': 't2', + 'Submission': 't6', + 'Subreddit': 't5', + 'MoreComments': 'more'} diff --git a/reddit/__init__.py b/reddit/__init__.py index d895dc171..046c11b34 100644 --- a/reddit/__init__.py +++ b/reddit/__init__.py @@ -22,13 +22,13 @@ except ImportError: import simplejson as json +import settings from base_objects import RedditObject from comment import Comment, MoreComments from decorators import require_captcha, require_login, parse_api_json_response from errors import ClientException from helpers import _modify_relationship, _request from redditor import LoggedInRedditor, Redditor -from settings import DEFAULT_CONTENT_LIMIT from submission import Submission from subreddit import Subreddit from urls import urls @@ -122,7 +122,7 @@ def content_id(self): """ return self.user.content_id - def _get_content(self, page_url, limit=DEFAULT_CONTENT_LIMIT, + def _get_content(self, page_url, limit=settings.DEFAULT_CONTENT_LIMIT, url_data=None, place_holder=None, root_field='data', thing_field='children', after_field='after'): """A generator method to return Reddit content from a URL. Starts at @@ -258,23 +258,24 @@ def _mark_as_read(self, content_ids): 'uh': self.modhash} self._request_json(urls["read_message"], params) - def get_front_page(self, limit=DEFAULT_CONTENT_LIMIT): + def get_front_page(self, limit=settings.DEFAULT_CONTENT_LIMIT): """Return the reddit front page. Login isn't required, but you'll only see your own front page if you are logged in.""" return self._get_content(urls["reddit_url"], limit=limit) @require_login - def get_saved_links(self, limit=DEFAULT_CONTENT_LIMIT): + def get_saved_links(self, limit=settings.DEFAULT_CONTENT_LIMIT): """Return a listing of the logged-in user's saved links.""" return self._get_content(urls["saved"], limit=limit) - def get_all_comments(self, limit=DEFAULT_CONTENT_LIMIT, place_holder=None): + def get_all_comments(self, limit=settings.DEFAULT_CONTENT_LIMIT, + place_holder=None): """Returns a listing from reddit.com/comments (which provides all of the most recent comments from all users to all submissions).""" return self._get_content(urls["comments"], limit=limit, place_holder=place_holder) - def info(self, url=None, id=None, limit=DEFAULT_CONTENT_LIMIT): + def info(self, url=None, id=None, limit=settings.DEFAULT_CONTENT_LIMIT): """ Query the API to see if the given URL has been submitted already, and if it has, return the submissions. diff --git a/reddit/comment.py b/reddit/comment.py index 32bbc80a1..1a997f7d1 100644 --- a/reddit/comment.py +++ b/reddit/comment.py @@ -15,6 +15,7 @@ from urlparse import urljoin +import settings from base_objects import RedditContentObject from features import Voteable, Deletable from util import limit_chars @@ -22,7 +23,7 @@ class Comment(RedditContentObject, Voteable, Deletable): """A class for comments.""" - kind = "t1" + kind = settings.OBJECT_KIND_MAPPING['Comment'] def __init__(self, reddit_session, json_dict): super(Comment, self).__init__(reddit_session, None, json_dict) @@ -68,7 +69,7 @@ def mark_read(self): class MoreComments(RedditContentObject): """A class indicating there are more comments.""" - kind = "more" + kind = settings.OBJECT_KIND_MAPPING['MoreComments'] def __init__(self, reddit_session, json_dict): super(MoreComments, self).__init__(reddit_session, None, json_dict) diff --git a/reddit/decorators.py b/reddit/decorators.py index 97e827c1c..6d2494610 100644 --- a/reddit/decorators.py +++ b/reddit/decorators.py @@ -20,7 +20,7 @@ import errors import reddit -from settings import WAIT_BETWEEN_CALL_TIME +import settings from urls import urls @@ -96,7 +96,6 @@ class sleep_after(object): delayed until the proper duration is reached. """ last_call_time = 0 # init to 0 to always allow the 1st call - WAIT_BETWEEN_CALL_TIME = WAIT_BETWEEN_CALL_TIME def __init__(self, func): wraps(func)(self) @@ -106,8 +105,8 @@ def __call__(self, *args, **kwargs): call_time = time.time() since_last_call = call_time - self.last_call_time - if since_last_call < WAIT_BETWEEN_CALL_TIME: - time.sleep(WAIT_BETWEEN_CALL_TIME - since_last_call) + if since_last_call < settings.WAIT_BETWEEN_CALL_TIME: + time.sleep(settings.WAIT_BETWEEN_CALL_TIME - since_last_call) self.__class__.last_call_time = call_time return self.func(*args, **kwargs) diff --git a/reddit/errors.py b/reddit/errors.py index 6622ca406..6da86fae6 100644 --- a/reddit/errors.py +++ b/reddit/errors.py @@ -85,6 +85,6 @@ class RateLimitExceeded(APIException): ERROR_MAPPING = {} -for name, obj in inspect.getmembers(sys.modules[__name__]): - if inspect.isclass(obj) and hasattr(obj, 'ERROR_TYPE'): - ERROR_MAPPING[obj.ERROR_TYPE] = obj +predicate = lambda x: inspect.isclass(x) and hasattr(x, 'ERROR_TYPE') +for name, obj in inspect.getmembers(sys.modules[__name__], predicate): + ERROR_MAPPING[obj.ERROR_TYPE] = obj diff --git a/reddit/helpers.py b/reddit/helpers.py index 1a1c369a0..9e939bcc8 100644 --- a/reddit/helpers.py +++ b/reddit/helpers.py @@ -17,8 +17,8 @@ import urllib2 from urlparse import urljoin +import settings from decorators import require_login, sleep_after -from settings import DEFAULT_CONTENT_LIMIT from urls import urls from util import memoize @@ -27,7 +27,8 @@ def _get_section(subpath=""): Used by the Redditor class to generate each of the sections (overview, comments, submitted). """ - def get_section(self, sort="new", time="all", limit=DEFAULT_CONTENT_LIMIT, + def get_section(self, sort="new", time="all", + limit=settings.DEFAULT_CONTENT_LIMIT, place_holder=None): url_data = {"sort" : sort, "time" : time} return self.reddit_session._get_content(urljoin(self._url, subpath), @@ -41,7 +42,8 @@ def _get_sorter(subpath="", **defaults): Used by the Reddit Page classes to generate each of the currently supported sorts (hot, top, new, best). """ - def sorted(self, limit=DEFAULT_CONTENT_LIMIT, place_holder=None, **data): + def sorted(self, limit=settings.DEFAULT_CONTENT_LIMIT, + place_holder=None, **data): for k, v in defaults.items(): if k == "time": # time should be "t" in the API data dict diff --git a/reddit/reddit_test.py b/reddit/reddit_test.py index a747137bb..ea638815a 100755 --- a/reddit/reddit_test.py +++ b/reddit/reddit_test.py @@ -15,8 +15,9 @@ # You should have received a copy of the GNU General Public License # along with reddit_api. If not, see . -import itertools, unittest, util, uuid, warnings +import itertools, time, unittest, util, uuid, warnings +import settings from reddit import Reddit, errors from reddit.comment import Comment, MoreComments from reddit.redditor import LoggedInRedditor @@ -37,7 +38,10 @@ def configure(self): class BasicTest(unittest.TestCase, BasicHelper): def setUp(self): self.configure() - self.self = 'http://www.reddit.com/r/programming/comments/bn2wi/' + if settings.REDDIT_DOMAIN == 'www.reddit.com': + self.self = 'http://www.reddit.com/r/programming/comments/bn2wi/' + else: + self.self = 'http://reddit.local:8888/r/bboe/comments/2z/tasdest/' def test_require_user_agent(self): self.assertRaises(TypeError, Reddit, user_agent=None) @@ -50,8 +54,12 @@ def test_not_logged_in_submit(self): self.sr, 'TITLE', text='BODY') def test_info_by_known_url_returns_known_id_link_post(self): - url = 'http://imgur.com/Vr8ZZ' - comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/' + if settings.REDDIT_DOMAIN == 'www.reddit.com': + url = 'http://imgur.com/Vr8ZZ' + comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/' + else: + url = 'http://google.com/?q=82.1753988563' + comm = 'http://reddit.local:8888/r/reddit_test8/comments/2s/' found_links = self.r.info(url) tmp = self.r.get_submission(url=comm) self.assertTrue(tmp in found_links) @@ -65,8 +73,12 @@ def test_info_by_self_url_raises_warning(self): self.assertTrue('self' in str(w[-1].message)) def test_info_by_url_also_found_by_id(self): - url = 'http://imgur.com/Vr8ZZ' - comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/' + if settings.REDDIT_DOMAIN == 'www.reddit.com': + url = 'http://imgur.com/Vr8ZZ' + comm = 'http://www.reddit.com/r/UCSantaBarbara/comments/m77nc/' + else: + url = 'http://google.com/?q=82.1753988563' + comm = 'http://reddit.local:8888/r/reddit_test8/comments/2s/' found_links = self.r.info(url) for link in found_links: found_by_id = self.r.info(id=link.name) @@ -74,7 +86,10 @@ def test_info_by_url_also_found_by_id(self): self.assertTrue(link in found_by_id) def test_comments_contains_no_noncomment_objects(self): - url = 'http://www.reddit.com/r/programming/comments/bn2wi/' + if settings.REDDIT_DOMAIN == 'www.reddit.com': + url = 'http://www.reddit.com/r/programming/comments/bn2wi/' + else: + url = 'http://reddit.local:8888/r/reddit_test9/comments/1a/' comments = self.r.get_submission(url=url).comments self.assertFalse([item for item in comments if not (isinstance(item, Comment) or @@ -91,6 +106,7 @@ def test_add_comment_and_verify(self): submission = self.subreddit.get_new_by_date().next() self.assertTrue(submission.add_comment(text)) # reload the submission + time.sleep(1) submission = self.r.get_submission(url=submission.permalink) for comment in submission.comments: if comment.body == text: @@ -108,6 +124,7 @@ def test_add_reply_and_verify(self): self.fail('Could not find a submission with comments.') self.assertTrue(comment.reply(text)) # reload the submission (use id to bypass cache) + time.sleep(1) submission = self.r.get_submission(id=submission.id) for comment in submission.comments[0].replies: if comment.body == text: @@ -155,15 +172,16 @@ def test_flair_list(self): self.assertTrue(self.subreddit.flair_list().next()) def test_flair_csv(self): - flair_mapping = [{'user':'bboe', 'flair_text':'dev', - 'flair_css_class':''}, - {'user':'pyapitestuser3', 'flair_text':'', - 'flair_css_class':'css2'}, - {'user':'pyapitestuser2', 'flair_text':'AWESOME', - 'flair_css_class':'css'}] + flair_mapping = [{u'user':'bboe', u'flair_text':u'dev', + u'flair_css_class':u''}, + {u'user':u'PyAPITestUser3', u'flair_text':u'', + u'flair_css_class':u'css2'}, + {u'user':u'PyAPITestUser2', u'flair_text':u'AWESOME', + u'flair_css_class':u'css'}] self.subreddit.set_flair_csv(flair_mapping) - expected = set([tuple(x) for x in flair_mapping]) - result = set([tuple(x) for x in self.subreddit.flair_list()]) + expected = set([tuple(sorted(x.items())) for x in flair_mapping]) + result = set([tuple(sorted(x.items())) for x in + self.subreddit.flair_list()]) self.assertTrue(not expected - result) def test_flair_csv_optional_args(self): @@ -184,7 +202,10 @@ def test_flair_csv_requires_user(self): class RedditorTest(unittest.TestCase, AuthenticatedHelper): def setUp(self): self.configure() - self.other = {'id':'6c1xj', 'name':'PyApiTestUser3'} + if settings.REDDIT_DOMAIN == 'www.reddit.com': + self.other = {'id':'6c1xj', 'name':'PyApiTestUser3'} + else: + self.other = {'id':'pa', 'name':'PyApiTestUser3'} self.user = self.r.get_redditor(self.other['name']) def test_get(self): @@ -243,7 +264,6 @@ def test_clear_vote(self): else: self.fail('Could not find a down-voted submission.') submission.clear_vote() - print submission, 'clear' # reload the submission submission = self.r.get_submission(id=submission.id) self.assertEqual(submission.likes, None) @@ -255,7 +275,6 @@ def test_downvote(self): else: self.fail('Could not find an up-voted submission.') submission.downvote() - print submission, 'down' # reload the submission submission = self.r.get_submission(id=submission.id) self.assertEqual(submission.likes, False) @@ -267,7 +286,6 @@ def test_upvote(self): else: self.fail('Could not find a non-voted submission.') submission.upvote() - print submission, 'up' # reload the submission submission = self.r.get_submission(id=submission.id) self.assertEqual(submission.likes, True) diff --git a/reddit/redditor.py b/reddit/redditor.py index e1a578600..27220255b 100644 --- a/reddit/redditor.py +++ b/reddit/redditor.py @@ -13,10 +13,10 @@ # You should have received a copy of the GNU General Public License # along with reddit_api. If not, see . +import settings from base_objects import RedditContentObject from decorators import require_login from helpers import _get_section -from settings import DEFAULT_CONTENT_LIMIT from urls import urls from util import limit_chars @@ -24,7 +24,7 @@ class Redditor(RedditContentObject): """A class for Redditor methods.""" - kind = "t2" + kind = settings.OBJECT_KIND_MAPPING['Redditor'] get_overview = _get_section("") get_comments = _get_section("comments") @@ -58,13 +58,13 @@ def unfriend(self): class LoggedInRedditor(Redditor): """A class for a currently logged in redditor""" @require_login - def my_reddits(self, limit=DEFAULT_CONTENT_LIMIT): + def my_reddits(self, limit=settings.DEFAULT_CONTENT_LIMIT): """Return all of the current user's subscribed subreddits.""" return self.reddit_session._get_content(urls["my_reddits"], limit=limit) @require_login - def my_moderation(self, limit=DEFAULT_CONTENT_LIMIT): + def my_moderation(self, limit=settings.DEFAULT_CONTENT_LIMIT): """Return all of the current user's subreddits that they moderate.""" return self.reddit_session._get_content(urls["my_moderation"], limit=limit) diff --git a/reddit/settings.py b/reddit/settings.py index 4646a8c27..b0f8cc64d 100644 --- a/reddit/settings.py +++ b/reddit/settings.py @@ -13,11 +13,40 @@ # You should have received a copy of the GNU General Public License # along with reddit_api. If not, see . # How many results to retrieve by default when making content calls +import imp +import inspect +import os +import sys + +# The domain to send API requests to. Useful to change for local reddit +# installations. +REDDIT_DOMAIN = 'www.reddit.com' + +# The domain to use for SSL requests (login). Set to None to disable SSL +# requests. +HTTPS_DOMAIN = 'ssl.reddit.com' DEFAULT_CONTENT_LIMIT = 25 + # Seconds to wait between calls, see http://code.reddit.com/wiki/API # specifically "In general, and especially for crawlers, make fewer than one # request per two seconds" WAIT_BETWEEN_CALL_TIME = 2 CACHE_TIMEOUT = 30 # in seconds + +OBJECT_KIND_MAPPING = {'Comment': 't1', + 'Redditor': 't2', + 'Submission': 't3', + 'Subreddit': 't5', + 'MoreComments': 'more'} + +# Python magic to overwrite the above default values if a user-defined settings +# file is provided via the REDDIT_CONFIG environment variable. +if 'REDDIT_CONFIG' in os.environ: + _tmp = imp.load_source('config', os.environ['REDDIT_CONFIG']) + for name, _ in inspect.getmembers(sys.modules[__name__], + lambda x: not inspect.ismodule(x)): + if name.startswith('_'): continue + if hasattr(_tmp, name): + setattr(sys.modules[__name__], name, getattr(_tmp, name)) diff --git a/reddit/submission.py b/reddit/submission.py index f0e3b8f58..188ab702b 100644 --- a/reddit/submission.py +++ b/reddit/submission.py @@ -16,13 +16,14 @@ from urls import urls from urlparse import urljoin +import settings from base_objects import RedditContentObject from features import Deletable, Saveable, Voteable class Submission(RedditContentObject, Saveable, Voteable, Deletable): """A class for submissions to Reddit.""" - kind = "t3" + kind = settings.OBJECT_KIND_MAPPING['Submission'] def __init__(self, reddit_session, title=None, json_dict=None): super(Submission, self).__init__(reddit_session, title, json_dict) diff --git a/reddit/subreddit.py b/reddit/subreddit.py index cdc0a5364..0ffdc55f3 100644 --- a/reddit/subreddit.py +++ b/reddit/subreddit.py @@ -15,6 +15,7 @@ from urls import urls +import settings from base_objects import RedditContentObject from helpers import _modify_relationship, _get_sorter from util import limit_chars @@ -22,7 +23,7 @@ class Subreddit(RedditContentObject): """A class for Subreddits.""" - kind = "t5" + kind = settings.OBJECT_KIND_MAPPING['Subreddit'] ban = _modify_relationship("banned") make_contributor = _modify_relationship("contributor") diff --git a/reddit/urls.py b/reddit/urls.py index c53e9f922..a38826f01 100644 --- a/reddit/urls.py +++ b/reddit/urls.py @@ -12,20 +12,21 @@ # # You should have received a copy of the GNU General Public License # along with reddit_api. If not, see . - from urlparse import urljoin +import settings + class URLDict(object): - def __init__(self, base_url, *args): + def __init__(self, *args): """ Builds a URL dictionary. `args` should be tuples of the form: `(url_prefix, {"url_name", "url_path"})` """ - self._base_url = base_url + _base_url = 'http://%s' % settings.REDDIT_DOMAIN self._urls = {} for prefix, url_dict in args: - full_prefix = urljoin(self._base_url, prefix) + full_prefix = urljoin(_base_url, prefix) for name, url in url_dict.iteritems(): self[name] = '%s/' % urljoin(full_prefix, url) @@ -38,17 +39,16 @@ def __setitem__(self, key, value): def group(self, *urls): return [v for v in (self[k] for k in urls)] -urls = URLDict("http://www.reddit.com", - ("", {"reddit_url" : "", - "api_url" : "api", - "comments" : "comments", - "help" : "help", - "info" : "button_info", - "logout" : "logout", - "my_reddits" : "reddits/mine", - "my_moderation" : "reddits/mine/moderator", - "saved" : "saved", - "view_captcha" : "captcha"}), +urls = URLDict(("", {"reddit_url" : "", + "api_url" : "api", + "comments" : "comments", + "help" : "help", + "info" : "button_info", + "logout" : "logout", + "my_reddits" : "reddits/mine", + "my_moderation" : "reddits/mine/moderator", + "saved" : "saved", + "view_captcha" : "captcha"}), ("api/", {"comment" : "comment", "compose_message" : "compose", "del" : "del", diff --git a/reddit/util.py b/reddit/util.py index b5514d23b..aa7e3f382 100644 --- a/reddit/util.py +++ b/reddit/util.py @@ -16,7 +16,8 @@ import time from functools import wraps -from settings import CACHE_TIMEOUT +import settings + class memoize(object): """ @@ -26,7 +27,6 @@ class memoize(object): For RedditContentObject methods, this means removal by URL, provided by the is_stale method. """ - TIMEOUT = CACHE_TIMEOUT def __init__(self, func): wraps(func)(self) @@ -50,7 +50,7 @@ def clear_timeouts(self, call_time): Clears the _caches of results which have timed out. """ need_clearing = (k for k, v in self._timeouts.items() - if call_time - v > self.TIMEOUT) + if call_time - v > settings.CACHE_TIMEOUT) for k in need_clearing: try: del self._cache[k]