In [2]:
sc

In [12]:
# -*- coding: utf-8 -*-

import os
from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql.functions import *

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, VectorIndexer
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

class Utils():
    def __init__(self):
        pass

    # 敘述性統計：平均數 標準差
    def getStatValue(self, df, fieldName):
        stat = df.select(avg(fieldName), stddev(fieldName)).collect()
        return stat[0]

class LoadSavedData(Utils):
    # 繼承
    def __init__(self):
        Utils.__init__(self)

    # 載入資料集檔案
    def loadData(self, dataFile):
        sql = 'SELECT * FROM parquet.`%s`' % dataFile
        df = sqlContext.sql(sql)
        return df

    # 列印敘述性統計
    def printStats(self, df, fields=None):
        if fields is None:
            df.describe().show()
        else:
            for field in fields:
                df.describe(field).show()

    # 決策樹
    def DT(self, trainingData, testData, labelIndexer, features):
        # 組合自變數欄位群，並指明衍生欄位名稱
        '''
            例如 m1=2 m2=3 n1= 4 n2=5
            features會把它組合成
            ['2','3','4','5']
        '''
        features = (VectorAssembler()
                        .setInputCols(features)
                        .setOutputCol('features'))

        # 取得決策樹介面
        dt = DecisionTreeClassifier(labelCol='indexedLabel', featuresCol='features')

        # 進行決策樹分析
        pipeline = Pipeline(stages=[labelIndexer, features, dt])

        # 產生決策樹分析模型
        model = pipeline.fit(trainingData)

        # 推測值
        predictions = model.transform(testData)

        return predictions

    # 列印決策樹分析結果
    def printStatsDT(self, predictions):
        # 篩選分析結果欄位群
        # indexedLabel 算出來的值
        # prediction 演算法預測的值
        # features [開盤,尾盤]
        result = predictions.select('indexedLabel', 'prediction', 'features', 'probability')

        # 篩選預測錯誤資料
        resultError = result.where(result.indexedLabel != result.prediction)
        resultError.show()

        print(u'準確率=%.3f (%d\t%d)' % (1.000 - resultError.count() / result.count(),
                resultError.count(),
                result.count()))

def m_fun(o,c):
    if o>c:
        return 0
    else:
        return 1

# 主程式
def main(dataDir):
    # 類別初始化
    worker = LoadSavedData()

    # 載入資料集
    df = worker.loadData(dataFile='%s/IBM.parquet' % dataDir)
     
    my_m = udf(m_fun, IntegerType())
    df = df.withColumn('result', my_m('open', 'close'))
    
    df.show()
    # 資料隨機抽樣成二群
    # 將資料分成2群，一群為trainingData,另一個為testData,經過測試 7:3比例ok
    (trainingData, testData) = df.randomSplit([0.7, 0.3])

    # 為類別值建立數值對照表
    # 建立對照表, 輸入的欄位為shape，輸出的為indexedLabel
   
    labelIndexer = StringIndexer(inputCol='result', outputCol='indexedLabel').fit(df)
    
    # 決策樹：指定自變數欄位群
    # 載入worker.DT
    result = worker.DT(trainingData, testData, labelIndexer, df.columns[1:5])
    result.printSchema()

    # 列印決策樹分析結果
    worker.printStatsDT(result)

# 程式進入點
if __name__ == '__main__':
    global sc, sqlContext

    # 本地資源運算
    appName = 'Cup-11'
    master = 'local'

    #sc = SparkContext(conf=SparkConf().setAppName(appName).setMaster(master))

    # 取得資料庫介面
    sqlContext = SQLContext(sc)

    # 調用主程式
    homeDir = os.environ['HOME']
    dirName = 'Data'
    sampleDir = '%s/Sample' % homeDir
    dataDir = '%s/Data' % homeDir

    main(dataDir)

+----------+----------+----------+----------+----------+---------+----------+------+
|      date|      open|      high|       low|     close|   volumn|  adjclose|result|
+----------+----------+----------+----------+----------+---------+----------+------+
|2016-12-30|166.440002|166.699997|     165.5|165.990005|2952800.0|164.687836|     0|
|2016-12-29|166.020004|166.990005|     166.0|166.600006|1663500.0|165.293051|     1|
|2016-12-28|167.289993|167.740005|     166.0|166.190002|1757500.0|164.886264|     0|
|2016-12-27|166.979996|167.979996|166.850006|167.139999|1397500.0|165.828809|     1|
|2016-12-23|     167.0|167.490005|166.449997|166.710007|1701200.0|165.402189|     0|
|2016-12-22|167.360001|168.229996|166.580002|167.059998|2802600.0|165.749434|     0|
|2016-12-21|    166.25|167.940002|    165.25|167.330002|3575700.0|166.017321|     1|
|2016-12-20|167.490005|    168.25|166.449997|167.600006|2174600.0|166.285207|     1|
|2016-12-19|166.830002|167.259995|     166.0|166.679993|2955900.0