In [1]:
sc

In [2]:
# -*- coding: utf-8 -*-

import os
from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql.functions import *

from pyspark.ml import Pipeline
from pyspark.mllib.util import MLUtils
from pyspark.ml.feature import StringIndexer, VectorAssembler, VectorIndexer
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel


class Utils():
    def __init__(self):
        pass

    # 敘述性統計：平均數 標準差
    def getStatValue(self, df, fieldName):
        stat = df.select(avg(fieldName), stddev(fieldName)).collect()
        return stat[0]

class LoadSavedData(Utils):
    # 繼承
    def __init__(self):
        Utils.__init__(self)

    # 載入資料集檔案
    def loadData(self, dataFile):
        sql = 'SELECT * FROM parquet.`%s`' % dataFile
        df = sqlContext.sql(sql)
        return df

    # 列印敘述性統計
    def printStats(self, df, fields=None):
        if fields is None:
            df.describe().show()
        else:
            for field in fields:
                df.describe(field).show()

    # 羅吉斯回歸
    def LR(self, trainingData, testData,
            labelIndexer, features,
           # 自訂變數
            maxIteration=100, regessionParam=0.001):
        # 組合自變數欄位群，並指明衍生欄位名稱
        features = (VectorAssembler()
                        .setInputCols(features)
                        .setOutputCol('features'))

        # 取得羅吉斯回歸介面
        lr = LogisticRegression(labelCol='indexedLabel', featuresCol='features',
                                    maxIter=maxIteration, regParam=regessionParam)

        # 進行羅吉斯回歸分析
        pipeline = Pipeline(stages=[labelIndexer, features, lr])

        # 產生羅吉斯回歸分析模型
        model = pipeline.fit(trainingData)

        # 推測值
        predictions = model.transform(testData)

        return predictions

    # 列印羅吉斯回歸分析結果
    def printStatsLR(self, predictions):
        # 篩選分析結果欄位群
        result = predictions.select('indexedLabel', 'prediction', 'features', 'probability')

        # 篩選預測錯誤資料
        resultError = result.where(result.indexedLabel != result.prediction)
        resultError.show()

        print(u'準確率=%.3f (%d\t%d)' % (1.000 - resultError.count() / result.count(),
                resultError.count(),
                result.count()))

def m_fun(o,c):
    if o>c:
        return 0
    else:
        return 1

# 主程式
def main(dataDir):
    # 類別初始化
    worker = LoadSavedData()

    # 載入資料集
    df = worker.loadData(dataFile='%s/IBM.parquet' % dataDir)
    
    my_m = udf(m_fun, IntegerType())
    df = df.withColumn('result', my_m('open', 'close'))
    
    # 資料隨機抽樣成二群
    (trainingData, testData) = df.randomSplit([0.7, 0.3])

    # 為類別值建立數值對照表
    labelIndexer = StringIndexer(inputCol='result', outputCol='indexedLabel').fit(df)

    # 羅吉斯回歸：指定自變數欄位群
    result = worker.LR(trainingData, testData, labelIndexer, df.columns[1:5])
    result.printSchema()

    # 列印羅吉斯回歸分析結果
    worker.printStatsLR(result)

# 程式進入點
if __name__ == '__main__':
    global sc, sqlContext

    # 本地資源運算
    appName = 'Cup-12'
    master = 'local'

    #sc = SparkContext(conf=SparkConf().setAppName(appName).setMaster(master))

    # 取得資料庫介面
    sqlContext = SQLContext(sc)

    # 調用主程式
    homeDir = os.environ['HOME']
    dirName = 'Data'
    sampleDir = '%s/Sample' % homeDir
    dataDir = '%s/Data' % homeDir

    main(dataDir)


root
 |-- date: date (nullable = true)
 |-- open: double (nullable = true)
 |-- high: double (nullable = true)
 |-- low: double (nullable = true)
 |-- close: double (nullable = true)
 |-- volumn: double (nullable = true)
 |-- adjclose: double (nullable = true)
 |-- result: integer (nullable = true)
 |-- indexedLabel: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = true)

+------------+----------+--------------------+--------------------+
|indexedLabel|prediction|            features|         probability|
+------------+----------+--------------------+--------------------+
|         1.0|       0.0|[122.099998,124.2...|[0.66740492051320...|
|         1.0|       0.0|[118.779999,119.6...|[0.51230983286720...|
|         1.0|       0.0|[147.399994,147.5...|[0.51094041179876...|
|         1.0|       0.0|[152.75,153.69000...|[0.51492470177133...|
+--------