<a href="https://colab.research.google.com/github/xjdeng/mbtimodel/blob/main/mbti_model_reddit_gemini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/xjdeng/mbtimodel
!pip install praw google-generativeai
!pip install -r mbtimodel/requirements.txt

In [None]:
import joblib
import praw
from google.colab import userdata
import bs4
import markdown
import re
import pandas as pd
import praw
from prawcore.exceptions import NotFound
import string
import google.generativeai as genai
import enum
from typing_extensions import TypedDict
import pprint

In [None]:

# IMPORTANT: Set up your Reddit AND your Gemini credentials and enter them into your Google Colab Secrets
# See Vid for instructions, for Reddit: https://www.youtube.com/watch?v=VAJFZEeKjSY
# For Gemini: https://www.youtube.com/watch?v=S1elvCs1gyI

GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')
assert GOOGLE_API_KEY is not None
genai.configure(api_key=GOOGLE_API_KEY)
client_id = userdata.get("reddit_client_id")
assert client_id is not None
client_secret = userdata.get("reddit_client_secret")
assert client_secret is not None
username = userdata.get("reddit_username")
assert username is not None
password = userdata.get("reddit_password")
assert password is not None
app_name = userdata.get("reddit_app")
reddit = praw.Reddit(client_id=client_id,
                     client_secret=client_secret,
                     user_agent=app_name,
                     username=username, \
                     password=password)

model = genai.GenerativeModel("gemini-1.5-flash-latest")

In [None]:
class MBTI(enum.Enum):
    ISTJ = "ISTJ"
    ISFJ = "ISFJ"
    INFJ = "INFJ"
    INTJ = "INTJ"
    ISTP = "ISTP"
    ISFP = "ISFP"
    INFP = "INFP"
    INTP = "INTP"
    ENTP = "ENTP"
    ESTP = "ESTP"
    ESFP = "ESFP"
    ENFP = "ENFP"
    ESTJ = "ESTJ"
    ESFJ = "ESFJ"
    ENFJ = "ENFJ"
    ENTJ = "ENTJ"

class Confidence(enum.Enum):
  low = "low"
  medium = "medium"
  high = "high"

class Personality(TypedDict):
    personality: MBTI
    confidence: Confidence
    explanation: str

def noquotes(text):
    """
This function first stated out as a way to remove markdown quotes from raw reddit markdown text but now it's more of a
general purpose text parser, but the name hasn't changed.
    """
    #https://stackoverflow.com/questions/761824/python-how-to-convert-markdown-formatted-text-to-text
    html = markdown.markdown(text)
    text = ''.join(bs4.BeautifulSoup(html, 'lxml').findAll(string=True))
    t1 = re.sub(">.+?(\n|$)","",text).replace("\\n","").replace("\\","")
    return t1



def predict_mbti_txt(txt):
    prompt = f"""
    The follow is a list of comments and posts by a particular Reddit user. Predict their MBTI (Myer-Briggs) type to the best of your ability. When making the prediction, disregard any type the user claims to be and make the judgement yourself.
    Please also state your confidence in the prediction and give an explanation for the type you chose and your confidence in it. You may also discuss other potential types in the explanation.

    Comments:
    ---
    {txt}
    ---
    """
    result = model.generate_content(prompt,
                                    generation_config = genai.GenerationConfig(
                                        response_mime_type="application/json", response_schema=Personality
                                    ))
    return result

def predict_user(username):
    try:
        comms = list(reddit.redditor(username).comments.new(limit=None))
        subs = list(reddit.redditor(username).submissions.new(limit=None))
        text = []
        ups = 0
        downs = 0
        for comment in comms:
            text.append(noquotes(comment.body))
            votes = comment.ups - 1
            if comment.controversiality == 1:
                extent = abs(votes)*1.5 + 3
                ups += int(round(votes + extent))
                downs += int(round(votes - extent))
            else:
                if votes > 0:
                    ups += votes
                else:
                    downs += votes
        for sub in subs:
            newsub = noquotes(sub.selftext)
            if len(newsub) > 0:
                text.append(newsub)
        fulltext = "\n\n\n".join(text)
        A = abs(ups)
        B = abs(downs)
        controversiality = B/(A+B)
        return predict_mbti_txt(fulltext).to_dict()['candidates'][0]['content']['parts'][0]['text'], controversiality
    except NotFound:
        return None, None

In [None]:
result = predict_user("lexfridman")
pprint.pprint(result)