Skip to content

Commit

Permalink
Merge pull request #1211 from PokestarFan/add-types-collections
Browse files Browse the repository at this point in the history
(#1164) Add types to reddit/collections.py
  • Loading branch information
bboe committed Dec 29, 2019
2 parents bda2d07 + 63d9a49 commit 8d10ac4
Showing 1 changed file with 42 additions and 17 deletions.
59 changes: 42 additions & 17 deletions praw/models/reddit/collections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provide Collections functionality."""
from typing import Any, Dict, Generator, List, Optional, TypeVar, Union

from ...const import API_PATH
from ...exceptions import ClientException
Expand All @@ -8,6 +9,10 @@
from .submission import Submission
from .subreddit import Subreddit

Reddit = TypeVar("Reddit")
_CollectionModeration = TypeVar("_CollectionModeration")
_SubredditCollectionsModeration = TypeVar("_SubredditCollectionsModeration")


class Collection(RedditBase):
"""Class to represent a Collection.
Expand Down Expand Up @@ -57,7 +62,7 @@ class Collection(RedditBase):
STR_FIELD = "collection_id"

@cachedproperty
def mod(self):
def mod(self) -> _CollectionModeration:
"""Get an instance of :class:`.CollectionModeration`.
Provides access to various methods, including
Expand All @@ -77,11 +82,17 @@ def mod(self):
return CollectionModeration(self._reddit, self.collection_id)

@cachedproperty
def subreddit(self):
def subreddit(self) -> Subreddit:
"""Get the subreddit that this collection belongs to."""
return next(self._reddit.info([self.subreddit_id]))

def __init__(self, reddit, _data=None, collection_id=None, permalink=None):
def __init__(
self,
reddit: Reddit,
_data: Dict[str, Any] = None,
collection_id: Optional[str] = None,
permalink: Optional[str] = None,
):
"""Initialize this collection.
:param reddit: An instance of :class:`.Reddit`.
Expand All @@ -108,7 +119,7 @@ def __init__(self, reddit, _data=None, collection_id=None, permalink=None):
"include_links": True,
}

def __iter__(self):
def __iter__(self) -> Generator[Any, None, None]:
"""Provide a way to iterate over the posts in this Collection.
Example usage:
Expand All @@ -123,7 +134,7 @@ def __iter__(self):
for item in self.sorted_links:
yield item

def __len__(self):
def __len__(self) -> int:
"""Get the number of posts in this Collection.
Example usage:
Expand All @@ -136,7 +147,7 @@ def __len__(self):
"""
return len(self.link_ids)

def __setattr__(self, attribute, value):
def __setattr__(self, attribute: str, value: Any):
"""Objectify author, subreddit, and sorted_links attributes."""
if attribute == "author_name":
self.author = self._reddit.redditor(value)
Expand Down Expand Up @@ -235,15 +246,15 @@ def _post_fullname(self, post):
except ClientException:
return self._reddit.submission(id=post).fullname

def __init__(self, reddit, collection_id):
def __init__(self, reddit: Reddit, collection_id: str):
"""Initialize an instance of CollectionModeration.
:param collection_id: The ID of a collection.
"""
super().__init__(reddit, _data=None)
self.collection_id = collection_id

def add_post(self, submission):
def add_post(self, submission: Submission):
"""Add a post to the collection.
:param submission: The post to add, a :class:`.Submission`, its
Expand Down Expand Up @@ -287,7 +298,7 @@ def delete(self):
data={"collection_id": self.collection_id},
)

def remove_post(self, submission):
def remove_post(self, submission: Submission):
"""Remove a post from the collection.
:param submission: The post to remove, a :class:`.Submission`, its
Expand All @@ -314,7 +325,7 @@ def remove_post(self, submission):
},
)

def reorder(self, links):
def reorder(self, links: List[Union[str, Submission]]):
"""Reorder posts in the collection.
:param links: A ``list`` of submissions, as :class:`.Submission`,
Expand All @@ -336,7 +347,7 @@ def reorder(self, links):
data={"collection_id": self.collection_id, "link_ids": link_ids},
)

def update_description(self, description):
def update_description(self, description: str):
"""Update the collection's description.
:param description: The new description.
Expand All @@ -359,7 +370,7 @@ def update_description(self, description):
},
)

def update_title(self, title):
def update_title(self, title: str):
"""Update the collection's title.
:param title: The new title.
Expand Down Expand Up @@ -392,7 +403,7 @@ class SubredditCollections(PRAWBase):
"""

@cachedproperty
def mod(self):
def mod(self) -> _SubredditCollectionsModeration:
"""Get an instance of :class:`.SubredditCollectionsModeration`.
Provides :meth:`~SubredditCollectionsModeration.create`:
Expand All @@ -407,7 +418,11 @@ def mod(self):
self._reddit, self.subreddit.fullname
)

def __call__(self, collection_id=None, permalink=None):
def __call__(
self,
collection_id: Optional[str] = None,
permalink: Optional[str] = None,
):
"""Return the :class:`.Collection` with the specified ID.
:param collection_id: The ID of a Collection (default: None).
Expand Down Expand Up @@ -442,7 +457,12 @@ def __call__(self, collection_id=None, permalink=None):
self._reddit, collection_id=collection_id, permalink=permalink
)

def __init__(self, reddit, subreddit, _data=None):
def __init__(
self,
reddit: Reddit,
subreddit: Subreddit,
_data: Optional[Dict[str, Any]] = None,
):
"""Initialize an instance of SubredditCollections."""
super().__init__(reddit, _data)
self.subreddit = subreddit
Expand Down Expand Up @@ -477,12 +497,17 @@ class SubredditCollectionsModeration(PRAWBase):
"""

def __init__(self, reddit, sub_fullname, _data=None):
def __init__(
self,
reddit: Reddit,
sub_fullname: str,
_data: Optional[Dict[str, Any]] = None,
):
"""Initialize the SubredditCollectionsModeration instance."""
super().__init__(reddit, _data)
self.subreddit_fullname = sub_fullname

def create(self, title, description):
def create(self, title: str, description: str):
"""Create a new :class:`.Collection`.
The authenticated account must have appropriate moderator
Expand Down

0 comments on commit 8d10ac4

Please sign in to comment.