# Data Clean

Script to clean data and create the subreddit groupings. Each subreddit grouping will be used to train a separate fine tuned model. Final generated summaries will be done using their respective subreddit group fine tuned model. 

In [1]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score nltk
!pip install pyarrow
# !pip install -q sentencepiece
# !pip install rouge-score # google package version

clear_output()

In [2]:
import os
import re
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

clear_output()

# Load Data

In [5]:
%%time
from google.colab import drive
drive.mount('/content/gdrive')
# data_path = "/content/gdrive/MyDrive/Classes/W266_NLP/w266_reddit_summarization/data/reddit_parquet/"
data_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization/data/reddit_parquet/'

os.chdir(data_path)
files = [i for i in os.listdir(data_path) if re.search("reddit_data", i)]

# df = pd.read_parquet(files[:4])
# df = pd.read_parquet(files[0])
df = pd.read_parquet(files)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
CPU times: user 26.8 s, sys: 29.5 s, total: 56.3 s
Wall time: 39.9 s


# Filter out bad data

- Content will have >= 5 distinct words. And summaries will have >= 2 distinct words. 
- Content will have between 20-1000 total words. Summaries will have between 2-100 total words. 
- See EDA script for distributions that led to this decision.

In [6]:
%%time

def count_vocab(item):
  return len(set(item.lower().split(" ")))

def count_total_words(x):
  return len(x.lower().split(" "))

df['content_vocab'] = df['content'].map(count_vocab)
df['summary_vocab'] = df['summary'].map(count_vocab)
df['content_total_words'] = df['content'].map(count_total_words)
df['summary_total_words'] = df['summary'].map(count_total_words)

print("Starting number of rows: {}".format(df.shape[0]))

# some posts just have 1 word, or repeat the one word over and over and is not a real post. 
# only keep contents with at least 5 distinct words

df = df[df['content_vocab'] >= 10]
print("Rows after filtering content to have >= 10 distinct words: {}".format(df.shape[0]))

# do the same for summaries, want at least 2 distinct words
# df = df[df['summary'].map(lambda x: count_vocab(x) >= 2)]
df = df[df['summary_vocab'] >= 2]
print("Rows after filtering content to have >= 2 distinct words: {}".format(df.shape[0]))

# after observing distributions of word counts in EDA, deciding to keep...
# content with 20-1000 TOTAL words
# summaries with 2 - 100 TOTAL words
df = df[(df['content_total_words'] >= 20) & (df['content_total_words'] < 1001)]
print("Rows after filtering content to 20-1000 total words: {}".format(df.shape[0]))

df = df[(df['summary_total_words'] >= 2) & (df['summary_total_words'] < 101)]
print("Rows after filtering summaries to 2-100 total words: {}".format(df.shape[0]))

# remove the columns
df = df[['content', 'summary', 'subreddit']]

Starting number of rows: 3848330
Rows after filtering content to have >= 10 distinct words: 3818715
Rows after filtering content to have >= 2 distinct words: 3777540
Rows after filtering content to 20-1000 total words: 3642077
Rows after filtering summaries to 2-100 total words: 3550127
CPU times: user 4min 19s, sys: 1.68 s, total: 4min 21s
Wall time: 4min 21s


# Create Groupings

In [7]:
%%time
def group_subreddit(subreddit):

  if subreddit in ['buildapc', 'LifeProTips', 'lifehacks', 'IAmA', 'DoesAnybodyElse', 'answers', 
                   'DIY', 'howtonotgiveafuck', 'Parenting'] or \
      re.search('advice|ask|explain|question|socialskills', subreddit.lower()) or \
      re.search('okcupid|relationship|sex|seduction', subreddit.lower()) or \
      re.search('finance|frugal|investing|crypto|bitcoin|entrepreneur|business|occupywallstreet', subreddit.lower()):
    x = 'advice/story' # advice

  elif subreddit in ['leagueoflegends', 'DotA2', 'starcraft', 'magicTCG', 
                     'Guildwars2', 'DestinyTheGame', 'pcmasterrace', 'gamecollecting',
                     'Planetside', 'rpg', 'pokemon', 'smashbros', 'swtor', 
                     'runescape', 'battlefield3', 'DarkSouls2', 'LeagueofLegendsMeta', 
                     'WorldofTanks', 'darksouls', 'gamedev', 'Minecraft', 'Diablo', 
                     'DnD', 'skyrim', 'halo', 'PS4', 'xboxone', 'battlefield_4', 
                     'ShouldIbuythisgame', 'Pathfinder_RPG', 'elderscrollsonline', 
                     'Fallout', 'GrandTheftAutoV', 'assassinscreed', 'summonerschool', 
                     'GlobalOffensive', '3DS', 'ffxiv', 'tf2', 'MonsterStrike', 
                     'GuildWars2Builds', 'hearthstone', 'Warframe', 'MonsterHunter', 'wow', 
                     'Smite', 'dayz', 'Eve', 'Warthunder', 'GameDeals', '2007scape', 'pathofexile', 
                     'masseffect', 'starcitizen', 'oculus', 'wiiu', 'Steam', 'bravefrontier', 
                     'diablo3', 'gamegrumps', 'totalwar', 'Borderlands', 'CoDCompetitive', 
                     'Civcraft', 'blackops2', 'gaymers', 'KotakuInAction', 'PS3', 'Warhammer', 
                     'zelda', 'truetf2', 'dwarffortress', 'mw3', 'nintendo', 'learndota2', 
                     'HeroesofNewerth', 'feedthebeast', 'h1z1', 'archeage', 'ClashOfClans', 
                     'CodAW', 'dragonage', 'PuzzleAndDragons', 'RotMG', 'SimCity', 'Bioshock'
                     ] or re.search('gaming|games|gamer|pokemon', subreddit.lower()):
    x = 'gaming'

  elif subreddit in ['tifu', 'TwoXChromosomes', 'offmychest', 'todayilearned', 
                     'fffffffuuuuuuuuuuuu', 'JusticePorn', 'cringe', 'pettyrevenge',
                     'confession', 'WTF', 'aww', 'SubredditDrama', 'facepalm'
                     ] or \
                     re.search('tales', subreddit.lower()):
    x = 'advice/story' #'story'

  elif subreddit in ['changemyview', 'MensRights'] or \
      re.search('news|politic|libertarian|anarchism|democrat|republican|conservative|liberal|socialism', subreddit.lower()) or \
      re.search('atheis|religion|christian|islam|mormon|catholicism|judaism', subreddit.lower()):
    x = 'media/lifestyle/sports' #'news/religion'

  elif re.search('sport|baseball|soccer|golf|football|basketball|hockey|nfl|nba|mlb', subreddit.lower()) or \
      subreddit in ['MMA', 'SquaredCircle', 'MLS', 'MTB', 'FIFA', 'LiverpoolFC', 'chelseafc', 'NASCAR', 'formula1', 
                    'longboarding', 'snowboarding', 'skiing', 'climbing', 
                    'photography'] or \
      subreddit in ['loseit', 'GetMotivated', 'bodybuilding', 'martialarts', 'Health'] or \
      re.search('cycling|running|motorcycles|cfb|airsoft|paintball|fitness', subreddit.lower()):
    x = 'media/lifestyle/sports' # 'sports/hobbies/fitness'
    
  elif re.search('pics|videos|funny|comedy|jokes|movies|television|music|books|anime|gifs', subreddit.lower()) or \
      subreddit in ['gameofthrones', 'asoiaf', 'doctorwho', 'thewalkingdead', 'TheLastAirbender', 
                    'Naruto', 'harrypotter', 'yugioh', 'startrek', 'StarWars', 'breakingbad'
                    'mylittlepony', 'hiphopheads', 'comics', 'vinyl', 'community', 'kpop'
                    ]:
    x = 'media/lifestyle/sports' # 'media'

  else:
    x = 'other'
  
  return x


df['subreddit_group'] = df['subreddit'].map(group_subreddit)
subreddit_group_counts = df['subreddit_group'].value_counts().to_frame().reset_index().rename(columns={'index': 'subreddit_group', 'subreddit_group': 'N'})
print(subreddit_group_counts)

          subreddit_group        N
0            advice/story  1409813
1                   other  1184045
2                  gaming   501041
3  media/lifestyle/sports   455228
CPU times: user 28.1 s, sys: 89.1 ms, total: 28.2 s
Wall time: 28.1 s


In [None]:
# visualize leftover subreddits

from IPython.display import display, HTML
all_counts = df[['subreddit', 'subreddit_group']].value_counts().to_frame().reset_index().rename(columns={0: 'N'})
# filter out ones we've already grouped
all_counts = all_counts[all_counts['subreddit_group']=='other']

# Puts the scrollbar next to the DataFrame
display(HTML(
    "<div style='height: 400px; overflow: auto; width: fit-content'>" 
    + all_counts.to_html() 
    + "</div>"
    ))

Output hidden; open in https://colab.research.google.com to view.

In [None]:
# inspect individual ones
findme = 'reddit.com'
ind = 1
print("content:")
pprint(df[df['subreddit'] == findme]['content'].iloc[ind])
print("\nsummary:")
df[df['subreddit'] == findme]['summary'].iloc[ind]

content:
('Zell Miller never joined the Republican party. Are you implying that anyone '
 "who has ever agreed with a Republican must also be a Republican? I'm sorry "
 'to disappoint you, then, but even President Obama qualifies as a '
 'Republican. \n'
 ' The bold text was used to emphasize key phrases for the selective (')

summary:


"readers of Reddit. If I was trying to present those phrases as the only historical events from the referenced page worth discussing, I wouldn't have attempted to include their historical context."

# Train/Test/Valid

In [None]:
df.head(10)

In [None]:
df_shuffled = df.sample(frac=1).reset_index(drop=True)
df_shuffled.head(10)

In [17]:
df[df['subreddit_group'] == 'advice/story']

Unnamed: 0,content,summary,subreddit,subreddit_group
12,"Yeah, but most folks think avoiding gluten wil...",stupid stuff.,AskReddit,advice/story
13,As an entrepreneur/freelancer (especially a su...,get a good CPA - they aren't that expensive bu...,personalfinance,advice/story
16,You probably won't come off as an ass if you j...,"just get both of their numbers, text the one y...",AskReddit,advice/story
20,"I want to say this was about two weeks ago, co...",Fuck Slender Man.,AskReddit,advice/story
21,I take a beta blocker for my heart condition t...,Butchered daughter to save the universe and wa...,AskReddit,advice/story
...,...,...,...,...
3848307,I (21 M ) My ex (18 F) broke up 2 months ago ....,so what im asking is it true that all ex girlf...,relationships,advice/story
3848315,Our church has been around for five years. In ...,how would you pay a pastor $100 a month? What ...,smallbusiness,advice/story
3848325,I've finally gotten around to initiating plans...,"hate my own feet, and don't know how to give a...",sex,advice/story
3848326,"Long time lurker, first time poster here. I'm ...","want to win cash prize, need answer for radio ...",AskReddit,advice/story


In [20]:
num_rows_train = 15000
num_rows_test = num_rows_valid = 1000
df_advice[:num_rows_train]

Unnamed: 0,content,summary,subreddit,subreddit_group
0,Can bars and/or clubs legally charge men for e...,Is it legal for bars and/or clubs to charge me...,legaladvice,advice/story
1,Great Link! Enjoyed the comments at the end es...,check out the link,WTF,advice/story
2,"Met my bf on DeviantArt, oddly enough. I comme...",since there was of course WAY more involved in...,SRSQuestions,advice/story
3,You told me you really liked me. \n When I mov...,"You seduced me when I tried to push away, stol...",offmychest,advice/story
4,"So, preface. I have a tiny Asian vagina. My bo...","Tiny asian vagina, tears after sex, halp?",sex,advice/story
...,...,...,...,...
14995,"As a wedding DJ, the worst wedding I've ever d...",The worst weddings are the ones run by cheapos,AskReddit,advice/story
14996,DO NOT DO THIS! \n If it turns out that you wo...,It's illegal and you are basically asking for ...,LifeProTips,advice/story
14997,From the Digg update backlash. The comments we...,My parents told me about Reddit.,AskReddit,advice/story
14998,"When I was in grade 12, we had a really hot in...","highschool buddy bangs the gym intern, she lea...",AskReddit,advice/story


In [21]:
num_rows_train = 15000
num_rows_test = num_rows_valid = 1000

In [28]:
%%time
# sample 15k obs from each subreddit category for train, and 1k obs for both validation/test.
num_rows_train = 15000
num_rows_test = num_rows_valid = 1000
test_ind = num_rows_train + num_rows_test
valid_ind = num_rows_train + num_rows_test + num_rows_valid

# shuffle data in each category
df_advice = df[df['subreddit_group'] == 'advice/story'].sample(frac=1, random_state=1).reset_index(drop=True)
df_gaming = df[df['subreddit_group'] == 'gaming'].sample(frac=1, random_state=1).reset_index(drop=True)
df_media = df[df['subreddit_group'] == 'media/lifestyle/sports'].sample(frac=1, random_state=1).reset_index(drop=True)
df_other = df[df['subreddit_group'] == 'other'].sample(frac=1, random_state=1).reset_index(drop=True)

# now extract certain num of rows from each subsection:
train = pd.concat([
  df_advice[:num_rows_train], 
  df_gaming[:num_rows_train], 
  df_media[:num_rows_train], 
  df_other[:num_rows_train]], ignore_index=True)

test = pd.concat([
  df_advice[num_rows_train:test_ind], 
  df_gaming[num_rows_train:test_ind], 
  df_media[num_rows_train:test_ind], 
  df_other[num_rows_train:test_ind]], ignore_index=True)

valid = pd.concat([
  df_advice[test_ind:valid_ind], 
  df_gaming[test_ind:valid_ind], 
  df_media[test_ind:valid_ind], 
  df_other[test_ind:valid_ind]], ignore_index=True)


# check counts
print("train:")
print(train['subreddit_group'].value_counts())
print("\ntest:")
print(test['subreddit_group'].value_counts())
print("\nvalidation:")
print(valid['subreddit_group'].value_counts())

train:
advice/story              15000
gaming                    15000
media/lifestyle/sports    15000
other                     15000
Name: subreddit_group, dtype: int64

test:
advice/story              1000
gaming                    1000
media/lifestyle/sports    1000
other                     1000
Name: subreddit_group, dtype: int64

validation:
advice/story              1000
gaming                    1000
media/lifestyle/sports    1000
other                     1000
Name: subreddit_group, dtype: int64
CPU times: user 7.51 s, sys: 48.8 ms, total: 7.56 s
Wall time: 7.48 s


# Write data to disk

In [29]:
%%time
write_path = data_path

# create dir if not exists
write_path2 = os.path.join(write_path, 'train_test_split_v2')

if not os.path.exists(write_path2):
  os.makedirs(write_path2)

# write train/test/valid separately
# make sure you install pyarrow
train.to_parquet(os.path.join(write_path2, 'reddit_train.parquet'))
test.to_parquet(os.path.join(write_path2, 'reddit_test.parquet'))
valid.to_parquet(os.path.join(write_path2, 'reddit_validation.parquet'))

CPU times: user 687 ms, sys: 220 ms, total: 907 ms
Wall time: 1.27 s


In [30]:
!ls /content/gdrive/MyDrive/w266/w266_reddit_summarization/data/reddit_parquet/train_test_split_v2

reddit_test.parquet  reddit_train.parquet  reddit_validation.parquet
