In [21]:
import numpy as np

In [28]:
articles = []
labels = []

with open('input/bbc_text_cls.csv') as f:
    running_art = ''
    for l in f.readlines()[1:]:
        if '",' not in l or '", ' in l:
            running_art += f'{l.strip()} '
        else:
            last_line = l.split('",')
            assert len(last_line) == 2
            articles.append(running_art + last_line[0].strip())
            labels.append(last_line[1].strip())
            running_art = ''



In [29]:
punc = '''!()-[]{};:'"\,<>/?@#^&*_~'''

def remove_punc(s):
    ''' takes in a string, removes all "useless" punctuation, lowercases it '''
    no_punc = ''
    for char in s:
        if char not in punc:
            no_punc += char.lower()
    return no_punc

In [30]:
def articles_to_mm(articles):
    # mm structure will be mm[(word_idx-1, word_idx+1)] = [list of words that fit this pattern]
    # for start of line, will use an initial state distribution with structure isd
    isd = {}
    mm = {}

    for article in articles:
        words = remove_punc(article).split()
        assert len(words) > 0

        
        for idx, w in enumerate(words):

            # first word in line for generating the isd
            if idx == 0:
                next_w = words[idx+1]
                if next_w not in isd:
                    isd[next_w] = []
                isd[next_w].append(w)
            
            # generate mm for all words before the end
            # but add an END token if it's the end of a sentence
            elif idx < len(words) - 1:
                prev_w = words[idx-1] # if '.' not in w else 'START'
                next_w = words[idx+1] # if '.' not in w else 'END'

                if (prev_w, next_w) not in mm:
                    mm[(prev_w, next_w)] = []
                mm[(prev_w, next_w)].append(w)
            
            elif idx == len(words) -1:
                prev_w = words[idx-1]
                next_w = 'END'
                if (prev_w, next_w) not in mm:
                    mm[(prev_w, next_w)] = []
                mm[(prev_w, next_w)].append(w)
    
    return isd, mm

isd, mm = articles_to_mm(articles)

In [31]:
def normalize_dict(d):
    ''' convert list of words in isd to dict of word probs '''
    d_norm = {}
    for k,v in d.items():
        # case where there's only one word to choose from
        if len(v) == 1:
            d_norm[k] = {v[0]: 1.}
        
        # multiple words
        elif len(v) > 1:
            count_dict = {}
            for w in v:
                count_dict[w] = v.count(w)
            
            prob_dict = {}
            cum_sum = sum(count_dict.values())
            for count_k,count_v in count_dict.items():
                prob_dict[count_k] = count_v / cum_sum
            
            d_norm[k] = prob_dict
        
    return d_norm

isd_norm = normalize_dict(isd)
mm_norm = normalize_dict(mm)

In [32]:
def pick_from_dist(d):
    # d is a dict with form {key: prob, k2: p2 ... }
    return np.random.choice(list(d.keys()), p=list(d.values()))

def get_random_article(articles):
    return remove_punc(np.random.choice(articles))

In [44]:
my_article = get_random_article(articles).split()
new_article = []
for idx, word in enumerate(my_article):
    if idx == 0:
        next_word = my_article[idx+1]
        new_word = pick_from_dist(isd_norm[next_word])
    
    elif idx < len(my_article) - 1:
        prev_word = my_article[idx-1]
        next_word = my_article[idx+1]
        new_word = pick_from_dist(mm_norm[(prev_word, next_word)])
    
    elif idx == len(my_article) - 1:
        prev_word = my_article[idx-1]
        next_word = 'END'
        new_word = f'{pick_from_dist(mm_norm[(prev_word, next_word)])}. '
    
    new_article.append(new_word)

# the -2 removes the trailing period
new_article_str = ' '.join(new_article)[:-2]

# capitalize first letter
test = [x.capitalize() for x in new_article_str.split('. ')]
final_art = '. '.join(test)
print(final_art)


