# About

Script to create the subreddit groupings. Each subreddit grouping will be used to train a separate fine tuned model. Final generated summaries will be done using their respective subreddit group fine tuned model. 

In [1]:
from IPython.display import clear_output

!pip install datasets transformers rouge_score nltk
!pip install pyarrow
# !pip install -q sentencepiece
# !pip install rouge-score # google package version

clear_output()

In [2]:
import os
import re
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

# nlp stuff
import nltk
nltk.download('punkt')

# tf stuff
import tensorflow_datasets as tfds 
import tensorflow as tf
from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration # pegasus
from transformers import BartTokenizer, TFBartForConditionalGeneration # bart

# pytorch dataset types
import datasets
from datasets.dataset_dict import DatasetDict
from datasets import Dataset, load_metric, load_dataset

# pytorch bart stuff
import torch
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

clear_output()

# Load Data

In [3]:
%%time
from google.colab import drive
drive.mount('/content/gdrive')
# data_path = "/content/gdrive/MyDrive/Classes/W266_NLP/w266_reddit_summarization/data/reddit_parquet/"
data_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization/data/reddit_parquet/'

os.chdir(data_path)
files = [i for i in os.listdir(data_path) if re.search("reddit_data", i)]

# df = pd.read_parquet(files[:4])
# df = pd.read_parquet(files[0])
df = pd.read_parquet(files)

Mounted at /content/gdrive
CPU times: user 24.8 s, sys: 21.5 s, total: 46.3 s
Wall time: 1min 1s


In [4]:
# filter out content/summaries that are way too short
print("Starting number of rows: {}".format(df.shape[0]))

df = df[df['content'].map(lambda x: (len(x) >= 20) and (len(x) < 1500))]
print("Rows after filtering too long/short content: {}".format(df.shape[0]))

df = df[df['summary'].map(lambda x: (len(x) >= 10) and (len(x) < 150))]
print("Rows after filtering too long/short summaries: {}".format(df.shape[0]))

Starting number of rows: 3848330
Rows after filtering too long/short content: 2572452
Rows after filtering too long/short summaries: 2002658


# Create Groupings

In [5]:
%%time
def group_subreddit(subreddit):

  if subreddit in ['buildapc', 'LifeProTips', 'lifehacks', 'IAmA', 'DoesAnybodyElse', 'answers', 
                   'DIY', 'howtonotgiveafuck', 'Parenting'] or \
      re.search('advice|ask|explain|question|socialskills', subreddit.lower()) or \
      re.search('okcupid|relationship|sex|seduction', subreddit.lower()) or \
      re.search('finance|frugal|investing|crypto|bitcoin|entrepreneur|business|occupywallstreet', subreddit.lower()):
    x = 'advice/story' # advice

  elif subreddit in ['leagueoflegends', 'DotA2', 'starcraft', 'magicTCG', 
                     'Guildwars2', 'DestinyTheGame', 'pcmasterrace', 'gamecollecting',
                     'Planetside', 'rpg', 'pokemon', 'smashbros', 'swtor', 
                     'runescape', 'battlefield3', 'DarkSouls2', 'LeagueofLegendsMeta', 
                     'WorldofTanks', 'darksouls', 'gamedev', 'Minecraft', 'Diablo', 
                     'DnD', 'skyrim', 'halo', 'PS4', 'xboxone', 'battlefield_4', 
                     'ShouldIbuythisgame', 'Pathfinder_RPG', 'elderscrollsonline', 
                     'Fallout', 'GrandTheftAutoV', 'assassinscreed', 'summonerschool', 
                     'GlobalOffensive', '3DS', 'ffxiv', 'tf2', 'MonsterStrike', 
                     'GuildWars2Builds', 'hearthstone', 'Warframe', 'MonsterHunter', 'wow', 
                     'Smite', 'dayz', 'Eve', 'Warthunder', 'GameDeals', '2007scape', 'pathofexile', 
                     'masseffect', 'starcitizen', 'oculus', 'wiiu', 'Steam', 'bravefrontier', 
                     'diablo3', 'gamegrumps', 'totalwar', 'Borderlands', 'CoDCompetitive', 
                     'Civcraft', 'blackops2', 'gaymers', 'KotakuInAction', 'PS3', 'Warhammer', 
                     'zelda', 'truetf2', 'dwarffortress', 'mw3', 'nintendo', 'learndota2', 
                     'HeroesofNewerth', 'feedthebeast', 'h1z1', 'archeage', 'ClashOfClans', 
                     'CodAW', 'dragonage', 'PuzzleAndDragons', 'RotMG', 'SimCity', 'Bioshock'
                     ] or re.search('gaming|games|gamer|pokemon', subreddit.lower()):
    x = 'gaming'

  elif subreddit in ['tifu', 'TwoXChromosomes', 'offmychest', 'todayilearned', 
                     'fffffffuuuuuuuuuuuu', 'JusticePorn', 'cringe', 'pettyrevenge',
                     'confession', 'WTF', 'aww', 'SubredditDrama', 'facepalm'
                     ] or \
                     re.search('tales', subreddit.lower()):
    x = 'advice/story' #'story'

  elif subreddit in ['changemyview', 'MensRights'] or \
      re.search('news|politic|libertarian|anarchism|democrat|republican|conservative|liberal|socialism', subreddit.lower()) or \
      re.search('atheis|religion|christian|islam|mormon|catholicism|judaism', subreddit.lower()):
    x = 'media/religion/sports' #'news/religion'

  elif re.search('sport|baseball|soccer|golf|football|basketball|hockey|nfl|nba|mlb', subreddit.lower()) or \
      subreddit in ['MMA', 'SquaredCircle', 'MLS', 'MTB', 'FIFA', 'LiverpoolFC', 'chelseafc', 'NASCAR', 'formula1', 
                    'longboarding', 'snowboarding', 'skiing', 'climbing', 
                    'photography'] or \
      subreddit in ['loseit', 'GetMotivated', 'bodybuilding', 'martialarts', 'Health'] or \
      re.search('cycling|running|motorcycles|cfb|airsoft|paintball|fitness', subreddit.lower()):
    x = 'media/religion/sports' # 'sports/hobbies/fitness'
    
  elif re.search('pics|videos|funny|comedy|jokes|movies|television|music|books|anime|gifs', subreddit.lower()) or \
      subreddit in ['gameofthrones', 'asoiaf', 'doctorwho', 'thewalkingdead', 'TheLastAirbender', 
                    'Naruto', 'harrypotter', 'yugioh', 'startrek', 'StarWars', 'breakingbad'
                    'mylittlepony', 'hiphopheads', 'comics', 'vinyl', 'community', 'kpop'
                    ]:
    x = 'media/religion/sports' # 'media'

  else:
    x = 'other'
  
  return x


df['subreddit_group'] = df['subreddit'].map(group_subreddit)
subreddit_group_counts = df['subreddit_group'].value_counts().to_frame().reset_index().rename(columns={'index': 'subreddit_group', 'subreddit_group': 'N'})
print(subreddit_group_counts)

         subreddit_group       N
0           advice/story  752765
1                  other  660654
2                 gaming  304512
3  media/religion/sports  284727
CPU times: user 16.1 s, sys: 9.97 ms, total: 16.1 s
Wall time: 16 s


In [170]:
# visualize leftover subreddits

from IPython.display import display, HTML
all_counts = df[['subreddit', 'subreddit_group']].value_counts().to_frame().reset_index().rename(columns={0: 'N'})
# filter out ones we've already grouped
all_counts = all_counts[all_counts['subreddit_group']=='other']

# Puts the scrollbar next to the DataFrame
display(HTML(
    "<div style='height: 400px; overflow: auto; width: fit-content'>" 
    + all_counts.to_html() 
    + "</div>"
    ))

Unnamed: 0,subreddit,subreddit_group,N
14,trees,other,2771
17,reddit.com,other,1795
18,technology,other,1736
27,science,other,1248
37,guns,other,906
38,Android,other,886
58,electronic_cigarette,other,640
59,teenagers,other,634
65,programming,other,575
68,canada,other,558


In [178]:
# inspect individual ones
findme = 'woahdude'
ind = 50
print("content:")
pprint(df[df['subreddit'] == findme]['content'].iloc[ind])
print("\nsummary:")
df[df['subreddit'] == findme]['summary'].iloc[ind]

content:
('As an EMT in training, we are required to write down the mechanism of injury '
 'in our pre-hospital care reports (PCR), if your case had gone to court the '
 "PCR can be used as evidence, and the EMT's or paramedics in question may "
 "have been called in as witnesses. I'm also not sure on the exact legal "
 'specifications of caregiver patient confidentiality when it comes to '
 'emergency responders.')

summary:


"They weren't being dicks, as long as your injuries weren't life or limb threatening they were protecting you from a possible court case."

# Train/Test/Valid

In [10]:
# sample to create train/test/valid splits
num_rows_train = 50000
num_rows_test = 5000
num_rows_valid = 5000
ratio1 = (num_rows_train + num_rows_test + num_rows_valid) / df.shape[0]

from sklearn.model_selection import train_test_split
throw_away, data = train_test_split(df, test_size=ratio1, random_state=1)
print(data.shape[0])

# get train
train, test = train_test_split(data, test_size = (num_rows_test+num_rows_valid) / data.shape[0], random_state=1)

# get test/valid
test, valid = train_test_split(test, test_size=.5, random_state=1)

# print shapes
print("train size: {}".format(train.shape[0]))
print("test size: {}".format(test.shape[0]))
print("valid size: {}".format(valid.shape[0]))

60000
train size: 50000
test size: 5000
valid size: 5000


# Write data to disk

In [16]:
# %%time
# write_path = '/content/gdrive/MyDrive/w266/w266_reddit_summarization/data/reddit_parquet/'

# # create dir if not exists
# write_path2 = os.path.join(write_path, 'train_test_split')

# if not os.path.exists(write_path2):
#   os.makedirs(write_path2)

# # write train/test/valid separately
# # make sure you install pyarrow
# train.to_parquet(os.path.join(write_path2, 'reddit_train.parquet'))
# test.to_parquet(os.path.join(write_path2, 'reddit_test.parquet'))
# valid.to_parquet(os.path.join(write_path2, 'reddit_validation.parquet'))

yes
