In [1]:
import os
import sys
# 如果当前代码文件运行测试需要加入修改路径，避免出现后导包问题
BASE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.join(BASE_DIR))

PYSPARK_PYTHON = "/miniconda2/envs/reco_sys/bin/python"
# 当存在多个版本时，不指定很可能会导致出错
os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON

from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import LogisticRegressionModel
from offline import SparkSessionBase

class CtrLogisticRegression(SparkSessionBase):

    SPARK_APP_NAME = "ctrLogisticRegression"
    ENABLE_HIVE_SUPPORT = True

    def __init__(self):

        self.spark = self._create_spark_hbase()

ctr = CtrLogisticRegression()

In [2]:
# 2、读取用户点击行为表，与用户画像和文章画像，构造训练样本
ctr.spark.sql('use profile')
news_article_basic = ctr.spark.sql("select user_id, article_id, channel_id, clicked from user_article_basic")

In [3]:
news_article_basic.show()

+-------------------+-------------------+----------+-------+
|            user_id|         article_id|channel_id|clicked|
+-------------------+-------------------+----------+-------+
|1105045287866466304|              14225|         0|  false|
|1106476833370537984|              14208|         0|  false|
|1109980466942836736|              19233|         0|  false|
|1109980466942836736|              44737|         0|  false|
|1109993249109442560|              17283|         0|  false|
|1111189494544990208|              19322|         0|  false|
|1111524501104885760|              44161|         0|  false|
|1112727762809913344|              18172|        18|   true|
|1113020831425888256|1112592065390182400|         0|  false|
|1114863735962337280|              17665|         0|  false|
|1114863741448486912|              14208|         0|  false|
|1114863751909081088|              13751|         0|  false|
|1114863846486441984|              17940|         0|  false|
|1114863941936218112|   

In [4]:
# 获取用户画像的数据
user_profile_hbase = ctr.spark.sql(
    "select user_id, information.birthday, information.gender, article_partial, env from user_profile_hbase")
user_profile_hbase = user_profile_hbase.drop('env')

In [None]:
user_profile_hbase.show()

+--------------------+--------+------+--------------------+
|             user_id|birthday|gender|     article_partial|
+--------------------+--------+------+--------------------+
|              user:1|     0.0|  null|Map(18:vars -> 0....|
|             user:10|     0.0|  null|Map(18:tp2 -> 0.2...|
|             user:11|     0.0|  null|               Map()|
|user:110249052282...|     0.0|  null|               Map()|
|user:110319567345...|    null|  null|Map(18:Animal -> ...|
|user:110504528786...|    null|  null|Map(18:text -> 0....|
|user:110509388310...|    null|  null|Map(18:text -> 0....|
|user:110510518565...|    null|  null|Map(18:SHOldboySt...|
|user:110639618314...|    null|  null|Map(18:tp2 -> 0.2...|
|user:110647320376...|    null|  null|Map(18:text -> 0....|
|user:110647683337...|    null|  null|Map(18:text -> 1....|
|user:110826490119...|    null|  null|Map(18:text -> 0....|
|user:110997636345...|    null|  null|Map(18:text -> 0....|
|user:110997980510...|    null|  null|Ma

In [None]:
# 对用户ID做处理
def get_user_id(row):
    return int(row.user_id.split(':')[1]), row.birthday, row.gender, row.article_partial

user_profile_hbase = user_profile_hbase.rdd.map(get_user_id)

In [None]:
# 对于其中toDF存在一些列没办法确定类型，手动指定DataFrame列的类型
_schema = StructType([
    StructField('user_id', LongType()),
    StructField('birthday', DoubleType()),
    StructField('gender', BooleanType()),
    StructField('article_partial', MapType(StringType(), DoubleType()))
])

user_profile_hbase = ctr.spark.createDataFrame(user_profile_hbase, schema=_schema)

In [None]:
user_profile_hbase.show()

+-------------------+--------+------+--------------------+
|            user_id|birthday|gender|     article_partial|
+-------------------+--------+------+--------------------+
|                  1|     0.0|  null|Map(18:vars -> 0....|
|                 10|     0.0|  null|Map(18:tp2 -> 0.2...|
|                 11|     0.0|  null|               Map()|
|1102490522829717504|     0.0|  null|               Map()|
|1103195673450250240|    null|  null|Map(18:Animal -> ...|
|1105045287866466304|    null|  null|Map(18:text -> 0....|
|1105093883106164736|    null|  null|Map(18:text -> 0....|
|1105105185656537088|    null|  null|Map(18:SHOldboySt...|
|1106396183141548032|    null|  null|Map(18:tp2 -> 0.2...|
|1106473203766657024|    null|  null|Map(18:text -> 0....|
|1106476833370537984|    null|  null|Map(18:text -> 1....|
|1108264901190615040|    null|  null|Map(18:text -> 0....|
|1109976363453906944|    null|  null|Map(18:text -> 0....|
|1109979805106831360|    null|  null|Map(18:text -> 0...

In [None]:
# 合并用户点击行为表与用户画像表，并进行相应的删除无用特征
train = news_article_basic.join(user_profile_hbase, on=['user_id'], how='left').drop('birthday').drop('channel_id').drop('gender')



In [None]:
train.show()

+-------------------+----------+-------+--------------------+
|            user_id|article_id|clicked|     article_partial|
+-------------------+----------+-------+--------------------+
|1106473203766657024|     16005|  false|Map(18:text -> 0....|
|1106473203766657024|     14335|  false|Map(18:text -> 0....|
|1106473203766657024|     13778|  false|Map(18:text -> 0....|
|1106473203766657024|     13039|  false|Map(18:text -> 0....|
|1106473203766657024|     13648|  false|Map(18:text -> 0....|
|1106473203766657024|     17304|  false|Map(18:text -> 0....|
|1106473203766657024|     19233|  false|Map(18:text -> 0....|
|1106473203766657024|     44466|  false|Map(18:text -> 0....|
|1106473203766657024|     18795|  false|Map(18:text -> 0....|
|1106473203766657024|    134812|  false|Map(18:text -> 0....|
|1106473203766657024|     13357|  false|Map(18:text -> 0....|
|1106473203766657024|     19171|  false|Map(18:text -> 0....|
|1106473203766657024|     44104|  false|Map(18:text -> 0....|
|1106473

In [None]:
# 合并文章的向量以及文章的权重特征，文章所属的真正频道ID
ctr.spark.sql('use article')
article_vector = ctr.spark.sql("select * from article_vector")

In [None]:
train_user_article = train.join(article_vector, on=['article_id'], how='left')

In [None]:
train_user_article.show()

In [None]:
# 读取文章画像
article_profile = ctr.spark.sql("select article_id, keywords from article_profile")

def get_article_weights(row):
    
    try:
        weights = sorted(row.keywords.values())[:10]
    except Exception as e:
        weights = [0.0] * 10
    
    return row.article_id, weights

article_profile = article_profile.rdd.map(get_article_weights).toDF(['article_id', 'article_weights'])

In [None]:
# 合并文章权重与样本
train_user_article = train_user_article.join(article_profile, on=['article_id'], how='left')

In [None]:
train_user_article.show()

In [None]:
# 保留了用户的每个频道的关键词权重，找到用户对应操作文章的所属频道的关键词权重
train_user_article = train_user_article.dropna()

In [None]:
train_user_article.show()

In [None]:
train_user_article

In [None]:
columns = ['article_id', 'user_id', 'channel_id', 'articlevector', 'user_weights', 'article_weights', 'clicked']
def get_user_weights(row):

    from pyspark.ml.linalg import Vectors
    try:
        user_weights = sorted([row.article_partial[key] for key in row.article_partial.keys() if key.split(':')[0] == str(row.channel_id)])[
                  :10]
    except Exception:
        user_weights = [0.0] * 10

    return row.article_id, row.user_id, row.channel_id, Vectors.dense(row.articlevector), Vectors.dense(
        user_weights), Vectors.dense(row.article_weights), int(row.clicked)

train_vector = train_user_article.rdd.map(get_user_weights).toDF(columns)





In [None]:
# 收集所有特征到一个features列
train_res = VectorAssembler().setInputCols(columns[2:6]).setOutputCol('features').transform(train_vector)

In [46]:
train_res.show()

DataFrame[article_id: bigint, user_id: bigint, channel_id: bigint, articlevector: vector, user_weights: vector, article_weights: vector, clicked: bigint, features: vector]