In [1]:
from pyspark import SparkContext
sc = SparkContext()

In [2]:
# Data file at https://www.cse.ust.hk/msbd5003/data

lines = sc.textFile('./剑指offer/Untitled.ipynb', 8)

In [3]:
lines.count()

477

In [4]:
lines.getNumPartitions()

8

In [5]:
lines.take(5)

[u'{',
 u' "cells": [',
 u'  {',
 u'   "cell_type": "code",',
 u'   "execution_count": 1,']

In [6]:
# Converting lines into word pairs. 
# Data is dirty: some lines have more than 2 words, so filter them out.
pairs = lines.map(lambda l: tuple(l.split())).filter(lambda p: len(p)==2)
pairs.cache()

PythonRDD[4] at RDD at PythonRDD.scala:48

In [7]:
pairs.take(5)

[(u'"cells":', u'['),
 (u'"cell_type":', u'"code",'),
 (u'"execution_count":', u'1,'),
 (u'"metadata":', u'{},'),
 (u'"outputs":', u'[')]

In [8]:
N = pairs.count()

In [9]:
N

406

In [10]:
# Compute the frequency of each pair.
# Ignore pairs that not frequent enough
pair_freqs = pairs.map(lambda p: (p,1)).reduceByKey(lambda f1, f2: f1 + f2) \
                  .filter(lambda pf: pf[1] >= 2)

In [11]:
pair_freqs.take(5)

[((u'"source":', u'['), 2),
 ((u'"outputs":', u'['), 2),
 ((u'"metadata":', u'{'), 2),
 ((u'"cell_type":', u'"code",'), 2)]

In [12]:
# Computing the frequencies of the adjectives and the nouns
a_freqs = pairs.map(lambda p: (p[0],1)).reduceByKey(lambda x,y: x+y)
n_freqs = pairs.map(lambda p: (p[1],1)).reduceByKey(lambda x,y: x+y)

In [13]:
a_freqs.take(5)

[(u'"\\"cols\\"::', 1),
 (u'"counts:', 1),
 (u'"\\"\\\\u001b[0;31mKeyboardInterrupt\\\\u001b[0m::', 1),
 (u'"\\"nbformat_minor\\"::', 1),
 (u'"\\\\u001b[0;36m<module>\\\\u001b[0;34m()\\\\u001b[0m\\\\n\\\\u001b[1;32m:',
  1)]

In [14]:
n_freqs.take(3)

[(u'"code",', 2), (u'16,', 2), (u'2,', 1)]

In [15]:
# n_freqs.collectAsMap()  #{u'"': 1,u'"",': 2,u'"KeyboardInterrupt",': 1,
# n_freqs.collect()  #[(u'"code",', 2),(u'16,', 2),

In [25]:
# Broadcasting the adjective and noun frequencies. 
# a_dict = a_freqs.collectAsMap()
# # a_dict = sc.parallelize(a_dict).map(lambda x: x)

# n_dict = n_freqs.collectAsMap()
# # n_dict = sc.parallelize(n_dict).map(lambda x: x)

n_dict = sc.broadcast(n_freqs.collectAsMap())
a_dict = sc.broadcast(a_freqs.collectAsMap())
# a_dict.value['violent']

In [20]:
from math import *

# Computing the PMI for a pair.
def pmi_score(pair_freq):
    w1, w2 = pair_freq[0]
    f = pair_freq[1]
#     pmi = log(float(f)*N/(a_dict[w1]*n_dict[w2]), 2)
    pmi = log(float(f)*N/(a_dict.value[w1]*n_dict.value[w2]), 2)
    return pmi, (w1, w2)

In [21]:
# Computing the PMI for all pairs.
scored_pairs = pair_freqs.map(pmi_score)

In [23]:
# Printing the most strongly associated pairs. 
scored_pairs.top(10)

[(7.6653359171851765, (u'"cell_type":', u'"code",')),
 (5.6653359171851765, (u'"source":', u'[')),
 (5.6653359171851765, (u'"outputs":', u'[')),
 (5.08037341646402, (u'"metadata":', u'{'))]

In [None]:
|