In [1]:
!pip install transformers --quiet
!pip install praw --quiet
!pip install torch torchvision --quiet

from transformers import BertTokenizerFast, BertTokenizer, BertModel
import praw
from itertools import chain
import torch

[K     |████████████████████████████████| 4.4 MB 5.3 MB/s 
[K     |████████████████████████████████| 596 kB 37.3 MB/s 
[K     |████████████████████████████████| 86 kB 3.3 MB/s 
[K     |████████████████████████████████| 6.6 MB 37.7 MB/s 
[K     |████████████████████████████████| 188 kB 5.5 MB/s 
[K     |████████████████████████████████| 54 kB 1.5 MB/s 
[?25h

In [2]:
from transformers import logging
logging.set_verbosity_error()

In [40]:
class RedditUser:
    def __init__(self, reddit: praw.Reddit, user_name):
        self.user = reddit.redditor(user_name)

    def fetch_texts(self, depth):
        return chain(self._fetch_comments(depth), 
                     self._fetch_post(depth))

    def _fetch_comments(self, depth):
        return map(lambda c: c.body, 
                   self.user.comments.new(limit=depth))

    def _fetch_post(self, depth):
        return map(lambda s: s.selftext, 
                   self.user.submissions.new(limit=depth))


class RedditAuth:
  def __init__(self, client_id, client_secret, user_agent):
    self.client_id = client_id
    self.client_secret = client_secret
    self.user_agent = user_agent


class Reddit:
  def __init__(self, text_fetch_depth, reddit_auth):
    self.text_fetch_depth = text_fetch_depth
    self.reddit = praw.Reddit(
      client_id=reddit_auth.client_id,
      client_secret=reddit_auth.client_secret,
      user_agent=reddit_auth.user_agent,
      check_for_async=False) # to supress some strange warnings
    
  def user_texts(self, user_name):
    return RedditUser(self.reddit, user_name).fetch_texts(self.text_fetch_depth)

In [41]:
class TextToVecUsingHiddenState:
  def __init__(self, max_length) -> None:
      self.max_length = max_length
      self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
      self.bert_encoder = BertModel.from_pretrained("bert-base-uncased")

  def _hidden_state(self, last_hidden_state):
    return torch.mean(torch.squeeze(last_hidden_state), dim=0)

  def __call__(self, text):
      tokens = self.tokenizer.encode_plus(text, return_tensors='pt', 
                             max_length=self.max_length, 
                             truncation=True)
      encoded = self.bert_encoder(**tokens)
      return self._hidden_state(encoded.last_hidden_state)

In [42]:
class TextToVecUsingPoolerOutput:
  def __init__(self, max_length) -> None:
      self.max_length = max_length
      self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
      self.feature_extractor = BertModel.from_pretrained("bert-base-uncased")

  def __call__(self, text):
      token = self.tokenizer.encode_plus(text, return_tensors='pt', 
                             max_length=self.max_length, 
                             truncation=True)
      result = self.feature_extractor(**token)
      return result.pooler_output.flatten()

In [43]:
class Application:
  def __init__(self, 
               text_to_vec,
               text_fetch_depth, 
               reddit_auth):
      self.reddit = Reddit(text_fetch_depth, reddit_auth)
      self.text_to_vec = text_to_vec

  def _mean(self, list_of_text_vec):
      matrix_of_text_vectors = torch.stack(list_of_text_vec, dim=0)
      return torch.mean(matrix_of_text_vectors, dim=0)
  
  def _calculate_similarity_score(self, user_name_1, user_name_2):
    list_of_text_vec_1 = list(map(self.text_to_vec, 
                                  self.reddit.user_texts(user_name_1)))
    vec_1 = self._mean(list_of_text_vec_1)

    list_of_text_vec_2 = list(map(self.text_to_vec, 
                                  self.reddit.user_texts(user_name_2)))
    vec_2 = self._mean(list_of_text_vec_2)
    return torch.dot(vec_1, vec_2) / (vec_1.norm() * vec_2.norm())

  @staticmethod
  def main_loop(text_to_vec, text_fetch_depth, reddit_auth):
    app = Application(text_to_vec, text_fetch_depth, reddit_auth)

    while True:
      user_name_1 = input('Please enter first user name: ')
      user_name_2 = input('Please enter second user name: ')
      print(f'''The similarity score between {user_name_1} and {user_name_2}
             is {app._calculate_similarity_score(user_name_1, user_name_2)}''')

In [None]:
reddit_auth = RedditAuth(client_id="lHtve-vYh8mpZ6DbTH450A",
                         client_secret="KZOKIAL944VQabOWqGni7jEJscssMg",
                         user_agent="android:com.example.test:v1.2.3 (by u/kemitcheProfessionalInside45)")

text_fetch_depth = 35
max_tokenizer_length = 300
text_to_vec = TextToVecUsingHiddenState(max_tokenizer_length) # choose one of two implementations

Application.main_loop(text_to_vec, text_fetch_depth, reddit_auth)

In [None]:
# user name examples
'xtilexx'
'Sir_Loinbeef'
'Repulsive_Love_'
'ForecastForFourCats'