Url: https://tbrain.trendmicro.com.tw/Competitions/Details/2

In [1]:
#import
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyspark.sql import Row
from pyspark.sql.functions import col, udf, lag, rank, lit
from pyspark.sql.window import Window

In [2]:
global Path
if sc.master[0:5]=="local":
    Path = "file:/c:/D Drive/work/bigData/pySpark/TBrain_Round2_DataSet_20180427"
    #Path = "file:/Users/yungchuanlee/Documents/learn/AI競賽/ETF預測/TBrain_Round2_DataSet_20180427"
    #Path = "file:/home/hduser/app/bigdata/competition/etf/TBrain_Round2_DataSet_20180427"
else:
    Path = "hdfs://master:9000/user/hduser"

In [3]:
sc.master

'local[*]'

In [4]:
#define alias of columns
col_alias_etf= {"代碼":"etf_id", "日期": "etf_date", "中文簡稱": "etf_name", "開盤價(元)":"etf_open", 
            "最高價(元)":"etf_high", "最低價(元)":"etf_low", "收盤價(元)":"etf_close", "成交張數(張)":"etf_count"}
col_alias_stock= {"代碼":"stock_id", "日期": "stock_date", "中文簡稱": "stock_name", "開盤價(元)":"stock_open", 
            "最高價(元)":"stock_high", "最低價(元)":"stock_low", "收盤價(元)":"stock_close", "成交張數(張)":"stock_count"}

In [5]:
#udf
def to_double(str_val):
    return float(str_val.replace(",",""))
to_double=udf(to_double)

In [6]:
#def function to read data (因檔案格式都相同)
def read_data(file_name, col_alias):
    str_cols = ["代碼","日期", "中文簡稱"]
    raw_data = spark.read.option("encoding", "Big5").csv(Path + "/" + file_name, header=True, sep=",")
    print("Total " + file_name + " count: " + str(raw_data.count()))
    #rename cols and correct type 
    num_cols = [col_name for col_name in raw_data.columns if col_name not in str_cols]
    final_data=raw_data.select( [col(str_col_name).alias(col_alias[str_col_name]) for str_col_name in str_cols] + 
                                  [to_double(col(num_col_name)).cast("double").alias(col_alias[num_col_name]) for num_col_name in num_cols] )
    final_data.printSchema()
    final_data.show(5)
    return final_data

In [7]:
print("starting import tetfp.csv(台灣18檔ETF股價資料)...")
tetfp_dt=read_data("tetfp.csv", col_alias_etf)

starting import tetfp.csv(台灣18檔ETF股價資料)...
Total tetfp.csv count: 6515
root
 |-- etf_id: string (nullable = true)
 |-- etf_date: string (nullable = true)
 |-- etf_name: string (nullable = true)
 |-- etf_open: double (nullable = true)
 |-- etf_high: double (nullable = true)
 |-- etf_low: double (nullable = true)
 |-- etf_close: double (nullable = true)
 |-- etf_count: double (nullable = true)

+-------+--------+----------------+--------+--------+-------+---------+---------+
| etf_id|etf_date|        etf_name|etf_open|etf_high|etf_low|etf_close|etf_count|
+-------+--------+----------------+--------+--------+-------+---------+---------+
|0050   |20130102|元大台灣50          |    54.0|   54.65|   53.9|     54.4|  16487.0|
|0050   |20130103|元大台灣50          |    54.9|   55.05|  54.65|    54.85|  29020.0|
|0050   |20130104|元大台灣50          |   54.85|   54.85|   54.4|     54.5|   9837.0|
|0050   |20130107|元大台灣50          |   54.55|   54.55|   53.9|    54.25|   8910.0|
|0050   |20130108|元大台灣50      

In [8]:
#EDA
#range of date
tetfp_dt.describe('etf_date').show()

+-------+--------------------+
|summary|            etf_date|
+-------+--------------------+
|  count|                6515|
|   mean|2.0152356018419035E7|
| stddev|  15378.548560182882|
|    min|            20130102|
|    max|            20180427|
+-------+--------------------+



In [27]:
import sys
from pyspark.sql.functions import lag, col, avg,collect_list, lit
from pyspark.sql.window import Window
from pyspark.sql.types import ArrayType, DoubleType, IntegerType
#declare previous row windows
wsSpec_etf = Window.partitionBy('etf_id').orderBy('etf_date') #time window for normal case
wsSpec_etf_close_price_raw = Window.partitionBy('etf_id').orderBy('row_idx').rangeBetween(-sys.maxsize, -1)
wsSpec_etf_dif_raw = Window.partitionBy('etf_id').orderBy('row_idx').rangeBetween(-sys.maxsize, 0)
def avg_list(p_list):
    #計算數字list的平均值
    return sum(p_list)/len(p_list)
#計算EMA的udf
def calculate_ema_native(close_p_list, window_len):
    #透過歷史收盤價計算
    if len(close_p_list) < window_len:
        return None
    elif len(close_p_list) == window_len:
        #if len of list = win_len then return avg, 
        return avg_list(close_p_list)
    else:
        #else EMA[t] =(EMA[t-1]*(win_len-1)+close[t]*2)/(win_len+1)
        ema = avg_list(close_p_list[:window_len])
        for price in close_p_list[window_len:]:
            ema = (ema*(window_len-1)+price*2)/(window_len+1)
        return ema
calculate_ema=udf(calculate_ema_native, DoubleType())
#計算BIAS的udf
def calculate_bias(close_p_list):
    #計算前日收盤價與N日均線之差比: (close price - MA)/MA   ,Paper 建議用20日MA
    #因要預測今日的收盤價，故計算前日收盤價與前20日均線
    if len(close_p_list) < 21:
        return None
    else:
        list_len = len(close_p_list)
        p_close = close_p_list[-1]
        cal_list = close_p_list[list_len-21: list_len-1]
        return p_close - avg_list(cal_list)
calculate_bias=udf(calculate_bias, DoubleType())

def get_min_max_last(p_list):
    #找出list中最大最小和最後一個值, 回傳(min, max, last)
    return (min(p_list), max(p_list), p_list[-1])
def calculate_raw_rsv(p_list):
    #RSV = (收盤價-9日低值)/(9日高值-9日低值)
    p_min, p_max, p_last = get_min_max_last(p_list)
    rsv = (p_last - p_min)/(p_max - p_min)
    return rsv
def calculate_rsv(p_9_list, k_prev, d_prev):
    #計算加權後的RSV，p_9_list=>9日收盤價
    rrsv = calculate_raw_rsv(p_9_list)
    k_curr = (1/3)*rrsv + (2/3)*k_prev
    d_curr = (1/3)*k_curr + (2/3)*d_prev
    return [k_curr, d_curr]
#計算隨機指標（Stochastic Oscillator，KD），原名%K&%D
def calculate_KD(close_p_list):
    win_len = 9 #看過去 9 日值
    #RSV = (收盤價-9日低值)/(9日高值-9日低值)
    #K_curr = 1/3*RSV + 2/3*K_prev
    #D_curr = 1/3*K_curr + 2/3*D_prev
    if len(close_p_list) < win_len:
        return None
    elif len(close_p_list) == win_len:
        #無前日K, D時，以0.5帶入
        return calculate_rsv(close_p_list, 0.5, 0.5)
    else:
        kds = calculate_rsv(close_p_list[0:9], 0.5, 0.5)
        for idx in range(1, (len(close_p_list)+1-9)):
            p_9_list = close_p_list[idx: idx+9]
            kds = calculate_rsv(p_9_list, kds[0], kds[1])
        return kds
calculate_KD=udf(calculate_KD, ArrayType(DoubleType()))

#計算差離值DIF = 12日EMA - 26日EMA
def calculate_DIF(close_p_list):
    if len(close_p_list) < 26:
        return None
    else:
        ema12 = calculate_ema_native(close_p_list, 12)
        ema26 = calculate_ema_native(close_p_list, 26)
        return ema12 - ema26
calculate_DIF=udf(calculate_DIF, DoubleType())

#計算MACD=(前一日MACD × (9 - 1) + 今日DIF × 2) ÷ (9 + 1)
def calculate_MACD(dif_list, dif_curr):
    win_len = 9
    if len(dif_list) < win_len:
        return None
    elif len(dif_list) == win_len:
        #if len of list = win_len then return avg, 
        return avg_list(dif_list)
    else:
        #MACD=(前一日MACD × (9 - 1) + 今日DIF × 2) ÷ (9 + 1)
        macd = avg_list(dif_list[:win_len])
        for price in dif_list[win_len:]:
            macd = (macd*(win_len-1)+dif_curr*2)/(win_len+1)
        return macd
calculate_MACD=udf(calculate_MACD, DoubleType())

#計算相對強弱指數(RSI)
def calculate_RSI(close_p_list):
    win_len = 9
    if len(close_p_list) < (win_len + 1):
        return None
    else:
        cur_list = close_p_list[1:]
        prv_list = close_p_list[0:-1]
        p_dif_list = list(map(lambda x,y : x - y, cur_list, prv_list)) #dif list
        u_list = []
        d_list = []
        for dif in p_dif_list:
            if dif == 0:
                #若兩天價格相同，則U及D皆等於零
                u_list.append(0)
                d_list.append(0)
            elif dif > 0:
                #在價格上升的日子, U = diff, D = 0
                u_list.append(dif)
                d_list.append(0)
            else:
                #在價格下跌的日子, U = 0, D = abs(diff)
                u_list.append(0)
                d_list.append(abs(dif))
        #RSI = ema(u,9)/(ema(u,9)+ema(d,9))
        ema_u = calculate_ema_native(u_list, win_len)
        ema_d = calculate_ema_native(d_list, win_len)
        return ema_u/(ema_u + ema_d)
calculate_RSI=udf(calculate_RSI, DoubleType())

#計算威廉指標（Williams %R）
def calculate_WR(close_p_list):
    win_len = 9
    if len(close_p_list) < win_len:
        return None
    else:
        p_list = close_p_list[len(close_p_list) - win_len :]
        return 1.0 - calculate_raw_rsv(p_list)
calculate_WR=udf(calculate_WR, DoubleType())

#計算上或下的值
def judge_up_down_native(curr_price, close_p_list):
    prev_price = 0.0
    if len(close_p_list) < 1:
        prev_price = curr_price
    else:
        prev_price = close_p_list[-1]
    if curr_price == prev_price:
        return 0.0
    elif curr_price > prev_price:
        return 1.0
    else:
        return 2.0
judge_up_down=udf(judge_up_down_native, DoubleType())

#calculate diff between close_price and last_price
def calculate_price_diff(curr_price, close_p_list):
    if len(close_p_list) < 1:
        return None
    return curr_price - close_p_list[-1]
calculate_price_diff=udf(calculate_price_diff, DoubleType())

In [28]:
#calculate ema [5,10,20] #cannot remove row_idx, row_idx for next window usage
tetfp_dt2=tetfp_dt.withColumn("row_idx", rank().over(wsSpec_etf)) \
    .withColumn("close_price_raw", collect_list(col('etf_close')).over(wsSpec_etf_close_price_raw)) \
    .withColumn("EMA5", calculate_ema(col("close_price_raw"), lit(5))) \
    .withColumn("EMA10", calculate_ema(col("close_price_raw"), lit(10))) \
    .withColumn("EMA20", calculate_ema(col("close_price_raw"), lit(20))) \
    .withColumn("BIAS", calculate_bias(col("close_price_raw"))) \
    .withColumn("KD", calculate_KD(col("close_price_raw"))) \
    .withColumn("K", col("KD")[0]).withColumn("D", col("KD")[1]) \
    .withColumn("DIF", calculate_DIF(col("close_price_raw"))) \
    .withColumn("dif_list", collect_list(col('DIF')).over(wsSpec_etf_dif_raw)) \
    .withColumn("MACD", calculate_MACD(col("dif_list"), col("DIF"))) \
    .withColumn("RSI", calculate_RSI(col("close_price_raw")))\
    .withColumn("WR", calculate_WR(col("close_price_raw"))) \
    .withColumn("price_dif", calculate_price_diff(col("etf_close"), col("close_price_raw"))) \
    .withColumn("up_down", judge_up_down(col("etf_close"), col("close_price_raw")))

tetfp_dt2.cache()
tetfp_dt2.printSchema()

root
 |-- etf_id: string (nullable = true)
 |-- etf_date: string (nullable = true)
 |-- etf_name: string (nullable = true)
 |-- etf_open: double (nullable = true)
 |-- etf_high: double (nullable = true)
 |-- etf_low: double (nullable = true)
 |-- etf_close: double (nullable = true)
 |-- etf_count: double (nullable = true)
 |-- row_idx: integer (nullable = true)
 |-- close_price_raw: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- EMA5: double (nullable = true)
 |-- EMA10: double (nullable = true)
 |-- EMA20: double (nullable = true)
 |-- BIAS: double (nullable = true)
 |-- KD: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- K: double (nullable = true)
 |-- D: double (nullable = true)
 |-- DIF: double (nullable = true)
 |-- dif_list: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- MACD: double (nullable = true)
 |-- RSI: double (nullable = true)
 |-- WR: double (nullable = true)
 |-- price_dif: double

In [29]:
tot_dt = tetfp_dt2.filter("MACD is not null") \
    .select("etf_id", "etf_date", "EMA5", "EMA10", "EMA20", "BIAS", "K", "D", "DIF", "MACD", "RSI", "WR", "etf_close","price_dif","up_down") \
    .orderBy("etf_id", "etf_date", ascending=True)
tot_dt.show(20)

+-------+--------+------------------+------------------+------------------+--------------------+-------------------+-------------------+--------------------+--------------------+-------------------+-------------------+---------+--------------------+-------+
| etf_id|etf_date|              EMA5|             EMA10|             EMA20|                BIAS|                  K|                  D|                 DIF|                MACD|                RSI|                 WR|etf_close|           price_dif|up_down|
+-------+--------+------------------+------------------+------------------+--------------------+-------------------+-------------------+--------------------+--------------------+-------------------+-------------------+---------+--------------------+-------+
|0050   |20130227| 55.48233888645433|  55.4036427292022|55.109781220405175| 0.24749999999999517|   0.48004877445193| 0.6936407134791391|  0.4425475841886879| 0.46653575041343565| 0.3681396149827187|                1.0|     55.

In [30]:
from pyspark.ml.feature import MinMaxScaler, StandardScaler
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
#將Feature合併為Vector 並作標準化
assembler = VectorAssembler(
    inputCols=["EMA5", "EMA10", "EMA20", "BIAS", "K", "D", "DIF", "MACD", "RSI", "WR"],
    outputCol="features")
tot_dt_1 = assembler.transform(tot_dt)
#minmax_scaler = MinMaxScaler(inputCol="features", outputCol="stdFeatures")
#scaler_model = minmax_scaler.fit(tot_dt_1)
std_scaler = StandardScaler(inputCol="features", outputCol="stdFeatures")
scaler_model = std_scaler.fit(tot_dt_1)
tot_dt_scale = scaler_model.transform(tot_dt_1)


In [31]:
#取出4/16~4/27 (共兩週資料作為測試集)
train_dt = tot_dt_scale.filter("etf_date < '20180416' and MACD is not null") \
    .select("etf_id", "etf_date", "stdFeatures", "price_dif","up_down") \
    .orderBy("etf_id", "etf_date", ascending=True)
test_dt = tot_dt_scale.filter("etf_date >= '20180416'") \
    .select("etf_id", "etf_date", "stdFeatures", "price_dif","up_down") \
    .orderBy("etf_id", "etf_date", ascending=True)
print('train count: ', str(train_dt.count()), ', test count: ', str(test_dt.count()))
train_dt.show(10)
test_dt.show(10)
train_dt.cache()
test_dt.cache()

train count:  6295 , test count:  50
+-------+--------+--------------------+--------------------+-------+
| etf_id|etf_date|         stdFeatures|           price_dif|up_down|
+-------+--------+--------------------+--------------------+-------+
|0050   |20130227|[3.256641619336,3...| 0.10000000000000142|    1.0|
|0050   |20130301|[3.25111747966177...| 0.19999999999999574|    1.0|
|0050   |20130304|[3.25134784693247...| -0.6499999999999986|    2.0|
|0050   |20130305|[3.23878376218901...| 0.45000000000000284|    1.0|
|0050   |20130306|[3.23921224156378...|                0.25|    1.0|
|0050   |20130307|[3.24438930329719...|-0.05000000000000426|    2.0|
|0050   |20130308|[3.24686239602275...|  0.3999999999999986|    1.0|
|0050   |20130311|[3.25633737861350...| 0.10000000000000142|    1.0|
|0050   |20130312|[3.26461059720076...| -0.3500000000000014|    2.0|
|0050   |20130313|[3.26327810391527...| 0.10000000000000142|    1.0|
+-------+--------+--------------------+--------------------+------

DataFrame[etf_id: string, etf_date: string, stdFeatures: vector, price_dif: double, up_down: double]

In [17]:
#取出etf的distinct id
etf_ids = []
for row in test_dt.select("etf_id").distinct().collect():
    etf_ids.append(row["etf_id"])
etf_ids

['0051   ', '0052   ', '0050   ', '0054   ', '0053   ']

In [18]:
#訓練Model及評估(RandomForestRegressor)
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
rf_c = RandomForestClassifier(featuresCol="stdFeatures",labelCol="up_down")
rf_r = RandomForestRegressor(featuresCol="stdFeatures",labelCol="price_dif")
predit_res = None
for etfid in etf_ids:
    train_data = train_dt.filter("etf_id='" + etfid + "'")
    test_data = test_dt.filter("etf_id='" + etfid + "'")
    #訓練判斷上升或下降
    rf_model_c = rf_c.fit(train_data)
    pred_c = rf_model_c.transform(test_data)
    #訓練判斷差額
    rf_model_r = rf_r.fit(train_data)
    pred_r = rf_model_r.transform(test_data)
    #合併結果輸出
    predicts = pred_c.join(pred_r, ["etf_id", "etf_date"]) \
        .select(pred_c.etf_id, pred_c.etf_date, pred_c.up_down, pred_c.prediction.alias("up_down_pred"), 
                pred_r.price_dif, pred_r.prediction.alias("price_dif_pred"))
    if predit_res is None:
        predit_res = predicts
    else:
        predit_res = predit_res.unionAll(predicts)
predit_res.show(10)
#評估差價預測的RMES
evaluator = RegressionEvaluator(
    labelCol="price_dif", predictionCol="price_dif_pred", metricName="rmse")
rmse = evaluator.evaluate(predit_res)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
#評估漲跌預測的Accuracy
evaluator = MulticlassClassificationEvaluator(
    labelCol="up_down", predictionCol="up_down_pred", metricName="accuracy")
accuracy = evaluator.evaluate(predit_res)
print("accuracy = %g " % accuracy)



+-------+--------+-------+------------+--------------------+-------------------+
| etf_id|etf_date|up_down|up_down_pred|           price_dif|     price_dif_pred|
+-------+--------+-------+------------+--------------------+-------------------+
|0051   |20180416|    1.0|         1.0|                0.25| 0.1696036418869628|
|0051   |20180417|    2.0|         2.0| 0.17999999999999972|0.17181694570802525|
|0051   |20180418|    2.0|         1.0|                0.25|0.18179509929343066|
|0051   |20180419|    1.0|         1.0|  0.3999999999999986| 0.2123318169994039|
|0051   |20180420|    2.0|         1.0| 0.12999999999999545|0.18746735447703192|
|0051   |20180423|    2.0|         1.0|0.020000000000003126|  0.166848623705315|
|0051   |20180424|    2.0|         1.0|  0.5600000000000023| 0.1872001698780043|
|0051   |20180425|    2.0|         1.0|  0.4199999999999946|0.18760086273909177|
|0051   |20180426|    2.0|         1.0|  0.3100000000000023| 0.7313826358743557|
|0051   |20180427|    1.0|  

In [26]:
#訓練Model及評估(RandomForestRegressor in one time, predict updown and price diff) => no different with etf_id wise
# --- accuacy: 0.26, (RMSE) on test data = 0.415976
rf_c = RandomForestClassifier(featuresCol="stdFeatures",labelCol="up_down")
rf_r = RandomForestRegressor(featuresCol="stdFeatures",labelCol="price_dif")

train_data = train_dt
test_data = test_dt
#訓練判斷上升或下降
rf_model_c = rf_c.fit(train_data)
pred_c = rf_model_c.transform(test_data)
#訓練判斷差額
rf_model_r = rf_r.fit(train_data)
pred_r = rf_model_r.transform(test_data)
    #合併結果輸出
predit_res = pred_c.join(pred_r, ["etf_id", "etf_date"]) \
        .select(pred_c.etf_id, pred_c.etf_date, pred_c.up_down, pred_c.prediction.alias("up_down_pred"), 
                pred_r.price_dif, pred_r.prediction.alias("price_dif_pred"))
predit_res.show(10)

#評估差價預測的RMES
evaluator = RegressionEvaluator(
    labelCol="price_dif", predictionCol="price_dif_pred", metricName="rmse")
rmse = evaluator.evaluate(predit_res)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
#評估漲跌預測的Accuracy
evaluator = MulticlassClassificationEvaluator(
    labelCol="up_down", predictionCol="up_down_pred", metricName="accuracy")
accuracy = evaluator.evaluate(predit_res)
print("accuracy = %g " % accuracy)



+-------+--------+-------+------------+-------------------+-------------------+
| etf_id|etf_date|up_down|up_down_pred|          price_dif|     price_dif_pred|
+-------+--------+-------+------------+-------------------+-------------------+
|0050   |20180416|    2.0|         1.0|0.20000000000000284|0.44624315880840426|
|0050   |20180417|    2.0|         1.0| 0.8999999999999915| 0.4529185078290195|
|0050   |20180418|    1.0|         1.0|0.19999999999998863|0.49852523062419296|
|0050   |20180419|    1.0|         1.0| 1.0500000000000114|0.49197118941790424|
|0050   |20180420|    2.0|         1.0| 1.9000000000000057|0.43374005416860156|
|0050   |20180423|    2.0|         1.0| 0.7999999999999972| 0.5037779058673211|
|0050   |20180424|    2.0|         1.0| 0.4000000000000057| 0.5538833105968822|
|0050   |20180425|    2.0|         1.0|               0.25|  0.583693679283252|
|0050   |20180426|    2.0|         1.0|               0.25| 0.5900041727944381|
|0050   |20180427|    1.0|         1.0|0

In [32]:
#計算上或下的值
def judge_up_down_eval(price_diff):
    if price_diff == 0:
        return 0.0
    elif price_diff > 0:
        return 1.0
    else:
        return 2.0
judge_up_down_eval=udf(judge_up_down_eval, DoubleType())

#訓練Model及評估(RandomForestRegressor in one time, just predict price diff) => no different with etf_id wise
rf_r = RandomForestRegressor(featuresCol="stdFeatures",labelCol="price_dif")

train_data = train_dt
test_data = test_dt

#訓練判斷差額
rf_model_r = rf_r.fit(train_data)
pred_r = rf_model_r.transform(test_data)
#合併結果輸出
predit_res = pred_r \
    .select(pred_r.etf_id, pred_r.etf_date, pred_r.up_down, pred_r.price_dif, pred_r.prediction.alias("price_dif_pred")) \
    .withColumn("up_down_pred", judge_up_down_eval(col("price_dif_pred")))
predit_res.show(10)

#評估差價預測的RMES
evaluator = RegressionEvaluator(
    labelCol="price_dif", predictionCol="price_dif_pred", metricName="rmse")
rmse = evaluator.evaluate(predit_res)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
#評估漲跌預測的Accuracy
evaluator = MulticlassClassificationEvaluator(
    labelCol="up_down", predictionCol="up_down_pred", metricName="accuracy")
accuracy = evaluator.evaluate(predit_res)
print("accuracy = %g " % accuracy)



+-------+--------+-------+--------------------+--------------------+------------+
| etf_id|etf_date|up_down|           price_dif|      price_dif_pred|up_down_pred|
+-------+--------+-------+--------------------+--------------------+------------+
|0050   |20180416|    2.0|-0.20000000000000284|0.026065509555671402|         1.0|
|0050   |20180417|    2.0| -0.8999999999999915|0.011225144871080348|         1.0|
|0050   |20180418|    1.0| 0.19999999999998863|0.053681656447893536|         1.0|
|0050   |20180419|    1.0|  1.0500000000000114| 0.03676159650133974|         1.0|
|0050   |20180420|    2.0| -1.9000000000000057| 0.06250014055236705|         1.0|
|0050   |20180423|    2.0| -0.7999999999999972| 0.03219234658751645|         1.0|
|0050   |20180424|    2.0| -0.4000000000000057| 0.05030372616741262|         1.0|
|0050   |20180425|    2.0|               -0.25|0.025122773786460245|         1.0|
|0050   |20180426|    2.0|               -0.25|  0.1449347225491161|         1.0|
|0050   |2018042

In [47]:
tetf_max_idx = tetfp_dt2.groupBy("etf_id").max("row_idx")
tetf_max_idx.select(col("etf_id"), col("max(row_idx)").cast("Double").alias("row_idx")).printSchema()
tetf_max = tetf_max_idx.select(col("etf_id"), col("max(row_idx)").cast("Double").alias("row_idx")) \
    .join(tetfp_dt2, ["etf_id", "row_idx"], "inner") \
    .select(tetfp_dt2.etf_id, tetfp_dt2.etf_date, tetfp_dt2.etf_close, tetfp_dt2.row_idx, tetfp_dt2.close_price_raw)
tetf_max.show(10)
#ridx_dic = {}
#for row in ridx_dic_raw:
#    ridx_dic.update({row["etf_id"], row["row_idx"]})


root
 |-- etf_id: string (nullable = true)
 |-- row_idx: double (nullable = true)

+-------+
| etf_id|
+-------+
|0051   |
|0052   |
|0050   |
|0054   |
|0053   |
+-------+

