# Subreddit Guesser

Can I train a program to identify which subreddit a post came from using only the post title?

The top 10 most popular subreddits are 'announcements', 'funny', 'AskReddit', 'gaming', 'awww', 'Music', 'pics', 'science', 'worldnews', 'videos'. So I'll use those.


In [9]:
import numpy as np
import praw
import pandas as pd
import xgboost as xgb
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

In [10]:
import secrets

# Create a Reddit instance
user_agent = "Test 0.1 by /u/IsThisATrollBot"
reddit = praw.Reddit(
    client_id=secrets.client_ID,
    client_secret=secrets.client_secret,
    password=secrets.password,
    user_agent=user_agent,
    username=secrets.username,
)

In [11]:
# Get the 10 most popular subreddits
top_subreddits = ['announcements', 'funny', 'AskReddit', 'gaming', 'awww', 'Music', 'pics', 'science', 'worldnews', 'videos']

Create a list of 1000 posts from the top 10 most popular subreddits.

In [12]:
# Create an empty list to store the posts
posts = []

# Iterate through the subreddits and get the last 1000 posts from each
for sub in top_subreddits:
    subreddit_posts = reddit.subreddit(sub).new(limit=1000)
    for post in subreddit_posts:
        posts.append(post)


In [71]:

# Create a list of dictionaries containing the data for each post
data = [{'title': post.title, 'subreddit': post.subreddit.display_name} for post in posts]

# Create a Pandas dataframe from the list of dictionaries
df = pd.DataFrame(data)
df

Unnamed: 0,title,subreddit
0,This subreddit is closed for new posts and com...,announcements
1,COVID denialism and policy clarifications,announcements
2,"Debate, dissent, and protest on Reddit",announcements
3,Sunsetting Secret Santa and Reddit Gifts,announcements
4,Second,announcements
...,...,...
8353,Mukbang Has Gone To Far,videos
8354,GWAR goes door to door among the suburbs of Ri...,videos
8355,Justin Roiland & Mikey Spano join Oddheader to...,videos
8356,Testing toilet flushing efficiency,videos


### Split the data into test and training data

In [72]:
from sklearn.model_selection import train_test_split

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(df['title'], df['subreddit'], test_size=0.2)

# Get the indices of the test set samples
test_indices = X_test.index

### Try different vectorization models

Bag-of-Words and Doc2Vec

In [90]:
# Try Bag-of-Words first
# This just measures the frequency of each word in each title

from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(X_train)
X_test = vectorizer.transform(X_test)

AttributeError: lower not found

In [89]:
X_train_dense = X_train.toarray()
X_train_dense

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int64)

In [88]:
type(X_test)

scipy.sparse._csr.csr_matrix

In [91]:

# Vectorize the text using bag-of-words
vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(X_train)
X_test = vectorizer.transform(X_test)

# Train a support vector machine (SVM) classifier
model = SVC()
model.fit(X_train, y_train)

# Evaluate the model's performance on the testing data
accuracy = model.score(X_test, y_test)
print("Accuracy:", accuracy)

AttributeError: lower not found

In [69]:
# Select the titles for the test set samples
test_titles = df['title'][test_indices]

# Create a dataframe to hold the results
same = y_test == y_pred
df_results = pd.DataFrame({'title': test_titles, 'same':same, 'actual_subreddit': y_test, 'predicted_subreddit': y_pred})


# Display the results
df_results

Unnamed: 0,title,same,actual_subreddit,predicted_subreddit
7762,A Special Rendition of O Holy Night,False,videos,AskReddit
8002,Jack Black sings Marvin Gaye in HIGH FIDELITY,False,videos,worldnews
4993,Sun halo,False,pics,AskReddit
540,"Shine It On, Rudolph!",False,funny,science
2205,What’s a nostalgic game you love that you genu...,False,gaming,Music
...,...,...,...,...
3536,"[OC] I have small hands, but this bunny is sma...",False,Awww,funny
3195,Pretty much me in a few hours...maybe a touch ...,True,Awww,Awww
2092,An article about how one can use GPT-3 and Sta...,False,gaming,Music
5455,Plague Doctor,False,pics,gaming


In [57]:
# Make predictions on the test set
y_pred = model.predict(X_test)

# Create a dataframe to hold the results
df_results = pd.DataFrame({'title': X_test, 'actual_subreddit': y_test, 'predicted_subreddit': y_pred})

# Display the results
df_results

Unnamed: 0,title,actual_subreddit,predicted_subreddit
1203,"(0, 5981)\t1\n (0, 6396)\t1\n (0, 10184)\t...",AskReddit,AskReddit
7186,"(0, 993)\t1\n (0, 2879)\t1\n (0, 3742)\t1\...",worldnews,worldnews
110,"(0, 368)\t1\n (0, 655)\t1\n (0, 2202)\t1\n...",announcements,AskReddit
5981,"(0, 569)\t1\n (0, 2152)\t1\n (0, 3402)\t1\...",science,science
4045,"(0, 773)\t1\n (0, 4004)\t1\n (0, 6240)\t1",Music,Music
...,...,...,...
2709,"(0, 773)\t1\n (0, 941)\t1\n (0, 5244)\t1\n...",gaming,funny
3420,"(0, 773)\t1\n (0, 1202)\t2\n (0, 3206)\t1\...",Awww,Awww
3936,"(0, 4424)\t1\n (0, 5708)\t1\n (0, 9761)\t1...",Music,Music
2275,"(0, 655)\t1\n (0, 773)\t2\n (0, 2657)\t1\n...",gaming,gaming


In [14]:
#Check to see if there are any null entries
df.isnull().sum()

title        0
subreddit    0
dtype: int64

### Train a doc2vec model on the post titles

In [92]:

# Get the post titles from the dataframe
titles = df['title'].tolist()

# Create a list of TaggedDocument objects from the titles
documents = [TaggedDocument(words=title.split(), tags=[str(i)]) for i, title in enumerate(titles)]

In [16]:

# Create a doc2vec model
model = Doc2Vec(documents, vector_size=100, window=2, min_count=1, workers=4)

# Train the model
model.train(documents, total_examples=model.corpus_count, epochs=100)

In [93]:
titles

['This subreddit is closed for new posts and comments. For future updates, announcements, and news related to Reddit Inc. and the platform, please visit r/reddit.',
 'COVID denialism and policy clarifications',
 'Debate, dissent, and protest on Reddit',
 'Sunsetting Secret Santa and Reddit Gifts',
 'Second',
 'An update on the recent issues surrounding a Reddit employee',
 'Today we’re testing a new way to discuss political ads (and announcements)',
 'Now you can make posts with multiple images.',
 'Update to Our Content Policy',
 'Upcoming changes to our content policy, our board, and where we’re going from here',
 'Changes to Reddit’s Political Ads Policy',
 'Introducing the Solidarity Award — A 100% contribution to the COVID-19 Solidarity Response Fund for WHO',
 'Imposter',
 'Introducing Reddit Polls, An All-New Post Type',
 'Announcing our partnership and AMA with Crisis Text Line',
 'Spring forward… into Reddit’s 2019 transparency report',
 'Suspected Campaign from Russia on Redd

In [17]:

# Get the vectorized titles from the doc2vec model
vectors = [model.infer_vector(title.split()) for title in titles]

# Add the vectors to the dataframe as a new column
df['vectors'] = vectors

In [18]:
df

Unnamed: 0,title,subreddit,vectors
0,This subreddit is closed for new posts and com...,announcements,"[-1.1355639, -0.772142, 0.33601853, 0.30085954..."
1,COVID denialism and policy clarifications,announcements,"[-0.22140661, -0.042934902, 0.12624909, 0.0191..."
2,"Debate, dissent, and protest on Reddit",announcements,"[-0.23785444, 0.11020816, -0.059264645, -0.149..."
3,Sunsetting Secret Santa and Reddit Gifts,announcements,"[-0.26244256, -0.045981333, 0.28132874, -0.058..."
4,Second,announcements,"[-0.034034044, 0.27952778, -0.053140096, 0.280..."
...,...,...,...
8353,Mukbang Has Gone To Far,videos,"[0.13389727, 0.25366965, 0.034213733, -0.20066..."
8354,GWAR goes door to door among the suburbs of Ri...,videos,"[0.16855897, -0.30004513, -0.33736327, 1.26674..."
8355,Justin Roiland & Mikey Spano join Oddheader to...,videos,"[-0.37707096, -0.0922984, -0.3191751, 0.096676..."
8356,Testing toilet flushing efficiency,videos,"[0.012790001, -0.18439461, 0.113403104, -0.302..."


### Train an XGBoost model to predict which subreddit each post belongs to

In [19]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(df['vectors'].tolist(), df['subreddit'].tolist(), test_size=0.2, random_state=42)


In [20]:
# X_train = pd.DataFrame(X_train)
# y_train = pd.DataFrame(y_train)

In [21]:
# X_train.info()

In [22]:
from sklearn.preprocessing import LabelEncoder

# Convert the labels to numerical values
le = LabelEncoder()
y_train = np.array(y_train)
y_train = le.fit_transform(y_train)
y_test = le.transform(y_test)


In [26]:
# Convert the data to DMatrix objects
dtrain = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
dtest = xgb.DMatrix(X_test, label=y_test, enable_categorical=True)


In [27]:
# Set the XGBoost parameters
param = {'max_depth': 3, 'eta': 0.1, 'objective': 'multi:softmax', 'num_class': 10}

# Train the XGBoost model
bst = xgb.train(param, dtrain, num_boost_round=10)

In [46]:

# Make predictions on the test set
predictions = bst.predict(dtest)

# Print the accuracy
accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy: {accuracy:.2f}')


Accuracy: 0.43


In [47]:
predictions = np.array(list(map(int, predictions)))
predictions

array([7, 0, 7, ..., 0, 5, 9])

array([7., 0., 7., ..., 0., 5., 9.], dtype=float32)

In [52]:
# Create a list of the predicted subreddit names
predicted_subreddits = le.inverse_transform(predictions)
actual_subreddits = le.inverse_transform(y_test)

# Create a dataframe with the post title, actual subreddit, and predicted subreddit
df = pd.DataFrame({'title': X_test, 'actual': actual_subreddits, 'predicted': predicted_subreddits})
df

Unnamed: 0,title,actual,predicted
0,"[-0.6965856, -0.36041075, -0.016394246, 0.2248...",science,science
1,"[-0.19899519, -0.057780933, 0.101786934, 0.147...",gaming,AskReddit
2,"[-0.1301142, -0.3742794, 0.23896079, 0.0127759...",science,science
3,"[-0.010454829, 0.12601396, -0.04458673, -0.089...",pics,Awww
4,"[0.12843373, 0.27396354, 0.14378692, 0.1771490...",worldnews,Music
...,...,...,...
1667,"[0.041328456, -0.01589865, 0.14951625, -0.0083...",funny,Awww
1668,"[-0.06795082, -0.063612305, 0.3198764, 0.38556...",videos,gaming
1669,"[-0.25345698, -0.010241983, 0.68700796, 0.0033...",funny,AskReddit
1670,"[0.04323784, -0.24545237, 0.18534826, -0.30235...",AskReddit,gaming


In [49]:
df

Unnamed: 0,title,actual,predicted
0,"[-0.6965856, -0.36041075, -0.016394246, 0.2248...",7,science
1,"[-0.19899519, -0.057780933, 0.101786934, 0.147...",5,AskReddit
2,"[-0.1301142, -0.3742794, 0.23896079, 0.0127759...",7,science
3,"[-0.010454829, 0.12601396, -0.04458673, -0.089...",6,Awww
4,"[0.12843373, 0.27396354, 0.14378692, 0.1771490...",9,Music
...,...,...,...
1667,"[0.041328456, -0.01589865, 0.14951625, -0.0083...",4,Awww
1668,"[-0.06795082, -0.063612305, 0.3198764, 0.38556...",8,gaming
1669,"[-0.25345698, -0.010241983, 0.68700796, 0.0033...",4,AskReddit
1670,"[0.04323784, -0.24545237, 0.18534826, -0.30235...",0,gaming


In [64]:
for label in y_train[:10]:
    if not label.isascii():
        print(f'Non-ASCII characters found in label: {label}')


Awww
funny
Awww
pics
funny
pics
funny
AskReddit
Music
funny


In [65]:
print(X_train[:5])
print(y_train[:5])


[[-0.03523372 -0.11046028 -0.1423582   0.4160329  -0.29461384 -0.4887824
   0.69866467  1.0167252  -0.5445614  -0.5502897  -0.15627718 -0.55022484
   0.41585827 -0.30390897 -0.27205878  0.41279483  0.56562847 -0.00461209
  -0.6877157  -0.6729945   0.5200319  -0.35202685  0.13109298 -0.24230972
   0.06181084 -0.2828839   0.13002984 -0.29786134 -0.09322817 -0.16211532
   0.7680654   0.3817985   0.01414419 -0.03392706 -0.19665286  0.0705587
   0.24197136 -0.10805511 -0.27039778 -0.13752991 -0.06370736 -0.27373213
   0.09091808 -0.30996516  0.4304253  -0.5640691  -0.14220345  0.24700731
   0.15278882  0.00705155  0.289233   -0.28843135  0.1336196  -0.5520374
  -0.6056017   0.16359329  0.45646515  0.14271544 -0.0997751  -0.27343363
  -0.08578045  0.12324178  0.91274434 -0.17022388 -0.37442383  0.18313143
   0.21369663 -0.12813434 -0.49952796  0.25478166  0.03890992 -0.07900175
   0.33613116 -0.12469764 -0.2777261   0.2407419   0.49526253  0.26033592
   0.27007443  0.3625942   0.1401691  -0.