<a href="https://colab.research.google.com/github/xjdeng/subreddit-ai-query/blob/main/ai_query_subreddit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install praw path.py==12.0.1

In [None]:
import praw
from google.colab import userdata
import time
import google.generativeai as genai
import pprint
import json
from path import Path as path


In [None]:
client_id = userdata.get("reddit_client_id") #Set up Reddit Credentials, see video: https://www.youtube.com/watch?v=VAJFZEeKjSY
assert client_id is not None
client_secret = userdata.get("reddit_client_secret")
assert client_secret is not None
username = userdata.get("reddit_username")
assert username is not None
password = userdata.get("reddit_password")
assert password is not None
app_name = userdata.get("reddit_app")
reddit = praw.Reddit(client_id=client_id,
                     client_secret=client_secret,
                     user_agent=app_name,
                     username=username, \
                     password=password, check_for_async=False)
GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY') #Set up Gemini Credentials, see video: https://www.youtube.com/watch?v=S1elvCs1gyI
genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel("gemini-1.5-flash-latest")



In [None]:
def get_posts(sub, subreddit_query=None, LIMIT=100, time_filter="year", score_cutoff=5, comments_cutoff=5):
    subreddit = reddit.subreddit(sub)

    if subreddit_query:
        # Perform a search in the subreddit
        search_results = list(subreddit.search(subreddit_query, sort='relevance', time_filter=time_filter, limit=LIMIT))
        allposts = search_results
    else:
        # Default behavior: Fetch hot, top, and new posts
        hot = list(subreddit.hot(limit=LIMIT))
        top = list(subreddit.top(time_filter=time_filter, limit=LIMIT))
        recent = list(subreddit.new(limit=LIMIT))
        allposts = list(set(hot + top + recent))

    # Filter posts based on score and number of comments
    allposts = [post for post in allposts if post.score >= score_cutoff and len(post.comments) > comments_cutoff]
    return allposts

def get_comments(post, LIMIT=50):
    post.comments.replace_more(limit=0)  # Avoid excessive calls for nested comments
    comments = [comment.body for comment in post.comments[:LIMIT]]
    return comments

def pipeline(subreddit_name, subreddit_query = None, postlimit = 100, post_score_cutoff = 5, min_comments = 5, max_comments = 50):
  posts = get_posts(subreddit_name, subreddit_query = subreddit_query, LIMIT=postlimit, score_cutoff = post_score_cutoff, comments_cutoff = min_comments)
  data = []
  for post in posts:
      comments = get_comments(post, LIMIT=max_comments)  # Adjust comment limit as needed
      data.append({
          'title': post.title,
          'body': post.selftext,
          'score': post.score,
          'comments': comments
      })
      time.sleep(2)
  return data


In [None]:
def query_subreddit(query, subreddit, subreddit_query = None, *args, **kwargs):
  if subreddit_query:
    jsonfile = f"{subreddit}-{subreddit_query}.jsonl"
  else:
    jsonfile = f"{subreddit}.jsonl"
  if not path(jsonfile).exists():
    data = pipeline(subreddit, subreddit_query, *args, **kwargs)
    json.dump(data, open(jsonfile, "w"))
  else:
    data = json.load(open(jsonfile))
  prompt = f"""

  I'd like to ask a question to the following subreddit: /r/{subreddit}

  Here's the question:
  ---
  {query}
  ---

  Do not download data from Internet, instead, formulate your answer using the following data that I've downloaded from the subreddit:
  ---
  {data}
  ---
  """
  response = model.generate_content(prompt)
  return response.text


In [None]:
query = """
What are examples of good jobs for INTPs?
"""

pprint.pprint(query_subreddit(query,"INTP"))