# Redactle

Train a decision tree to narrow down the topic of an arbitrary wikipedia article, motivated by the game [redactle](redactle.com).

In [1]:
import re
import os
import bs4
import json
import nltk
import time
import string
import pickle
import pathlib
import requests
import collections
import urllib.parse
import sklearn.tree
import multiprocessing

import numpy as np
import matplotlib.pyplot as plt

In [2]:
ROOT_URL = 'https://en.wikipedia.org/'
DATA_DIR = pathlib.Path('articles')
assert DATA_DIR.exists()

In [3]:
#nltk.download('stopwords')

In [4]:
stop_words = nltk.corpus.stopwords.words('english')

In [5]:
with open('articles.json', 'r') as infile:
    articles_data = json.load(infile)

In [6]:
def get_all_leaves(tree_dict):
    if 'children' not in tree_dict.keys():
        return [tree_dict['id']]
    leaves = []
    for child in tree_dict['children']:
        leaves.extend(get_all_leaves(child))
    return leaves

In [7]:
leaves = get_all_leaves(articles_data)

In [8]:
ROOT_URL + leaves[0]

'https://en.wikipedia.org//wiki/Julie_Andrews'

In [9]:
def get_articles_by_category(articles):
    categories = {}
    for category in articles['children']:
        categories[category['name']] = get_all_leaves(category)
    return categories

In [10]:
categories = get_articles_by_category(articles_data)

In [10]:
def download_articles(articles):
    for label, url_paths in articles.items():
        os.makedirs(DATA_DIR / label)
        for url_path in url_paths:
            time.sleep(0.1)
            request = requests.get(ROOT_URL + url_path)
            if request.status_code == 200:
                html = request.content
                title = os.path.split(url_path)[-1]
                with open(DATA_DIR / label / f'{title}.html', 'wb') as outfile:
                    outfile.write(html)

In [11]:
#download_articles(get_articles_by_category(articles_data))

In [12]:
def extract_text(html):
    soup = bs4.BeautifulSoup(html, 'html.parser')
    content = soup.find('div', id='content').find('div', id='bodyContent').find('div', id='mw-content-text').find('div')
    if content.find(class_='shortdescription'): content.find(class_='shortdescription').decompose()
    if content.find(id='References'): content.find(id='References').decompose()
    if content.find(class_='reflist'): content.find(class_='reflist').decompose()
    for i in content.find_all(class_='mwe-math-element'): i.decompose()
    text = '\n'.join([i.get_text() for i in content if i.name == 'p' or i.name == 'h1' or i.name == 'h2' or i.name == 'h3' or i.name == 'ul'])
    text = text.split('See also[edit]')[0]
    text = re.sub('[\(\[].*?[\)\]]', '', text)
    text = text.lower()
    return text

In [13]:
def digest_articles(directory):
    for article in os.listdir(directory):
        if os.path.splitext(article)[-1] != '.html': continue
        with open(os.path.join(directory, article), 'rb') as infile:
            html = infile.read()
        text = extract_text(html)
        tokens = nltk.tokenize.word_tokenize(text)
        tokens = [token for token in tokens if token not in stop_words]
        tokens = [token for token in tokens if not any([(punc in token) for punc in string.punctuation])]
        counts = collections.Counter(tokens)
        filename = os.path.splitext(article)[0]+'.txt'
        with open(os.path.join(directory, filename), 'w') as outfile:
            outfile.write(f'ALL_WORDS:{sum(counts.values())}\n')
            title = os.path.splitext(article)[0]
            outfile.write(f'TITLE_NUM_WORDS:{len(title.split("_"))}\n')
            for i in range(5):
                if i < len(title.split('_')):
                    w = title.split('_')[i]
                    outfile.write(f'TITLE_WORD_{i}_LEN:{len(w)}\n')
            for word, count in counts.items():
                if count > 1:
                    outfile.write(f'{word}:{count}\n')

In [372]:
categories = [os.path.join(DATA_DIR, category) for category in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, category))]

with multiprocessing.Pool(len(categories)) as p:
    p.map(digest_articles, categories)

In [14]:
def load_word_list(directory):
    words = collections.Counter()
    for category in os.listdir(directory):
        for article in os.listdir(os.path.join(directory, category)):
            if os.path.splitext(article)[-1] != '.txt': continue
            with open(os.path.join(directory, category, article), 'r') as infile:
                data = infile.read()
            for line in data.split('\n'):
                if line == '': continue
                word = line.split(':')[0]
                words[word] += 1
    words_array = np.array([word for word, count in words.items() if count > 5]) # optimized for getting word of index
    words_dict = {word:i for i, word in enumerate(words_array)} # optimized for getting index of word
    return words_array, words_dict

In [385]:
words_array, words_dict = load_word_list('articles')

In [389]:
with open('words.pkl', 'wb') as outfile:
    pickle.dump([words_array, words_dict], outfile)

In [11]:
with open('words.pkl', 'rb') as infile:
    words_array, words_dict = pickle.load(infile)

In [12]:
len(words_array)

39620

In [61]:
def generate_features(min_count=100, max_label_depth=3):
    
    nleaves = len(leaves)
    features = np.zeros((nleaves, len(words_array)))
    labels = np.zeros(nleaves)
    label_counts = [ ]
    all_labels = [ ]
    idx = 0
    top = None
    tpath = [ ]
    
    def scan(node, depth, label):
        nonlocal idx, top,  tpath
        saved_tpath = list(tpath)
        if depth == 1:
            top = node['name']
            tpath = [ top ]
            label = len(all_labels)
            all_labels.append(node['name'])
            label_counts.append(0)
        elif (depth > 1) and (depth <= max_label_depth) and (node.get('count',0) >= min_count):
            label = len(all_labels)
            tpath.append(node['name'])
            all_labels.append(' / '.join(tpath))
            label_counts.append(0)
            
        children = node.get('children', [])
        if children:
            for child in node['children']:
                scan(child, depth + 1, label)
        else:
            name = node["id"][6:]
            if '/' in name:
                # This only occurs for "HIV/AIDS"
                name = name[:name.index('/')]
            path = DATA_DIR / top / f'{name}.txt'
            with open(path, 'r') as infile:
                data = infile.read()
            for line in data.split('\n'):
                if line == '': continue
                word, count = line.split(':')
                count = int(count)
                if word in words_dict:
                    features[idx, words_dict[word]] = count
            labels[idx] = label
            label_counts[label] += 1
            idx += 1
        tpath = saved_tpath

    scan(articles_data, 0, None)
    return features, labels, all_labels, label_counts

In [62]:
features, labels, all_labels, label_counts = generate_features()

In [63]:
dict(zip(all_labels, label_counts))

{'People': 422,
 'People / Visual artists': 124,
 'People / Writers': 79,
 'People / Writers / Late modern': 176,
 'People / Musicians and composers': 149,
 'People / Philosophers, historians, political and social scientists': 162,
 'People / Religious figures': 124,
 'People / Politicians and leaders': 72,
 'People / Politicians and leaders / Post-classical': 125,
 'People / Politicians and leaders / Early modern period': 116,
 'People / Politicians and leaders / Late modern period': 187,
 'People / Scientists, inventors and mathematicians': 254,
 'History': 253,
 'History / Ancient history': 127,
 'History / Post-classical history': 133,
 'History / Late modern history': 171,
 'Geography': 54,
 'Geography / Physical geography': 192,
 'Geography / Physical geography / Bodies of water': 189,
 'Geography / Countries': 207,
 'Geography / Regions and country subdivisions': 111,
 'Geography / Cities': 274,
 'Geography / Cities / Asia': 177,
 'Arts': 208,
 'Arts / Literature': 56,
 'Arts / 

In [64]:
with open('features-new.pkl', 'wb') as outfile:
    pickle.dump([features, labels, all_labels], outfile)

In [36]:
with open('features-new.pkl', 'rb') as infile:
    features, labels, all_labels = pickle.load(infile)

In [65]:
classifier = sklearn.tree.DecisionTreeClassifier(max_depth = 20, criterion='entropy') #, class_weight='balanced')

In [66]:
%time classifier.fit(features, labels)

CPU times: user 26.5 s, sys: 128 ms, total: 26.6 s
Wall time: 26.7 s


DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='entropy',
                       max_depth=20, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')

In [67]:
classifier.score(features, labels)

0.9754270302667066

In [68]:
top_features = np.argsort(classifier.feature_importances_)[::-1][:200]
print(words_array[top_features])

['father' 'species' 'population' 'used' 'city' 'death' 'career' 'economy'
 'music' 'military' 'earth' 'people' 'world' 'works' 'art' 'language'
 'political' 'may' 'genus' 'god' 'government' 'united' 'work' 'include'
 'designed' 'production' 'cells' 'energy' 'north' 'area' 'games' 'played'
 'climate' 'numbers' 'example' 'literary' 'reign' 'chemical' 'water'
 'TITLE_WORD_1_LEN' 'century' 'treatment' 'ALL_WORDS' 'life' 'country'
 'mathematical' 'family' 'december' 'one' 'compounds' 'theory'
 'philosophy' 'food' 'common' 'stars' 'symptoms' 'many' 'king'
 'university' 'leaves' 'southern' 'published' 'grown' 'TITLE_WORD_0_LEN'
 'human' 'use' 'years' 'later' 'first' 'spoken' 'science' 'often'
 'scientific' 'plants' 'body' 'early' 'term' 'bc' 'also' 'including'
 'story' 'religious' 'army' 'organisms' 'time' 'known' 'languages'
 'humans' 'period' 'high' 'TITLE_NUM_WORDS' 'large' 'defined' 'usually'
 'popular' 'national' 'made' 'person' 'around' 'systems' 'ancient' 'form'
 'unit' 'index' 'surfac

In [69]:
def tree2json(tree, name):
    data = dict(
        classes = all_labels,
        words = [words_array[i] if i >= 0 else "" for i in tree.feature],
        cuts = [int(i) if i >= 0 else 0 for i in tree.threshold],
        below = tree.children_left.tolist(),
        above = tree.children_right.tolist(),
        probs = [p[0].astype(int).tolist() for p in tree.value],
    )
    with open(name, 'w') as fp:
        json.dump(data, fp)

In [70]:
tree2json(classifier.tree_, 'tree20-new-entropy-100-3.json')