# About

Script to clean data and create the subreddit groupings. Each subreddit grouping will be used to train a separate fine tuned model. Final generated summaries will be done using their respective subreddit group fine tuned model. 

In [31]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score nltk
!pip install pyarrow
# !pip install -q sentencepiece
# !pip install rouge-score # google package version

clear_output()

In [32]:
import os
import re
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

clear_output()

# Load Data

In [33]:
%%time
from google.colab import drive
drive.mount('/content/gdrive')
# data_path = "/content/gdrive/MyDrive/Classes/W266_NLP/w266_reddit_summarization/data/reddit_parquet/"
data_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization/data/reddit_parquet/'

os.chdir(data_path)
files = [i for i in os.listdir(data_path) if re.search("reddit_data", i)]

# df = pd.read_parquet(files[:4])
# df = pd.read_parquet(files[0])
df = pd.read_parquet(files)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
CPU times: user 39.2 s, sys: 42.5 s, total: 1min 21s
Wall time: 1min 2s


# Filter out bad data

- Content will have >= 5 distinct words. And summaries will have >= 2 distinct words. 
- Content will have between 20-1000 total words. Summaries will have between 2-100 total words. 
- See EDA_3 script for distributions that led to this decision.

In [34]:
%%time

def count_vocab(item):
  return len(set(item.lower().split(" ")))

def count_total_words(x):
  return len(x.lower().split(" "))

df['content_vocab'] = df['content'].map(count_vocab)
df['summary_vocab'] = df['summary'].map(count_vocab)
df['content_total_words'] = df['content'].map(count_total_words)
df['summary_total_words'] = df['summary'].map(count_total_words)

print("Starting number of rows: {}".format(df.shape[0]))

# some posts just have 1 word, or repeat the one word over and over and is not a real post. 
# only keep contents with at least 5 distinct words

df = df[df['content_vocab'] >= 10]
print("Rows after filtering content to have >= 10 distinct words: {}".format(df.shape[0]))

# do the same for summaries, want at least 2 distinct words
# df = df[df['summary'].map(lambda x: count_vocab(x) >= 2)]
df = df[df['summary_vocab'] >= 2]
print("Rows after filtering content to have >= 2 distinct words: {}".format(df.shape[0]))

# after observing distributions of word counts in EDA, deciding to keep...
# content with 20-1000 TOTAL words
# summaries with 2 - 100 TOTAL words
df = df[(df['content_total_words'] >= 20) & (df['content_total_words'] < 1001)]
print("Rows after filtering content to 20-1000 total words: {}".format(df.shape[0]))

df = df[(df['summary_total_words'] >= 2) & (df['summary_total_words'] < 101)]
print("Rows after filtering summaries to 2-100 total words: {}".format(df.shape[0]))

# remove the columns
df = df[['content', 'summary', 'subreddit']]

Starting number of rows: 3848330
Rows after filtering content to have >= 10 distinct words: 3818715
Rows after filtering content to have >= 2 distinct words: 3777540
Rows after filtering content to 20-1000 total words: 3642077
Rows after filtering summaries to 2-100 total words: 3550127
CPU times: user 4min 20s, sys: 1.7 s, total: 4min 21s
Wall time: 4min 21s


# Create Groupings

In [35]:
%%time
def group_subreddit(subreddit):

  if subreddit in ['buildapc', 'LifeProTips', 'lifehacks', 'IAmA', 'DoesAnybodyElse', 'answers', 
                   'DIY', 'howtonotgiveafuck', 'Parenting'] or \
      re.search('advice|ask|explain|question|socialskills', subreddit.lower()) or \
      re.search('okcupid|relationship|sex|seduction', subreddit.lower()) or \
      re.search('finance|frugal|investing|crypto|bitcoin|entrepreneur|business|occupywallstreet', subreddit.lower()):
    x = 'advice/story' # advice

  elif subreddit in ['leagueoflegends', 'DotA2', 'starcraft', 'magicTCG', 
                     'Guildwars2', 'DestinyTheGame', 'pcmasterrace', 'gamecollecting',
                     'Planetside', 'rpg', 'pokemon', 'smashbros', 'swtor', 
                     'runescape', 'battlefield3', 'DarkSouls2', 'LeagueofLegendsMeta', 
                     'WorldofTanks', 'darksouls', 'gamedev', 'Minecraft', 'Diablo', 
                     'DnD', 'skyrim', 'halo', 'PS4', 'xboxone', 'battlefield_4', 
                     'ShouldIbuythisgame', 'Pathfinder_RPG', 'elderscrollsonline', 
                     'Fallout', 'GrandTheftAutoV', 'assassinscreed', 'summonerschool', 
                     'GlobalOffensive', '3DS', 'ffxiv', 'tf2', 'MonsterStrike', 
                     'GuildWars2Builds', 'hearthstone', 'Warframe', 'MonsterHunter', 'wow', 
                     'Smite', 'dayz', 'Eve', 'Warthunder', 'GameDeals', '2007scape', 'pathofexile', 
                     'masseffect', 'starcitizen', 'oculus', 'wiiu', 'Steam', 'bravefrontier', 
                     'diablo3', 'gamegrumps', 'totalwar', 'Borderlands', 'CoDCompetitive', 
                     'Civcraft', 'blackops2', 'gaymers', 'KotakuInAction', 'PS3', 'Warhammer', 
                     'zelda', 'truetf2', 'dwarffortress', 'mw3', 'nintendo', 'learndota2', 
                     'HeroesofNewerth', 'feedthebeast', 'h1z1', 'archeage', 'ClashOfClans', 
                     'CodAW', 'dragonage', 'PuzzleAndDragons', 'RotMG', 'SimCity', 'Bioshock'
                     ] or re.search('gaming|games|gamer|pokemon', subreddit.lower()):
    x = 'gaming'

  elif subreddit in ['tifu', 'TwoXChromosomes', 'offmychest', 'todayilearned', 
                     'fffffffuuuuuuuuuuuu', 'JusticePorn', 'cringe', 'pettyrevenge',
                     'confession', 'WTF', 'aww', 'SubredditDrama', 'facepalm'
                     ] or \
                     re.search('tales', subreddit.lower()):
    x = 'advice/story' #'story'

  elif subreddit in ['changemyview', 'MensRights'] or \
      re.search('news|politic|libertarian|anarchism|democrat|republican|conservative|liberal|socialism', subreddit.lower()) or \
      re.search('atheis|religion|christian|islam|mormon|catholicism|judaism', subreddit.lower()):
    x = 'media/lifestyle/sports' #'news/religion'

  elif re.search('sport|baseball|soccer|golf|football|basketball|hockey|nfl|nba|mlb', subreddit.lower()) or \
      subreddit in ['MMA', 'SquaredCircle', 'MLS', 'MTB', 'FIFA', 'LiverpoolFC', 'chelseafc', 'NASCAR', 'formula1', 
                    'longboarding', 'snowboarding', 'skiing', 'climbing', 
                    'photography'] or \
      subreddit in ['loseit', 'GetMotivated', 'bodybuilding', 'martialarts', 'Health'] or \
      re.search('cycling|running|motorcycles|cfb|airsoft|paintball|fitness', subreddit.lower()):
    x = 'media/lifestyle/sports' # 'sports/hobbies/fitness'
    
  elif re.search('pics|videos|funny|comedy|jokes|movies|television|music|books|anime|gifs', subreddit.lower()) or \
      subreddit in ['gameofthrones', 'asoiaf', 'doctorwho', 'thewalkingdead', 'TheLastAirbender', 
                    'Naruto', 'harrypotter', 'yugioh', 'startrek', 'StarWars', 'breakingbad'
                    'mylittlepony', 'hiphopheads', 'comics', 'vinyl', 'community', 'kpop'
                    ]:
    x = 'media/lifestyle/sports' # 'media'

  else:
    x = 'other'
  
  return x


df['subreddit_group'] = df['subreddit'].map(group_subreddit)
subreddit_group_counts = df['subreddit_group'].value_counts().to_frame().reset_index().rename(columns={'index': 'subreddit_group', 'subreddit_group': 'N'})
print(subreddit_group_counts)

          subreddit_group        N
0            advice/story  1409813
1                   other  1184045
2                  gaming   501041
3  media/lifestyle/sports   455228
CPU times: user 30.5 s, sys: 4.15 ms, total: 30.5 s
Wall time: 30.5 s


In [26]:
# visualize leftover subreddits

from IPython.display import display, HTML
all_counts = df[['subreddit', 'subreddit_group']].value_counts().to_frame().reset_index().rename(columns={0: 'N'})
# filter out ones we've already grouped
all_counts = all_counts[all_counts['subreddit_group']=='other']

# Puts the scrollbar next to the DataFrame
display(HTML(
    "<div style='height: 400px; overflow: auto; width: fit-content'>" 
    + all_counts.to_html() 
    + "</div>"
    ))

Output hidden; open in https://colab.research.google.com to view.

In [39]:
# inspect individual ones
findme = 'reddit.com'
ind = 1
print("content:")
pprint(df[df['subreddit'] == findme]['content'].iloc[ind])
print("\nsummary:")
df[df['subreddit'] == findme]['summary'].iloc[ind]

content:
('Zell Miller never joined the Republican party. Are you implying that anyone '
 "who has ever agreed with a Republican must also be a Republican? I'm sorry "
 'to disappoint you, then, but even President Obama qualifies as a '
 'Republican. \n'
 ' The bold text was used to emphasize key phrases for the selective (')

summary:


"readers of Reddit. If I was trying to present those phrases as the only historical events from the referenced page worth discussing, I wouldn't have attempted to include their historical context."

# Train/Test/Valid

In [40]:
# sample to create train/test/valid splits
num_rows_train = 50000
num_rows_test = 5000
num_rows_valid = 5000
ratio1 = (num_rows_train + num_rows_test + num_rows_valid) / df.shape[0]

from sklearn.model_selection import train_test_split
throw_away, data = train_test_split(df, test_size=ratio1, random_state=1)
print(data.shape[0])

# get train
train, test = train_test_split(data, test_size = (num_rows_test+num_rows_valid) / data.shape[0], random_state=1)

# get test/valid
test, valid = train_test_split(test, test_size=.5, random_state=1)

# print shapes
print("train size: {}".format(train.shape[0]))
print("test size: {}".format(test.shape[0]))
print("valid size: {}".format(valid.shape[0]))

60000
train size: 50000
test size: 5000
valid size: 5000


In [41]:
# check distribution
print("Train dist:")
x = train['subreddit_group'].value_counts().to_frame().reset_index().rename(columns={'index': 'subreddit_group', 'subreddit_group': 'N'})
print(x)

print("\nTest dist:")
x = test['subreddit_group'].value_counts().to_frame().reset_index().rename(columns={'index': 'subreddit_group', 'subreddit_group': 'N'})
print(x)

print("\nValid dist:")
x = valid['subreddit_group'].value_counts().to_frame().reset_index().rename(columns={'index': 'subreddit_group', 'subreddit_group': 'N'})
print(x)

Train dist:
          subreddit_group      N
0            advice/story  19938
1                   other  16743
2                  gaming   6874
3  media/lifestyle/sports   6445

Test dist:
          subreddit_group     N
0            advice/story  1946
1                   other  1730
2                  gaming   680
3  media/lifestyle/sports   644

Valid dist:
          subreddit_group     N
0            advice/story  1952
1                   other  1678
2                  gaming   709
3  media/lifestyle/sports   661


# Write data to disk

In [45]:
train

Unnamed: 0,content,summary,subreddit,subreddit_group
779138,My dad got sick of the neighbors dog and went ...,neighbor married a piece of shit wife and she ...,AskReddit,advice/story
405453,"Is this the case? No, not entirely. First of a...",Most Christians in the United States don't bel...,TrueAtheism,media/lifestyle/sports
2422458,So after listening to [this]( nonstop I've dec...,"Could anyone with experience in choppy, call-a...",edmproduction,other
1797466,Good works do not get you into heaven without ...,You're absolutely correct as far as my studies...,worldnews,media/lifestyle/sports
274535,Ugh. My ex fiance proposed with a replica of t...,don't get your girl the One Ring unless you kn...,AdviceAnimals,advice/story
...,...,...,...,...
817999,"IS IT TOO LATE?! \n Ok, so when I was in 7th g...",My little brother tripped and lost our copy of...,Playdate,other
1824426,But not a country that has nothing much else t...,Brazilians don't care as much about their nati...,soccer,media/lifestyle/sports
1544704,Sex & Relationships. \n I was taught in high s...,dont raise your children in conservative areas...,AskReddit,advice/story
3352037,I have seen a ton of threads lately where peop...,If you are the victim of theft call the police...,RBI,other


In [47]:
%%time
write_path = data_path

# create dir if not exists
write_path2 = os.path.join(write_path, 'train_test_split')

if not os.path.exists(write_path2):
  os.makedirs(write_path2)

# write train/test/valid separately
# make sure you install pyarrow
train.to_parquet(os.path.join(write_path2, 'reddit_train.parquet'))
test.to_parquet(os.path.join(write_path2, 'reddit_test.parquet'))
valid.to_parquet(os.path.join(write_path2, 'reddit_validation.parquet'))

CPU times: user 637 ms, sys: 409 ms, total: 1.05 s
Wall time: 1.24 s


In [48]:
!ls /content/gdrive/MyDrive/w266/w266_reddit_summarization/data/reddit_parquet/train_test_split

reddit_test.parquet  reddit_train.parquet  reddit_validation.parquet
