## PySpark训练word2vec实现内容相似推荐

实现步骤：
1. 获取文章列表数据，包括ID、标题、内容
2. 使用jieba实现中文分词
3. 送入pyspark实现word2vec的训练，得到文章向量
4. 对于输入的ID，计算最相似的文章列表

### 1. 获取数据

In [1]:
import pandas as pd

In [2]:
import json

In [3]:
# SELECT id,post_title,post_content FROM `wp_posts` WHERE post_status='publish' and post_type='post'
with open("./datas/wp_posts.json") as fin:
    data = json.loads(fin.read())

In [4]:
df = pd.DataFrame(data[2]["data"])
df.head(3)

Unnamed: 0,id,post_title,post_content
0,78,JavaScript对Select的子元素Option的操作,"<ul>\r\n\t<li>\r\n<h3 style=""color: red;"">java..."
1,83,当用header方法输出内容时出现“Cannot modify header informa...,<h3>&nbsp;&nbsp;&nbsp; 解决方法有2种：</h3>\r\n<ol>\r...
2,85,linux下禁止机箱蜂鸣方法,1、在图像界面下 ，注意，是图形界面下，即使在图像界面下按快捷键出现的虚拟终端里\r\n\r...


#### 把网页内容的HTML去除

In [5]:
import re 
from bs4 import BeautifulSoup

def clean_post_cont(x):
    soup = BeautifulSoup(x, 'html.parser')
    result = soup.get_text()
    return re.sub(r"\r|\n|\t", "", result)

df["post_content"] = df["post_content"].map(clean_post_cont)

In [6]:
df.head(3)

Unnamed: 0,id,post_title,post_content
0,78,JavaScript对Select的子元素Option的操作,javascript_删除所有select下面的option的方法//增加之前删除所有opt...
1,83,当用header方法输出内容时出现“Cannot modify header informa...,解决方法有2种：使用ultraEditor打开该文件，然后变成utf-8编码，就会发...
2,85,linux下禁止机箱蜂鸣方法,1、在图像界面下 ，注意，是图形界面下，即使在图像界面下按快捷键出现的虚拟终端里用xset ...


### 2. 使用jieba实现中文分词

pip install jieba

In [7]:
import jieba

In [8]:
def do_cut_words(param_df):
    # 标题加上关键词，是整个待分词的句子
    sentence = param_df["post_title"]+","+(param_df["post_content"])
    # 调用分词
    words = list(jieba.cut(sentence))
    # 做过滤，变成小写
    result = []
    for word in words:
        if not word or len(word)==0 or len(word)==1: 
            continue
        word = word.lower()
        result.append(word)
    return " ".join(result)

df["words"] = df.apply(do_cut_words, axis=1)

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.731 seconds.
Prefix dict has been built successfully.


In [9]:
df.head(5)

Unnamed: 0,id,post_title,post_content,words
0,78,JavaScript对Select的子元素Option的操作,javascript_删除所有select下面的option的方法//增加之前删除所有opt...,javascript select 元素 option 操作 javascript 删除 所...
1,83,当用header方法输出内容时出现“Cannot modify header informa...,解决方法有2种：使用ultraEditor打开该文件，然后变成utf-8编码，就会发...,当用 header 方法 输出 内容 出现 cannot modify header inf...
2,85,linux下禁止机箱蜂鸣方法,1、在图像界面下 ，注意，是图形界面下，即使在图像界面下按快捷键出现的虚拟终端里用xset ...,linux 禁止 机箱 蜂鸣 方法 图像 界面 注意 图形界面 即使 图像 界面 快捷键 出...
3,87,硬盘分区表丢失、修复大事记--分区表修复利器testdisk,今天是2009年11月14日，就在刚才，我找回了前几天丢失分区表的硬盘分区，特此记录。事件起...,硬盘分区 丢失 修复 大事记 -- 分区表 修复 利器 testdisk 今天 2009 1...
4,91,vi编辑器命令,vi编辑器的文字说明模式：命令模式，编辑模式，末行模式。切换方式：命令模式→i→编辑模式，编...,vi 编辑器 命令 vi 编辑器 文字说明 模式 命令 模式 编辑 模式 末行 模式 切换 ...


In [10]:
# 保存成CSV
df[["id", "post_title", "words"]].to_csv("./datas/crazyant_blog_articles_wordsegs.csv", index=False)

### 3. 使用pyspark训练word2vec

In [11]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
spark = SparkSession \
    .builder \
    .appName("test pyspark") \
    .getOrCreate()

sc = spark.sparkContext

#### Pyspark读取CSV数据

In [12]:
df = spark.read.csv("./datas/crazyant_blog_articles_wordsegs.csv", header=True)
df.show(5)

+---+-----------------------------------+-------------------------------+
| id|                         post_title|                          words|
+---+-----------------------------------+-------------------------------+
| 78|              JavaScript对Select...|           javascript select...|
| 83|    当用header方法输出内容时出现...|    当用 header 方法 输出 内...|
| 85|            linux下禁止机箱蜂鸣方法|   linux 禁止 机箱 蜂鸣 方法...|
| 87|硬盘分区表丢失、修复大事记--分区...|硬盘分区 丢失 修复 大事记 --...|
| 91|                       vi编辑器命令|   vi 编辑器 命令 vi 编辑器 ...|
+---+-----------------------------------+-------------------------------+
only showing top 5 rows



In [13]:
from pyspark.sql import functions as F
from pyspark.sql import types as T

In [14]:
# 把非常的字符串格式变成LIST形式
df = df.withColumn('words_split', F.split(df.words, " "))

#### 实现word2vec的训练与转换

In [15]:
# https://spark.apache.org/docs/2.4.6/ml-features.html#word2vec

from pyspark.ml.feature import Word2Vec

word2Vec = Word2Vec(
    vectorSize=5, 
    minCount=0, 
    inputCol="words_split", 
    outputCol="word2vec")

model = word2Vec.fit(df)

# 注意这一步，会得到整个doc的word embedding
df_word2vec = model.transform(df)

In [16]:
df_word2vec.printSchema()

root
 |-- id: string (nullable = true)
 |-- post_title: string (nullable = true)
 |-- words: string (nullable = true)
 |-- words_split: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- word2vec: vector (nullable = true)



In [17]:
df_word2vec.select("word2vec").show(3, truncate=False)

+-------------------------------------------------------------------------------------------------------+
|word2vec                                                                                               |
+-------------------------------------------------------------------------------------------------------+
|[-0.06689483870158566,-0.31063375444468727,0.09132219613580914,-0.2524164506121651,0.2576837945949532] |
|[-0.09172025156525396,0.009816996645318663,0.19234123595467792,0.20329144122306167,0.11286570988141183]|
|[-0.09057465753391046,0.004765861371362751,0.24302226597633722,0.1654780033594844,0.12196611434178935] |
+-------------------------------------------------------------------------------------------------------+
only showing top 3 rows



In [18]:
df_word2vec.select("id", "post_title", "word2vec") \
           .toPandas() \
           .to_csv('./datas/crazyant_blog_articles_word2vec.csv', index=False)

### 4. 对于给定文章算出最相似的10篇文章

In [19]:
df = pd.read_csv("./datas/crazyant_blog_articles_word2vec.csv")
df.head(3)

Unnamed: 0,id,post_title,word2vec
0,78,JavaScript对Select的子元素Option的操作,"[-0.06689483870158566,-0.31063375444468727,0.0..."
1,83,当用header方法输出内容时出现“Cannot modify header informa...,"[-0.09172025156525396,0.009816996645318663,0.1..."
2,85,linux下禁止机箱蜂鸣方法,"[-0.09057465753391046,0.004765861371362751,0.2..."


In [20]:
import numpy as np
import json

In [21]:
df["word2vec"] = df["word2vec"].map(lambda x : np.array(json.loads(x)))

In [22]:
df.head(3)

Unnamed: 0,id,post_title,word2vec
0,78,JavaScript对Select的子元素Option的操作,"[-0.06689483870158566, -0.31063375444468727, 0..."
1,83,当用header方法输出内容时出现“Cannot modify header informa...,"[-0.09172025156525396, 0.009816996645318663, 0..."
2,85,linux下禁止机箱蜂鸣方法,"[-0.09057465753391046, 0.004765861371362751, 0..."


In [28]:
# 随便挑选一篇文章ID，2583：pandas，581：PHP
article_id = 581
df.loc[df["id"]==article_id]

Unnamed: 0,id,post_title,word2vec,sim_value
78,581,PHP对文件的操作总结,"[-0.0007125509248077937, -0.019375492545679725...",-0.844729


In [29]:
article_embedding = df.loc[df["id"]==article_id, "word2vec"].iloc[0]
article_embedding

array([-0.00071255, -0.01937549,  0.15768537,  0.16893182,  0.21808635])

In [30]:
# 余弦相似度
from scipy.spatial import distance
df["sim_value"] = df["word2vec"].map(lambda x : 1 - distance.cosine(article_embedding, x))

In [31]:
df[["id", "post_title", "sim_value"]].head(3)

Unnamed: 0,id,post_title,sim_value
0,78,JavaScript对Select的子元素Option的操作,0.218421
1,83,当用header方法输出内容时出现“Cannot modify header informa...,0.887608
2,85,linux下禁止机箱蜂鸣方法,0.881212


In [32]:
# 按相似度降序排列，查询前10条
df.sort_values(by="sim_value", ascending=False).head(10)

Unnamed: 0,id,post_title,word2vec,sim_value
78,581,PHP对文件的操作总结,"[-0.0007125509248077937, -0.019375492545679725...",1.0
131,1057,Magento获取指定分类下的所有子分类信息,"[0.0023514057645497226, 0.015581226250487228, ...",0.989693
97,724,Python关于apply的知识,"[-0.016336282122112508, 0.021025621154458843, ...",0.986838
77,576,PHP操作符可变变量测试变量等总结,"[-0.011872579748280342, -0.019191751616759557,...",0.973529
182,1703,Java怎样创建两个KEY（key-pair）的MAP,"[0.02837977321652619, 0.0177825763158194, 0.07...",0.963872
25,138,获取服务器传来的数据-必须用JS去空格,"[-0.08882668862100114, -0.020467974797620907, ...",0.96337
120,930,PHP和MySQL处理树状、分级、无限分类、分层数据的方法,"[0.03533486779765854, 0.009575359927538184, 0....",0.958006
9,105,PHP安全笔记,"[-0.027503433840624262, -0.025421138406898674,...",0.951645
47,236,C++数组类型学习笔记,"[-0.01885643251955992, 0.004926461203636049, 0...",0.948503
89,675,数据采集必备知识-php计划任务的实现,"[-0.0685060779376628, 0.011630317228794181, 0....",0.945682
