# Note: Partially adapted from Matthew Scharf's CIS 530 Homework #4 which he took in Fall 2020 with Professor Clayton Greenberg

# Imports

In [None]:
import re, unicodedata, numpy as np, pandas as pd, pickle, os, math

from numpy import random
from google.colab import files

from google.colab import drive

drive.mount('/content/gdrive')

Mounted at /content/gdrive


# Download from kaggle

In [None]:
files.upload()
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/

Saving kaggle.json to kaggle.json


In [None]:
! chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d mswarbrickjones/reddit-selfposts

Downloading reddit-selfposts.zip to /content
 99% 349M/352M [00:06<00:00, 60.5MB/s]
100% 352M/352M [00:06<00:00, 57.8MB/s]


In [None]:
! unzip reddit-selfposts.zip -d data

Archive:  reddit-selfposts.zip
  inflating: data/rspct.tsv          
  inflating: data/subreddit_info.csv  


# Data Processing Functions

In [None]:
def preprocess_text(text):
    text = text.lower()
    text = unicodedata.normalize('NFD', text)
    text = text.encode('ascii', 'ignore')
    text = text.decode("utf-8")
    text = re.sub(r'[^A-Za-z\'\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r' lb', ' ', text)
    text = text.strip()
    return text

In [None]:
def get_data(path):
  print('Getting data')
  raw_df = pd.read_table(path+'rspct.tsv')
  sub_df = pd.read_csv(path+'subreddit_info.csv')

  print('data read in. starting dataframe join:')

  df = pd.merge(raw_df, sub_df, how='inner',on='subreddit')\
            [['category_1','selftext']]\

  print('dataframes joined. starting text preprocessing:')

  df.selftext = df.selftext.apply(preprocess_text)

  df.category_1 = df.category_1.apply(lambda x: x.replace('/','&').replace(' ',''))

  print('Text processed.')
  return df


# Trie object definition

In [None]:
class CountTree(object):
    ''' Stores observed ngrams for language modeling '''

    def __init__(self, n=2):
        ''' Initialize the count tree; store n, d, and class_label '''
        #The order of the tree
        self.n = n

        #The counts of each ngram
        self.d = {'': 0}

        #The counts of characters following observed ngrams where
        #len(ngram) < order
        self.next = {}

    def hist_prep(self,h,str_len):
      #Preproccessing the history to be the longest substring at the end 
      #of the history which has length < order and has been seen before
      if len(h) <= str_len:
        ngram = h
      else:
        ngram = h[-str_len:]
    
      prior_count = self.get_count(ngram)
      
      if prior_count == 0 and ngram != '':
        return self.hist_prep(ngram[1:],str_len)
      else:
        return ngram

    def add_to_counts(self,ngram):
      #Method for adding an ngram to counts
      try:
        self.d[ngram] += 1
      except KeyError:
        self.d[ngram] = 1
    
    def add_to_next(self,ngram,nxt):
      #Method for adding ngram and following character to 'next' dictionary
      try:
        self.next[(ngram,nxt)] +=1
      except KeyError:
        self.next[(ngram,nxt)] = 1

    def update(self, text):
        ''' Update the count tree based on text '''
        
        str_ = '<'+text+'>'

        #Get 0th level count and next
        self.d[''] += len(str_)

        for char in str_:
          self.add_to_next('',char)  

        #Get levels above 0
        if self.n > 0:

          #For each level in the tree above 0
          for lvl in range(1,self.n+1):

            #For each character sequence of length lvl
            for i in range(len(str_) - lvl + 1):
              ngram = str_[i:i+lvl]
              
              #Update ngram counts
              self.add_to_counts(ngram)
              
              #Checking to see that we are not at the last character sequence
              #or the last level
              
              if i < len(str_) - lvl and lvl < self.n:
                
                #If we are not, then we update next character counts
                nxt = str_[i+lvl]
                self.add_to_next(ngram,nxt)
                

    def get_count(self, ngram):
        ''' Return the count of ngram '''
        try:
          count = self.d[ngram]
        except KeyError:
          count = 0
        return count
    
    def get_next(self, ngram, nxt):
        ''' Return the count of ngram '''
        assert len(nxt) ==  1
        try:
          count = self.next[(ngram,nxt)]
        except KeyError:
          count = 0
        return count

    def get_extensions(self, ngram):
        ''' Return a list of tuples: extensions of h and their counts '''

        #Gathering the 'next' counts that match the processed history
        return [(nxt,count) for ((pos_ngram, nxt), count) in self.next.items()\
                if pos_ngram==ngram]
    
    def train(self, training_data):
      for text in training_data:
        self.update(text)


# Text generation functions

In [None]:
def random_char(count_tree, h):
    ''' Returns a random character to follow h by its distribution '''
    #Get random number for threshold
    r = random.random()

    #Preprocess history (as described in hist_prep definition) to get ngram
    ngram = count_tree.hist_prep(h,count_tree.n-1)

    #Get ngram count
    ttl_count = count_tree.get_count(ngram)

    #Get list of next characters and counts
    next_list = count_tree.get_extensions(ngram)

    #Convert counts to probabilities and sort
    distr = sorted([(nxt,count/ttl_count) for (nxt,count) in next_list],
                   key = lambda x: x[0])


    prob_count = 0

    #Iterate over distribution
    for nxt, prob in distr:

      #Sum up probabilites
      prob_count += prob

      #Stop and return once we have exceeded random threshold
      if prob_count > r:
        return nxt

In [None]:
def random_text(count_tree, length):
    ''' Returns a random text with given length '''
    text = '<'
    next = ''
    counter = 0
    
    while next != '>' and counter < length:
      next = random_char(count_tree, text)
      text += next
      counter += 1

    return text

In [None]:
def generate_text_all_models(path,length):
  for cat in cats:

    with open(path + cat + '.pickle', 'rb') as handle:
      mdl = pickle.load(handle)
    
    print('Generating text for category:',cat)
    print(random_text(mdl, length),'\n')

# Training function

In [None]:
def train_models(df, out_path, cats, depth=2):
  
  print('Training models:')
  num_cats = len(cats)

  for i, cat in enumerate(cats, 1):
    print('Category {}/{}:{}'.format(i,num_cats,cat))
    mdl = CountTree(n=depth)
    data = list(df[df.category_1 == cat].selftext)

    mdl.train(data)

    with open('{}{}.pickle'.format(out_path,cat), 'wb') as handle:
      pickle.dump(mdl, handle, protocol=pickle.HIGHEST_PROTOCOL)
  
  print('Done training!')

# Hyperparameters and training setup

In [None]:
cats = ['writing&stories',
'tv_show',
'autos',
'hardware&tools',
'electronics']
# 'video_game',
# 'crypto',
# 'sports',
# 'hobby',
# 'appearance'
# ,
# 'card_game',
# 'drugs',
# 'advice&question',
# 'social_group',
# 'anime&manga',
# 'sex&relationships',
# 'software',
# 'health',
# 'animals',
# 'arts',
# 'programming',
# 'rpg',
# 'books',
# 'parenting',
# 'education',
# 'company&website',
# 'profession',
# 'music',
# 'politics&viewpoint',
# 'stem',
# 'travel',
# 'geo',
# 'religion&supernatural',
# 'board_game',
# 'movies',
# 'food&drink',
# 'finance&money',
# 'meta']

In [None]:
path = '/content/gdrive/MyDrive/CIS 522 Final Project/ngram_models/depth10/'
depth = 10

# Training models

In [None]:
%%time
df = get_data('data/')

Getting data
data read in. starting dataframe join:
dataframes joined. starting text preprocessing:
Text processed.
CPU times: user 1min 13s, sys: 1.6 s, total: 1min 15s
Wall time: 1min 15s


In [None]:
%%time
train_models(df,path, [cats[0]], depth=depth)

Training models:
Category 1/1:writing&stories
Done training!
CPU times: user 9min 41s, sys: 10.5 s, total: 9min 51s
Wall time: 9min 56s


# Text generation

In [None]:
generate_text_all_models(path,100)

Generating text for category: writing&stories
<a washable folder destiny is a new computer and i went home with me i have looked lb lb he's awkward 

Generating text for category: tv_show
<ok so i'm wondering if anyone else's opinion what happened so far even an answer and try not to ment 

Generating text for category: autos
<i have a volvo c t it is a handling and changed new pla from the second time this happen with these  

Generating text for category: hardware&tools
<i looked around that much  i have some money   light control you have any question but i want   sold 

Generating text for category: electronics
<so been higher frequency it's also a wealth of information i have the police know some people buy on 



# Classification functions

In [None]:
def ad_prob(count_tree, w, h, d,verbose=False):
    ''' Returns the probability of w given h '''
    if count_tree.n == 1:
      return count_tree.get_next('',w)/count_tree.get_count('')
    def disp(*text):
      if verbose:
        print(*text)

    disp("w:",w)
    disp("h:",h)
    disp("d:",d)
    if len(h) <= count_tree.n - 1:
      ngram = h
    else:
      ngram = h[-count_tree.n+1:]
    disp("ngram:",ngram)
    count = count_tree.get_count(ngram)
    disp("count:",count)
    next = count_tree.get_next(ngram,w)
    disp("next:",next)

    if ngram == '':
      disp("Empty Probability:",next/count,'\n')
      return next/count
    elif count == 0:
      disp("COUNT ZERO")
      disp("\n")
      return ad_prob(count_tree, w, ngram[1:], d,verbose)
    else:
      beta_history = ngram[1:]
      disp("beta history:",beta_history)
      if ngram[0] == '<':
        d = d[:len(ngram)] 
      disc = d[-len(ngram)]
      disp("discount:",disc)
      beta_prob = ad_prob(count_tree, w, beta_history, d,verbose)
      disp("beta_prob:",beta_prob)
      adjstd_mle = max([next - disc, 0])/count
      disp("adjusted mle:",adjstd_mle)
      lambda_ = (disc * len(count_tree.get_extensions(ngram)))/count
      disp("labmda_:",lambda_)
      disp("Probability:",adjstd_mle + (lambda_ * beta_prob))
      return adjstd_mle + (lambda_ * beta_prob)


In [None]:
def perplexity(count_tree, text,verbose=False):
    ''' Returns the perplexity of the given text '''

    d = [.9 for i in range(count_tree.n-1)]

    def disp(*text):
      if verbose:
        print(*text)
    
    text = '<' + preprocess_text(text) + '>'
    sup = 0
    disp("sup:",0)
    for i in range(1,len(text)):
      h = text[:i]
      w = text[i]
      prb = ad_prob(count_tree, w, h, d,verbose)
      if prb == 0:
        return math.inf
      else:
        sup -= math.log(prb,2)
      disp("total sup:",sup,'\n')
    perp = 2**(sup/(len(text)-1))
    return perp

In [None]:
def classify(path, text):
  
  perp_list = []

  for cat in cats:

    print('Category: {}:'.format(cat))

    with open(path + cat + '.pickle', 'rb') as handle:
      mdl = pickle.load(handle)

    perp = perplexity(mdl, text)

    print('Perplexity: {}\n'.format(perp))

    perp_list.append(perp)
  
  best_cat = cats[np.argmin(perp_list)]
  
  return best_cat

# Classify Text

In [None]:
text = "I just finished Scandal and I'm looking for a new show."
best_cat = classify(path, preprocess_text(text))
print(best_cat)

Category: writing&stories:
Perplexity: 3.4059914895981405

Category: tv_show:
Perplexity: 2.2930703335426306

Category: autos:
Perplexity: 3.4246802047921605

Category: hardware&tools:
Perplexity: 3.1555841493501213

Category: electronics:
Perplexity: 3.2564956621137218

tv_show
