## DS1003 Final Project: Data Collectiing from Reddit

## Authenticating with the Reddit API

This code snippet demonstrates how to authenticate and interact with the Reddit API using OAuth.

### Key Components of the Code:
1. **Authentication Setup**: 
   - Utilizes the `requests` library to manage HTTP requests.
   - The credentials, including `CLIENT_ID` (personal use script ID) and `SECRET_TOKEN` (OAuth token), are specified to perform HTTP basic authentication.

2. **Requesting an OAuth Token**:
   - The script sends a POST request containing the user credentials (username and password) to obtain an OAuth access token.
   - This token is essential for authenticating subsequent API requests.

3. **Setting Headers**:
   - A `User-Agent` header provides Reddit with a brief description of the bot application.
   - After obtaining the OAuth access token, it's added to the headers to enable authorization for further API requests.

4. **Making Authenticated Requests**:
   - Once the OAuth token is acquired, the `Authorization` header is included in all subsequent requests.
   - In this example, an authenticated request fetches the profile information of the currently authenticated user.


In [1]:
import requests

# note that CLIENT_ID refers to 'personal use script' and SECRET_TOKEN to 'token'
auth = requests.auth.HTTPBasicAuth('3gn07Oqp2_sut_zD3n8Dnw', 'WED8ZTRAfziqJPu-uugMfFGVhl-ozw')

# here we pass our login method (password), username, and password
data = {'grant_type': 'password',
        'username': 'diamonddurfhands',
        'password': '**DiamondHands22'}

# setup our header info, which gives reddit a brief description of our app
headers = {'User-Agent': 'MyBot/0.0.1'}

# send our request for an OAuth token
res = requests.post('https://www.reddit.com/api/v1/access_token',
                    auth=auth, data=data, headers=headers)

# convert response to JSON and pull access_token value
TOKEN = res.json()['access_token']

# add authorization to our headers dictionary
headers = {**headers, **{'Authorization': f"bearer {TOKEN}"}}

# while the token is valid (~2 hours) we just add headers=headers to our requests
requests.get('https://oauth.reddit.com/api/v1/me', headers=headers)

<Response [200]>

## Retrieving and Filtering Data from Reddit

This section of code demonstrates how to use the OAuth token obtained earlier to retrieve and filter posts from a specific subreddit. The filtered data is then stored in a pandas DataFrame for further analysis.

### Key Components of the Code:
1. **Query Parameters**:
   - The `params` dictionary specifies the number of posts to retrieve (`limit: 1000`).

2. **Making a GET Request**:
   - The GET request retrieves submissions that meet the specified search criteria:
     - Subreddit: `wallstreetbets`
     - Flair: `DD` (Deep Dive)
     - Sorted by newest posts.

3. **Initializing an Empty DataFrame**:
   - A pandas DataFrame is initialized to store the relevant fields from the posts.

4. **Filtering and Storing Posts**:
   - Each post retrieved is checked for the desired flair and subreddit name.
   - If it matches the criteria, a new DataFrame is created from that post's data and concatenated with the main DataFrame.


In [2]:
params = {'limit': 1000}
res = requests.get("https://oauth.reddit.com/r/wallstreetbets/search/?q=flair%3ADD&restrict_sr=on&include_over_18=on&sort=new",
                   headers=headers,
                   params=params)

import pandas as pd
df = pd.DataFrame()  # initialize dataframe

# loop through each post retrieved from GET request
for post in res.json()['data']['children']:

  # append relevant data to dataframe
  if post['data']['link_flair_richtext'] == [{'e': 'text', 't': 'DD'}] and post['data']['subreddit'] == "wallstreetbets":
    new_data = pd.DataFrame([{
        'subreddit': post['data']['subreddit'],
        'title': post['data']['title'],
        'selftext': post['data']['selftext'],
        'upvote_ratio': post['data']['upvote_ratio'],
        'ups': post['data']['ups'],
        'downs': post['data']['downs'],
        'score': post['data']['score'],
        'flair': post['data']['link_flair_richtext']
    }])
      # Concatenate with the existing DataFrame
    df = pd.concat([df, new_data], ignore_index=True)


In [4]:
#Here is a brif overview of the data collected
df
df.to_csv('wsbdata.csv')

Unnamed: 0,subreddit,title,selftext,upvote_ratio,ups,downs,score,flair
0,wallstreetbets,PG&amp;E California Power Utility ($PCG),"Richmond, California just passed a resolution ...",0.91,9,0,9,"[{'e': 'text', 't': 'DD'}]"
1,wallstreetbets,Uber position update and thoughts going forward,**Position Update**\n\nAbout a month ago I pos...,0.73,9,0,9,"[{'e': 'text', 't': 'DD'}]"
2,wallstreetbets,INTC Outlook and Marketcap vs Assets,"Why is foundry for Intel so important, despite...",0.70,9,0,9,"[{'e': 'text', 't': 'DD'}]"
3,wallstreetbets,DD - Put ProKidney on your radar,I know y'all like shouty AI-generated DD full ...,0.76,9,0,9,"[{'e': 'text', 't': 'DD'}]"
4,wallstreetbets,Robinhood Q1 Earnings - Is It A Trap?,**Summary:**\n\nI believe Robinhood is poised ...,0.80,32,0,32,"[{'e': 'text', 't': 'DD'}]"
...,...,...,...,...,...,...,...,...
94,wallstreetbets,Cannabis - not too late to get high bros,EDIT: I WAS RIGHT YOU REGARDS!!! LETS FUCKIN ...,0.84,4571,0,4571,"[{'e': 'text', 't': 'DD'}]"
95,wallstreetbets,CHK-SWN merger to create largest US natural ga...,I’ve been following Chesapeake Energy since Ja...,0.92,70,0,70,"[{'e': 'text', 't': 'DD'}]"
96,wallstreetbets,Solventum- Potentially undervalued spin-off,The following Due Diligence is pretty simple a...,0.72,13,0,13,"[{'e': 'text', 't': 'DD'}]"
97,wallstreetbets,TSMC Earthquake DD,**Like most of you guys I was initially worrie...,0.81,71,0,71,"[{'e': 'text', 't': 'DD'}]"


## Using Pushshift API to Retrieve Reddit Submissions

In this section, we access Reddit submissions via the Pushshift API. The collected data is organized into a pandas DataFrame for further analysis.

### Key Components of the Code:
1. **Pushshift API Request Function** (`getPushshiftData`):
   - This function constructs a query URL based on parameters like title (`query`), time frame (`after`, `before`), and subreddit name (`sub`).
   - The request is sent using the `requests` library, and the JSON response is parsed to obtain the submissions' data.

2. **Data Extraction Function** (`collectSubData`):
   - The function extracts key data points (title, URL, author, score, etc.) from each submission.
   - Error handling ensures that missing flair and selftext fields are marked as "NaN."
   - The extracted data is organized into a pandas DataFrame.

3. **Concatenating Submissions**:
   - The submission data for each post is merged into a global DataFrame (`subStats_df`) using `pd.concat`.
   - The global DataFrame is updated within the function to accumulate all the submissions.


In [2]:
# going to try pushshift api
import pandas as pd
import requests
import json
import csv
import time
import datetime

def getPushshiftData(query, after, before, sub):
    url = 'https://api.pushshift.io/reddit/search/submission/?title='+str(query)+'&size=1000&after='+str(after)+'&before='+str(before)+'&subreddit='+str(sub)
    print(url)
    r = requests.get(url)
    data = json.loads(r.text)
    return data['data']

#This function will be used to extract the key data points from each JSON result
def collectSubData(subm):
    try:
        flair = subm['link_flair_richtext']
    except KeyError:
        flair = "NaN"  # Handle missing flair
    try:
        selftext = subm['selftext']
    except KeyError:
        selftext = "NaN"  # Handle missing selftext

    # Create a DataFrame for the single submission
    subData = pd.DataFrame([{
        'sub_id': subm['id'],
        'title': subm['title'],
        'url': subm['url'],
        'author': subm['author'],
        'score': subm['score'],
        'created': datetime.datetime.fromtimestamp(subm['created_utc']),
        'selftext': selftext,
        'flair': flair
    }])
    # Use pd.concat to merge with the global DataFrame
    global subStats_df  # To modify the global DataFrame within the function
    subStats_df = pd.concat([subStats_df, new_data], ignore_index=True)

## Setting Up Search Parameters for the Pushshift API

In this section, the code sets up essential search parameters for querying submissions from the Pushshift API.

### Key Components of the Code:
1. **Time Range**:
   - The `after` and `before` variables represent Unix timestamps that define the time range for the search.
   - Unix timestamps can be created or converted using online tools like [unixtimestamp.com](https://www.unixtimestamp.com/index.php).

2. **Search Query**:
   - The `query` variable specifies keywords to search for within submissions (e.g., "DD" for Deep Dive).

3. **Subreddit**:
   - The `sub` variable defines the specific subreddit to query, in this case, `wallstreetbets`.

4. **Tracking Data Collection**:
   - `subCount`: Tracks the total number of submissions collected.
   - `subStats`: A dictionary used to store collected submissions data.


In [3]:
#Create your timestamps and queries for your search URL
#https://www.unixtimestamp.com/index.php > Use this to create your timestamps

after = "1672549200" #Submissions after this timestamp
before = "1641013200" #Submissions before this timestamp
query = "DD" #Keyword(s) to look for in submissions
sub = "wallstreetbets" #Which Subreddit to search in

#subCount tracks the no. of total submissions we collect
subCount = 0
#subStats is the dictionary where we will store our data.
subStats = {}

## Iteratively Collecting Reddit Submissions

In this section, we utilize the functions defined earlier to iteratively collect submissions using the Pushshift API, while adjusting the search range dynamically.

### Key Components of the Code:
1. **Initial API Call**:
   - The function `getPushshiftData` is called first with the given parameters (`after`, `before`, `query`, and `sub`) to initialize the `data` variable.

2. **While Loop for Data Collection**:
   - The loop continues until `data` is empty, indicating no more submissions in the specified range.
   - Inside the loop:
     - **Data Extraction**: Each submission in the `data` list is passed to `collectSubData` for processing and storage, and the submission count (`subCount`) is incremented.
     - **Updating the 'After' Variable**: The `after` variable is updated to the `created_utc` timestamp of the last submission retrieved. This ensures that the next API call continues where the previous left off.

3. **Fetching More Data**:
   - The API call is repeated using the updated `after` value, collecting more data within the specified range until no new data is found.


In [4]:
# We need to run this function outside the loop first to get the updated after variable
data = getPushshiftData(query, after, before, sub)
# Will run until all posts have been gathered i.e. When the length of data variable = 0
# from the 'after' date up until before date
while len(data) > 0: #The length of data is the number submissions (data[0], data[1] etc), once it hits zero (after and before vars are the same) end
    for submission in data:
        collectSubData(submission)
        subCount+=1
    # Calls getPushshiftData() with the created date of the last submission
    print(len(data))
    print(str(datetime.datetime.fromtimestamp(data[-1]['created_utc'])))
    #update after variable to last created date of submission
    after = data[-1]['created_utc']
    #data has changed due to the new after variable provided by above code
    data = getPushshiftData(query, after, before, sub)

print(len(data))

https://api.pushshift.io/reddit/search/submission/?title=DD&size=1000&after=1672549200&before=1641013200&subreddit=wallstreetbets


KeyError: 'data'

In [None]:
print(str(len(subStats)) + " submissions have added to list")
print("1st entry is:")
print(list(subStats.values())[0][0][1] + " created: " + str(list(subStats.values())[0][0][5]))
print("Last entry is:")
print(list(subStats.values())[-1][0][1] + " created: " + str(list(subStats.values())[-1][0][5]))

171 submissions have added to list
1st entry is:
watch list for Monday 23rd...do your own DD no financial advice created: 2023-01-22 22:23:03
Last entry is:
NAIL in the Coffin (light dd) created: 2022-11-04 16:35:10


In [None]:
def updateSubs_file():
    upload_count = 0
    #location = "\\Reddit Data\\" >> If you're running this outside of a notebook you'll need this to direct to a specific location
    print("wsbpushdata.csv")
    filename = input() #This asks the user what to name the file
    file = filename
    with open(file, 'w', newline='', encoding='utf-8') as file:
        a = csv.writer(file, delimiter=',')
        headers = ["Post ID","Title","Url","Author","Score","Publish Date", "Selftext", "Flair"]
        a.writerow(headers)
        for sub in subStats:
            a.writerow(subStats[sub][0])
            upload_count+=1

        print(str(upload_count) + " submissions have been uploaded")
updateSubs_file()

wsbpushdata.csv
2023datafirsttest.csv
171 submissions have been uploaded


## Installing `reddit-data-collector` for Efficient Data Collection

To simplify Reddit data collection, we utilize the `reddit-data-collector` Python package. Installing it with `pip` helps us manage Reddit API queries more efficiently.

### Key Benefits and Usage:
1. **Effortless API Management**:
   - The package abstracts complexities of Reddit's API, allowing faster data collection for different types of Reddit posts and comments.
   - This aligns well with our goal of fetching a significant amount of posts using Pushshift or the official Reddit API.

2. **Incorporating with Existing Code**:
   - Using `reddit-data-collector` can replace custom API calls and data processing logic.
   - It provides pre-built functions to specify search criteria, similar to `getPushshiftData`.

3. **Installation Command**:
   - To install the package, use:

In [8]:
!pip install reddit-data-collector

Collecting reddit-data-collector
  Obtaining dependency information for reddit-data-collector from https://files.pythonhosted.org/packages/62/9d/3994784d7609692163b50a2682dbaaaf8a00387b287a515267af1d9f6d6b/reddit_data_collector-1.1.0-py3-none-any.whl.metadata
  Downloading reddit_data_collector-1.1.0-py3-none-any.whl.metadata (4.9 kB)
Collecting praw>=7.5.0 (from reddit-data-collector)
  Obtaining dependency information for praw>=7.5.0 from https://files.pythonhosted.org/packages/81/6a/21bc058bcccbe03f6a0895bf1bd60c805f0c526aa4e9bfaac775ed0b299c/praw-7.7.1-py3-none-any.whl.metadata
  Downloading praw-7.7.1-py3-none-any.whl.metadata (9.8 kB)
Collecting prawcore<3,>=2.1 (from praw>=7.5.0->reddit-data-collector)
  Obtaining dependency information for prawcore<3,>=2.1 from https://files.pythonhosted.org/packages/96/5c/8af904314e42d5401afcfaff69940dc448e974f80f7aa39b241a4fbf0cf1/prawcore-2.4.0-py3-none-any.whl.metadata
  Downloading prawcore-2.4.0-py3-none-any.whl.metadata (5.0 kB)
Collecti

## Setting Up Reddit API Credentials

Here, we store sensitive Reddit API credentials and user details as variables to authenticate requests to Reddit's API.

### Components Explained:

1. **Credentials**:
   - `client_id`: Unique identifier assigned to the application via Reddit's developer dashboard.
   - `client_secret`: Secret token used alongside `client_id` for secure API access.
   - `user_agent`: Custom identifier string that describes the application. This should include a name and version, like `"myapp/0.1"`.

2. **User Credentials**:
   - `username`: Reddit account username, necessary for authenticated access to private data or posting.
   - `password`: Corresponding password to authenticate the Reddit account.


In [9]:
import json
client_id = "3gn07Oqp2_sut_zD3n8Dnw"
client_secret = "WED8ZTRAfziqJPu-uugMfFGVhl-ozw"
user_agent = "ddhandsbot"
username = "diamonddurfhands"
password = "**DiamondHands22"

### Overview of Reddit Data Collection

This section describes the data collection process using a Reddit Data Collector class. The purpose is to gather structured data about posts and comments from specified subreddits on Reddit.

**Key Components:**

1. **`DataCollector` Class**  
   - **Initialization**: Requires Reddit API credentials (client ID, secret, and user agent) and optionally the username and password to access the Reddit API.
   - **Subreddit Verification**: The class checks if specified subreddits exist before attempting to collect data, raising a `SubredditError` if not.
   - **Post Filters**: Validates the filtering method for posts (e.g., "hot", "new", "top") and raises a `FilterError` if an invalid filter is used.
   - **Top Post Filters**: Similar validation ensures the "top" filter only uses valid parameters (like "day," "week," etc.).

2. **Data Collection Methods**  
   - **Post Data**: Retrieves posts based on the specified filter method (e.g., "new," "hot," "top"). Each post's attributes include title, score, URL, and other metadata.
   - **Comment Data**: Optionally gathers comments and replies to each post if enabled. The depth of data collected can be configured via parameters.
   - **Output Format**: Data can be returned as either pandas DataFrames or Python dictionaries for posts and comments.

3. **Helper Functions**  
   - Subreddit verification and filter validation functions assist with error handling.
   - Post and comment retrieval functions are modular, streamlining data collection based on various criteria.

The class ensures that users can efficiently collect data from Reddit by handling errors gracefully, maximizing API usage, and providing flexibility in data organization and structure.


In [10]:
import praw
import pandas as pd
from tqdm import tqdm
def to_pandas(subreddit_data, separate=False):
    """Convert raw post or comment data collected to a pandas `DataFrame`.
    Parameters
    ----------
    subreddit_data : dict
        Raw post or comment data collected with the `DataCollector.get_data`
        method.
    separate : bool, default=False
        Whether or not to return a separate pandas `DataFrame` for the
        data of each subreddit.
    Returns
    -------
    df or dfs : pd.DataFrame or dict
        If separate is `False`, returns a pandas `DataFrame` containing
        the post or comment data.
        If separate is `True` returns a Python dictionary containing
        a pandas `DataFrame` for each subreddit that existed in the
        post or comment data.  The dctionary keys are the subreddits
        names.  The dictionary values are pandas `DataFrame`s of post
        or comment data.
    See Also
    --------
    reddit_data_collector.reddit_data_collector.DataCollector
        Class that performs the data collection from Reddit.
    reddit_data_collector.io.update_data
        Update a `.csv` file containing existing post or comment
        data collected with new data collected with `DataCollector`.
    """
    dfs = dict()

    for subreddit, data in subreddit_data.items():
        dfs[subreddit] = pd.DataFrame(data)

    if separate:
        return dfs
    else:
        return pd.concat(dfs.values(), ignore_index=True)


def update_data(csv_path, df, key="id", sort="subreddit_name", save=True):
    """Update a `.csv` file containing post or comment data with new data.
    The main purpose of this method is to allow a user to update a `.csv`
    file that contains historical data that they collected with Reddit Data
    Collector with new data collected.  The default method settings ensure
    that duplicated post or comment data, if any, is not saved to the `.csv`
    file.  In other words only one copy of each post or comment is kept in
    the combined data.
    If the `save` parameter is set to `True` then the method will automatically
    overwrite the existing `.csv` file.  Otherwise it will just return the
    combined data to the user as a pandas `DataFrame` for which they can
    then save with the pandas `DataFrame.to_csv` method when desired.
    Parameters
    ----------
    csv_path : str
        The file path to the existing `.csv` file.
    df : pandas DataFrame
        A pandas `DataFrame` containing the new data collected.  It is
        recommended that this `DataFrame` comes from the output of the
        `to_pandas` method in `reddit_data_collector.io`.
    key : str, default="id"
        The key to remove duplicate data on.  Default is the post or
        comment `id` as set by Reddit.  It is not recommended to set
        this parameter manually.  However, it is included as a parameter
        in case for some reason duplicate data is desired.
    sort : str, default="subreddit_name"
        How to sort the new data.  By default sorts the data by subreddit.
        This is purely aesthetic and has no impact on the data itself.
    save : bool, default=True
        Whether or not to automatically overwrite the existing `.csv`
        file with the new data.
    Returns
    -------
    new_df : pandas DataFrame
        A pandas `DataFrame` containing the newly combined post or comment
        data.
    Raises
    ------
    ColumnNameError
        If the update is attempted with two pandas `DataFrame`s that have
        different column names.
    See Also
    --------
    reddit_data_collector.reddit_data_collector.DataCollector
        Class that performs the data collection from Reddit.
    reddit_data_collector.io.to_pandas
        Used to convert raw `posts` or `comments` collected with
        `DataCollector` to a pandas `DataFrame`.
    Examples
    --------
    >>> import reddit_data_collector as rdc
    >>> # create instance of DataCollector
    >>> data_collector = rdc.DataCollector(
    ...     "<your_client_id>",
    ...     "<your_client_secret>",
    ...     "mac:script:v1.0 (by u/FakeRedditUser)",
    ...     "FakeRedditUser",
    ...     "FakePassword"
    ... )
    >>> # collect some data from Reddit
    >>> subreddits = ["pics", "funny"]
    >>> post_filter = "hot"
    >>> comment_data = True
    >>> replies_data = True
    >>> posts, comments = data_collector(
    ...     subreddits=subreddits,
    ...     post_filter=post_filter,
    ...     comment_data=comment_data,
    ...     replies_data=replies_data
    ... )
    >>> # update existing .csv file
    >>> new_posts_df = rdc.update_data("post_data.csv", posts_df)
    >>> new_comments_df = rdc.update_data("comment_data.csv", comments_df)
    """

    if not set(pd.read_csv(csv_path).columns) == set(df.columns):
        raise ColumnNameError("Both data sets must have the same features")

    old_df = pd.read_csv(csv_path)

    new_df = (
        pd.concat([old_df, df], ignore_index=True)
        .drop_duplicates(subset=[key], ignore_index=True)
        .sort_values(sort, ignore_index=True)
    )

    if save:
        new_df.to_csv(csv_path, index=False)

    return new_df
class SubredditError(Exception):
    """Exception class raised if invalid subreddit is used.
    Examples
    --------
    >>> from reddit_data_collector.exceptions import SubredditError
    >>> try:
    ...     data_collector.get_data(subreddits="1nv4ald")
    ... except SubredditError as e:
    ...     print(repr(e))
    SubredditError('r/1nv4ald does not exist')
    """

    pass


class FilterError(Exception):
    """Exception class raised if an invalid post or top post filter is used.
    Examples
    --------
    >>> from reddit_data_collector.exceptions import FilterError
    >>> try:
    ...     data_collector.get_data(subreddits="funny", post_filter="any")
    ... except FilterError as e:
    ...     print(repr(e))
    FilterError('Invalid post_filter used: any')
    >>> try:
    ...     data_collector.get_data(
    ...        subreddits="funny",
    ...        post_filter="top",
    ...        top_post_filter="now"
    ...     )
    ... except FilterError as e:
    ...     print(repr(e))
    FilterError('Invalid top_post_filter used: now')
    """

    pass


class ColumnNameError(Exception):
    """Exception class used if data update is attempted with mismatched columns.
    Examples
    --------
    >>> import pandas as pd
    >>> from reddit_data_collector.exceptions import ColumnNameError
    >>> csv_path = "example.csv"
    >>> # create and save first DataFrame
    >>> df = pd.DataFrame(data=[[1, 2], [3, 4]], columns=["a", "b"])
    >>> df.to_csv(path, index=False)
    >>> # create second DataFrame
    >>> df2 = pd.DataFrame(data=[[5, 6], [7, 8]], columns=["c", "d"])
    >>> try:
    ...    rdc.update_data(csv_path, df2)
    ... except ColumnNameError as e:
    ...    print(repr(e))
    ColumnNameError('Both data sets must have the same features')
    """

    pass

class DataCollector:
    """Object that performs data collection from Reddit.
    Once a `DataCollector` object is instantiated, you simply need to pass the subreddit
    name(s) that you desire to collect data from to the method `get_data`, and the data
    collection will be performed.
    Please see the Reddit's "First Step Guide" which describes how to obtain the
    `client_id` and `client_secret` parameters below:
    https://github.com/reddit-archive/reddit/wiki/OAuth2-Quick-Start-Example#first-steps
    Important: If you instantiate `DataCollector` without a Reddit username and password,
    it will have read only access to the reddit API, which is limited to 30 requests
    per minute.  However, if you do provide a Reddit username and password, it will
    have full access to the API and an increased limit of 60 requests per minute.  Full
    access can increase data collection by 2x.
    Finally, for safety, it is recommended that the parameters below are not hard-coded
    directly into a program that uses Reddit Data Collector.  Rather, they should be
    kept in a separate credentials file as data which is then read into the program.
    (e.g. a JSON credentials file that is read into a program with a Python `with`
    clause).
    Parameters
    ----------
    client_id : str
        The client id for your Reddit application.
    client_secret : str
        The client secret for your Reddit application.
    user_agent : str
        A unique identifier that helps Reddit determine the souce of network requests.
        To use Reddit's API, you need a unique and descriptive user agent.  The
        following format is recommended:
            <platform>:<app ID>:<version string> (by u/<Reddit username>)
    username : str, default=None
        Your Reddit username.
    password : str, default=None
        Your Reddit password.
    Attributes
    ----------
    reddit : praw.Reddit
        An instance of the PRAW `Reddit` class that provides access to Reddit's API.
    Examples
    --------
    >>> import reddit_data_collector as rdc
    >>> # create instance of DataCollector
    >>> data_collector = rdc.DataCollector(
    ...     "<your_client_id>",
    ...     "<your_client_secret>",
    ...     "mac:script:v1.0 (by u/FakeRedditUser)",
    ...     "FakeRedditUser",
    ...     "FakePassword"
    ... )
    >>> # collect some data from Reddit
    >>> posts, comments = data_collector.get_data(
    ...     subreddits=["pics", "funny"],
    ...     post_filter="hot",
    ...     post_limit=10,
    ...     comment_data=True,
    ...     replies_data=True,
    ...     replace_more_limit=0
    ... )
    >>> # save data as .csv files
    >>> posts.to_csv("posts.csv", index=False)
    >>> comments.to_csv("posts.csv", index=False)
    """

    def __init__(
        self, client_id, client_secret, user_agent, username=None, password=None
    ):
        self.reddit = praw.Reddit(
            client_id=client_id,
            client_secret=client_secret,
            user_agent=user_agent,
            username=username,
            password=password,
        )

    def get_data(
        self,
        subreddits,
        post_filter="new",
        post_limit=None,
        top_post_filter=None,
        comment_data=True,
        replies_data=False,
        replace_more_limit=0,
        dataframe=True,
    ):
        """Collects post and comment data from Reddit.
        Parameters
        ----------
        subreddits : str or list of str
            The subreddit(s) to collect post and comment data from.
        post_filter : str, default="new"
            How to filter the subreddit posts.  Must be one of:  new, hot, or top.
        post_limit : int, default=None
            The number of posts to collect.  A limit of `None` sets the limit to
            the max allowed by Reddit's API, which is 1,000 in most cases.
        top_post_filter : str, default=None
            Determines how to filter the top posts for a subreddit.  Only required
            if `post_filter` is set to "top". Must be one of: all, day, hour, month,
            week, or year.
        comment_data : bool, default=True
            Whether or not to collect comment data for each post that is collected.
            If only post data is desired, set to `False`.  Only collecting posts can
            significantly speed up data collection since it will likely reduce the
            number of API requests by a lot.
        replies_data : bool, default=False
            Whether or not to collect the data for all replies to each comment that
            is collected for each post.  Setting this to `True` can cause the script
            to take arbitrarily long, as some reddit comments can have arbitrarily
            long reply threads.  Think carefully if you actually need this data before
            setting this parameter to True.  Often times, reply threads will contain
            useless data, since they often contain discussions of people trolling one
            another.
        replace_more_limit : int, default=0
            The number of PRAW `MoreComment` instances to replace when collecting
            comment data.  If you don't know what this means, the recommended
            setting is a value between 0 to 32.  A higher value means that
            potentially more comments will be collected in a sample. You can also
            set this to `None` which will ensure all comments and replies on a
            single post are collected.  Note that, setting this to an integer value
            higher than 32 or to `None` can significantly slow down the script,
            since this can increase the number of API calls drastically.
            For more info on the PRAW `MoreComment` class read this:
            https://praw.readthedocs.io/en/stable/tutorials/comments.html
        dataframe : bool, default=True
            Whether or not to return the collected data as a pandas DataFrame.
            If False, the data is returned in the raw form of a dictionary,
            where the keys for each dictionary are the subreddit name(s) and
            the values for each dictionary are the data collected.
        Returns
        -------
        posts, comments : pandas DataFrames
            The collected reddit data.
            If `comment_data` is False, `None` is returned for `comments`.
        See Also
        --------
        reddit_data_collector.io.to_pandas
            Used to convert raw `posts` or `comments` to a pandas `DataFrame`.
            Not needed if dataframe argument is left as True.
        reddit_data_collector.io.update_data
            Used to update an existing `.csv` file that contains prior data collected
            with Reddit Data Collector with new data collected.
        """
        if isinstance(subreddits, str):
            subreddits = [subreddits]

        self._verify_subreddits(subreddits)
        self._verify_post_filter(post_filter)

        if top_post_filter is not None:
            self._verify_top_post_filter(top_post_filter)

        posts = self._get_posts(subreddits, post_filter, post_limit, top_post_filter)

        if comment_data:
            comments = self._get_comments(posts, replies_data, replace_more_limit)
        else:
            comments = None

        if dataframe:
            posts = to_pandas(posts)

            if comments is not None:
                comments = to_pandas(comments)

        return posts, comments

    # ------------------------------HELPER FUNCTIONS------------------------------ #

    def _verify_subreddits(self, subreddits):
        """Verifies that each subreddit in a list of subreddits exist."""
        for subreddit in subreddits:
            if not self._check_subreddit_exists(subreddit):
                msg = f"r/{subreddit} does not exist"
                raise (SubredditError(msg))

    def _check_subreddit_exists(self, subreddit):
        """Checks if a subreddit exists."""
        # PRAW Subreddits instance
        subreddits = self.reddit.subreddits

        # may return numerous similar subreddits, first value should match
        exists = subreddits.search_by_name(subreddit)

        if not exists:
            return False
        else:
            return exists[0].display_name.lower() == subreddit.lower()

    def _verify_post_filter(self, post_filter):
        """Verifies that a post filter is valid.
        Raises FilterError if a invalid post filter is used.
        """
        if post_filter.lower() not in ["new", "hot", "top"]:
            msg = f"Invalid post_filter used: {post_filter}"
            raise (FilterError(msg))

    def _verify_top_post_filter(self, top_post_filter):
        """Verifies that a top post filter is valid.
        Raises FilterError if a invalid top post filter is used.
        """
        if top_post_filter.lower() not in [
            None,
            "all",
            "day",
            "hour",
            "month",
            "week",
            "year",
        ]:
            msg = f"Invalid top_post_filter used: {top_post_filter}"
            raise (FilterError(msg))

    def _get_posts(self, subreddits, post_filter, post_limit, top_post_filter):
        """Collects the post data for each subreddit in a list of subreddits."""
        posts = dict()

        for subreddit in subreddits:
            posts[subreddit] = self._get_subreddit_posts(
                subreddit, post_filter, post_limit, top_post_filter
            )

        return posts

    def _get_subreddit_posts(self, subreddit, post_filter, post_limit, top_post_filter):
        """Collects the post data for a single subreddit."""
        subreddit_posts = []

        # convert to PRAW Subreddit instance
        subreddit = self.reddit.subreddit(subreddit)

        desc = f"Collecting {post_filter} r/{subreddit} posts"

        # a "submission" is an instance of the PRAW Subission class
        if post_filter.lower() == "new":
            for submission in tqdm(subreddit.new(limit=post_limit), desc, post_limit):
                subreddit_posts.append(self._get_post_data(submission))

        elif post_filter.lower() == "hot":
            for submission in tqdm(subreddit.hot(limit=post_limit), desc, post_limit):
                subreddit_posts.append(self._get_post_data(submission))

        elif post_filter.lower() == "top":
            for submission in tqdm(subreddit.top(time_filter=top_post_filter), desc):
                subreddit_posts.append(self._get_post_data(submission))

        return subreddit_posts

    def _get_post_data(self, submission):
        """Collects the data for a single post in a subreddit."""
        post_data = {
            "subreddit_name": submission.subreddit.display_name,
            "post_created_utc": submission.created_utc,
            "id": submission.id,
            "selftext": submission.selftext,
            "is_original_content": submission.is_original_content,
            "is_self": submission.is_self,
            "link_flair_text": submission.link_flair_text,
            "locked": submission.locked,
            "num_comments": submission.num_comments,
            "over_18": submission.over_18,
            "score": submission.score,
            "spoiler": submission.spoiler,
            "stickied": submission.stickied,
            "title": submission.title,
            "upvote_ratio": submission.upvote_ratio,
            "url": submission.url,
        }

        return post_data

    def _get_comments(self, posts, replies_data, replace_more_limit):
        """Collects the comment data for each subreddit in a list of subreddits."""
        comments = dict()

        for subreddit, subreddit_post_data in posts.items():
            comments[subreddit] = self._get_subreddit_comments(
                subreddit, subreddit_post_data, replies_data, replace_more_limit
            )

        return comments

    def _get_subreddit_comments(
        self, subreddit, subreddit_post_data, replies_data, replace_more_limit
    ):
        """Collects the comment data for posts in a single subreddit."""
        subreddit_comments = []

        desc = f"Collecting comments for {len(subreddit_post_data)} r/{subreddit} posts"

        # a "submission" is an instance of the PRAW Subission class
        for post in tqdm(subreddit_post_data, desc, len(subreddit_post_data)):
            submission = self.reddit.submission(id=post["id"])
            submission.comments.replace_more(limit=replace_more_limit)

            if replies_data:
                for comment in submission.comments.list():
                    comment_data = self._get_comment_data(subreddit, comment)
                    subreddit_comments.append(comment_data)
            else:
                for comment in submission.comments:
                    comment_data = self._get_comment_data(subreddit, comment)
                    subreddit_comments.append(comment_data)

        return subreddit_comments

    def _get_comment_data(self, subreddit, comment):
        """Collects the data for a single comment on a subreddit post."""
        comment_data = {
            "subreddit_name": subreddit,
            "id": comment.id,
            "post_id": comment.link_id,
            "parent_id": comment.parent_id,
            "top_level_comment": comment.parent_id == comment.link_id,
            "body": comment.body,
            "comment_created_utc": comment.created_utc,
            "is_submitter": comment.is_submitter,
            "score": comment.score,
            "stickied": comment.stickied,
        }

        return comment_data

In [11]:
#import reddit_data_collector as rdc
data_collector = DataCollector(
    client_id=client_id,
    client_secret=client_secret,
    user_agent=user_agent,
    username=username,
    password=password
)

In [12]:
posts, comments = data_collector.get_data(
    subreddits=["wallstreetbets"],
    post_filter="new",
    post_limit=2000,
    comment_data=False,
    replies_data=False,
    replace_more_limit=1000,
    dataframe=True
)

Collecting new r/wallstreetbets posts:  48%|▍| 963/2000 [00:13<00:14, 69.31it/s]


In [None]:
posts.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 866 entries, 0 to 865
Data columns (total 16 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   subreddit_name       866 non-null    object 
 1   post_created_utc     866 non-null    float64
 2   id                   866 non-null    object 
 3   selftext             866 non-null    object 
 4   is_original_content  866 non-null    bool   
 5   is_self              866 non-null    bool   
 6   link_flair_text      863 non-null    object 
 7   locked               866 non-null    bool   
 8   num_comments         866 non-null    int64  
 9   over_18              866 non-null    bool   
 10  score                866 non-null    int64  
 11  spoiler              866 non-null    bool   
 12  stickied             866 non-null    bool   
 13  title                866 non-null    object 
 14  upvote_ratio         866 non-null    float64
 15  url                  866 non-null    obj

In [None]:
posts.to_csv("wsbnewdata.csv", index=False)