In [2]:
import re
import math
from pyspark import SparkConf, SparkContext


sc = SparkContext(conf=SparkConf().setAppName("MyApp").setMaster("yarn"))

def parse_article(line):
    try:
        article_id, text = line.rstrip().split('\t', 1)
        text = re.sub("^\W+|\W+$", "", text, flags=re.UNICODE)
        words = re.split("\W*\s+\W*", text, flags=re.UNICODE)
        return words
    except ValueError as e:
        return []
    
def lower_all(words):
    result = []
    for word in words:
        result.append(word.lower())
    return result
        
def bigrams_getter(words):
    pairs = []
    for i in range(len(words)-1):
        pair = words[i].lower() + '_' + words[i+1].lower()
        pairs.append((pair, 1))
    return pairs

def word_count(words):
    counts = []
    for word in words:
        counts.append((word.lower(), 1))
    return counts

def remove_stopwords(words):
    result = []
    for word in words:
        if word not in stopwords_broadcast.value:
            result.append(word)
    return result

def npmi(value):
    pair, count = value
    word_1, word_2 = pair.split("_")
    word_1_count = word_pairs_map.value[word_1]
    word_2_count = word_pairs_map.value[word_2]
    
    pair_prob = float(count) / total_pairs.value
    word_1_prob = float(word_1_count) / total_words.value
    word_2_prob = float(word_2_count) / total_words.value
    
    pmi = math.log(pair_prob / (word_1_prob * word_2_prob))
    npmi = pmi / (-1 * math.log(pair_prob))
    return (pair, npmi)

# get stopwords
with open("/datasets/stop_words_en.txt", "r") as f:
    stopwords = f.read().splitlines()
    
stopwords_broadcast = sc.broadcast(stopwords)

# wiki words without stopwords
wiki = sc.textFile("/data/wiki/en_articles_part/articles-part", 16).map(parse_article).map(lower_all).map(remove_stopwords).cache()

In [12]:
# get rdd pairs (words and bigrams)
bigram_pairs = wiki.flatMap(lambda x: bigrams_getter(x)).reduceByKey(lambda x,y: x+y).cache()
word_pairs = wiki.flatMap(lambda x: word_count(x)).reduceByKey(lambda x,y: x+y).cache()

In [13]:
# set total words and total bigrams as broadcast vars
total_words = word_pairs.map(lambda value: value[1]).sum()
total_words = sc.broadcast(total_words)

total_pairs = bigram_pairs.map(lambda value: value[1]).sum()
total_pairs = sc.broadcast(total_pairs)

In [14]:
# set word_pairs as map to allow easy fetching and broadcast it
word_pairs_map = word_pairs.collectAsMap()
word_pairs_map = sc.broadcast(word_pairs_map)

In [16]:
result = bigram_pairs.filter(lambda x: x[1] >= 500).map(lambda x: npmi(x)).sortBy(lambda value: value[1], ascending=False).take(39)

for val in result:
    print(val[0])

los_angeles
external_links
united_states
prime_minister
san_francisco
et_al
new_york
supreme_court
19th_century
20th_century
references_external
soviet_union
air_force
baseball_player
university_press
roman_catholic
united_kingdom
references_reading
notes_references
award_best
north_america
new_zealand
civil_war
catholic_church
world_war
war_ii
south_africa
took_place
roman_empire
united_nations
american_singer-songwriter
high_school
american_actor
american_actress
american_baseball
york_city
american_football
years_later
north_american
