In [1]:
import os
from datetime import datetime

import IPython
import IPython.display

import numpy as np
import pandas as pd
import seaborn as sns

import csv
from csv import writer

#custom functions import
from utility.finance_formula import rsi
from data_func.time_window_torch import pandas_window_single_step
from data_func.df_preprocessing import convert_to_pct
from data_func.ts_dataset import Ts_dataset
from utility.utils import mape as mape_custom
from utility.utils import MAE as MAE_custom


#plotting setting
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['axes.grid'] = False


#pytorch 
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

#sklearn
from sklearn.preprocessing import StandardScaler


#models
from models.stacked_lstm import Stacked_lstm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
#option flags
save_csv = True

time_periodicity = False

debug_mode = True
debug_output_n = 2
financial_feature_engineering = False

performance_log = True
data_log_csv_path = "training_log/pytorch_test2.csv"

In [4]:
fileDir = os.path.dirname(os.path.realpath('__file__'))
filename = os.path.join(fileDir, r"stock_data\index300_05_22_day.csv")

# Data Preprocessing

In [5]:
drop_list = ['change']

df = pd.read_csv(filename).drop(drop_list, axis=1)

In [6]:
if debug_mode == True:
    display(df)

Unnamed: 0,ts_code,trade_date,open,high,low,close,pre_close,pct_chg,vol,amount
0,600783.SH,20221215,13.63,13.63,13.38,13.58,13.55,0.2214,38397.67,51819.5570
1,600783.SH,20221214,13.54,13.82,13.42,13.55,13.46,0.6686,57557.68,78307.7080
2,600783.SH,20221213,13.12,13.65,13.10,13.46,13.24,1.6616,74087.26,99829.4500
3,600783.SH,20221212,13.25,13.46,13.04,13.24,13.29,-0.3762,60428.45,79917.9930
4,600783.SH,20221209,13.56,13.60,13.19,13.29,13.53,-1.7738,66449.00,88408.6250
...,...,...,...,...,...,...,...,...,...,...
1743021,002064.SZ,20060829,9.60,9.79,9.50,9.65,9.63,0.2100,58861.23,56827.0509
1743022,002064.SZ,20060828,9.61,9.91,9.42,9.63,9.54,0.9400,75694.74,73239.1825
1743023,002064.SZ,20060825,9.22,10.00,9.13,9.54,9.28,2.8000,121564.97,115534.6703
1743024,002064.SZ,20060824,8.91,9.37,8.85,9.28,9.18,1.0900,110143.17,100371.4400


In [7]:
#split data into different stock
# stock_list = [x.sort_values("trade_date", ascending = True) for _,x in df.groupby('ts_code')]

stock_list = []

for ts_code, x in df.groupby('ts_code'):
    stock_list.append(x.copy().sort_values("trade_date", ascending = True).reset_index().drop('index', axis=1))

hs300_stock_pct_list = ["open", "high", "low"]

for i in stock_list:
    convert_to_pct(i, hs300_stock_pct_list)


#sort by trade_date, ascending
if debug_mode:
    print("number of stocks : ", len(stock_list))

number of stocks :  577


## add hs300 as a proxy

In [8]:
#add hs300 as a proxy

hs300_filename = os.path.join(fileDir, r"stock_data\hs300index_05_22_day.csv")

df_hs300 = pd.read_csv(hs300_filename).drop(['index_ts_code', 'index_vol', 'index_change', 'index_amount'], axis=1).sort_values("index_trade_date", ascending = True)

hs300_index_pct_list = ["index_open", "index_high", "index_low"]

convert_to_pct(df_hs300, hs300_index_pct_list)

display(df_hs300)

Unnamed: 0,index_trade_date,index_open,index_high,index_low,index_close,index_pre_close,index_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
4285,20050509,909.1740,934.6500,937.3900,909.1740,932.3950,-2.4905,,,
4284,20050510,913.0760,905.5430,913.3880,892.3130,909.1740,0.4292,0.429181,-3.114214,-2.560514
4283,20050511,901.8510,911.8380,917.2230,900.4380,913.0760,-1.2294,-1.229361,0.695163,0.419865
4282,20050512,885.8200,899.9680,900.0630,883.5110,901.8510,-1.7776,-1.777566,-1.301766,-1.870865
4281,20050513,887.5430,883.5050,898.5050,875.5760,885.8200,0.1945,0.194509,-1.829287,-0.173099
...,...,...,...,...,...,...,...,...,...,...
4,20221209,3998.2442,3961.9919,4003.3178,3944.4396,3959.1798,0.9867,0.986679,0.222425,0.801338
3,20221212,3953.4433,3976.1722,3983.4332,3950.3203,3998.2442,-1.1205,-1.120514,0.357908,-0.496703
2,20221213,3945.6813,3953.5482,3964.3957,3939.9795,3953.4433,-0.1963,-0.196335,-0.568989,-0.477917
1,20221214,3954.8857,3952.7885,3972.7381,3935.7668,3945.6813,0.2333,0.233278,-0.019216,0.210433


In [9]:
#testing date matching

# stock_list[2]['trade_date']
# print(len(stock_list[2]['trade_date']))

# tx = df_hs300[df_hs300['index_trade_date'].isin(stock_list[2]['trade_date'])].sort_values("index_trade_date", ascending = True).reset_index().drop('index', axis=1)
# display(tx)

# display(stock_list[2])

# display(df_hs300)

# txp = pd.concat([stock_list[2], tx], axis=1, join='inner')
# display(txp)

In [10]:
#concat

for i in range(len(stock_list)):

    temp = df_hs300[df_hs300['index_trade_date'].isin(stock_list[i]['trade_date'])].sort_values("index_trade_date", ascending = True).reset_index().drop('index', axis=1)
    stock_list[i] = pd.concat([stock_list[i], temp], axis=1, join='outer').drop('index_trade_date', axis=1)




In [11]:
# #check date matching
# error_ct = 0



# for i in stock_list:
#     if not i["trade_date"].equals(i["index_trade_date"]):
#         #print("index time stamp error")
#         error_ct += 1
# print(error_ct)

In [12]:
display(stock_list[0])

Unnamed: 0,ts_code,trade_date,open,high,low,close,pre_close,pct_chg,vol,amount,...,low_pct_change,index_open,index_high,index_low,index_close,index_pre_close,index_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,000001.SZ,20050509,6.23,6.27,5.98,6.09,6.20,-1.7700,96939.11,5.942849e+04,...,,909.1740,934.6500,937.3900,909.1740,932.3950,-2.4905,,,
1,000001.SZ,20050510,6.09,6.36,5.97,6.31,6.09,3.6100,108414.13,6.733231e+04,...,-0.167224,913.0760,905.5430,913.3880,892.3130,909.1740,0.4292,0.429181,-3.114214,-2.560514
2,000001.SZ,20050511,6.28,6.40,6.12,6.19,6.31,-1.9000,106868.37,6.681369e+04,...,2.512563,901.8510,911.8380,917.2230,900.4380,913.0760,-1.2294,-1.229361,0.695163,0.419865
3,000001.SZ,20050512,6.18,6.34,6.14,6.19,6.19,0.0000,79396.48,4.938704e+04,...,0.326797,885.8200,899.9680,900.0630,883.5110,901.8510,-1.7776,-1.777566,-1.301766,-1.870865
4,000001.SZ,20050513,6.19,6.24,5.90,6.02,6.19,-2.7500,111481.38,6.729219e+04,...,-3.908795,887.5430,883.5050,898.5050,875.5760,885.8200,0.1945,0.194509,-1.829287,-0.173099
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4135,000001.SZ,20221209,13.40,13.75,13.35,13.70,13.36,2.5449,1615831.92,2.197502e+06,...,1.753049,3998.2442,3961.9919,4003.3178,3944.4396,3959.1798,0.9867,0.986679,0.222425,0.801338
4136,000001.SZ,20221212,13.58,13.58,13.06,13.11,13.70,-4.3066,1392584.28,1.848278e+06,...,-2.172285,3953.4433,3976.1722,3983.4332,3950.3203,3998.2442,-1.1205,-1.120514,0.357908,-0.496703
4137,000001.SZ,20221213,13.15,13.36,13.14,13.24,13.11,0.9916,902285.13,1.194285e+06,...,0.612557,3945.6813,3953.5482,3964.3957,3939.9795,3953.4433,-0.1963,-0.196335,-0.568989,-0.477917
4138,000001.SZ,20221214,13.23,13.40,13.08,13.23,13.24,-0.0755,974061.10,1.287868e+06,...,-0.456621,3954.8857,3952.7885,3972.7381,3935.7668,3945.6813,0.2333,0.233278,-0.019216,0.210433


In [13]:
display(stock_list[0][0:15])

Unnamed: 0,ts_code,trade_date,open,high,low,close,pre_close,pct_chg,vol,amount,...,low_pct_change,index_open,index_high,index_low,index_close,index_pre_close,index_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,000001.SZ,20050509,6.23,6.27,5.98,6.09,6.2,-1.77,96939.11,59428.4874,...,,909.174,934.65,937.39,909.174,932.395,-2.4905,,,
1,000001.SZ,20050510,6.09,6.36,5.97,6.31,6.09,3.61,108414.13,67332.3061,...,-0.167224,913.076,905.543,913.388,892.313,909.174,0.4292,0.429181,-3.114214,-2.560514
2,000001.SZ,20050511,6.28,6.4,6.12,6.19,6.31,-1.9,106868.37,66813.6905,...,2.512563,901.851,911.838,917.223,900.438,913.076,-1.2294,-1.229361,0.695163,0.419865
3,000001.SZ,20050512,6.18,6.34,6.14,6.19,6.19,0.0,79396.48,49387.0372,...,0.326797,885.82,899.968,900.063,883.511,901.851,-1.7776,-1.777566,-1.301766,-1.870865
4,000001.SZ,20050513,6.19,6.24,5.9,6.02,6.19,-2.75,111481.38,67292.1901,...,-3.908795,887.543,883.505,898.505,875.576,885.82,0.1945,0.194509,-1.829287,-0.173099
5,000001.SZ,20050516,6.02,6.02,5.76,5.9,6.02,-1.99,54461.62,32003.5853,...,-2.372881,875.271,885.389,885.389,869.334,887.543,-1.3827,-1.382694,0.213242,-1.459758
6,000001.SZ,20050517,5.9,6.15,5.81,6.06,5.9,2.71,60944.28,36815.7997,...,0.868056,881.462,873.077,888.281,868.212,875.271,0.7073,0.707324,-1.390575,0.326636
7,000001.SZ,20050518,6.06,6.2,6.04,6.14,6.06,1.32,52046.44,31843.2828,...,3.958692,883.196,881.141,890.403,871.821,881.462,0.1967,0.196719,0.92363,0.238888
8,000001.SZ,20050519,6.14,6.27,6.01,6.2,6.14,0.98,72994.06,45054.1898,...,-0.496689,884.171,882.842,888.016,871.289,883.196,0.1104,0.110395,0.193045,-0.268081
9,000001.SZ,20050520,6.22,6.34,6.12,6.21,6.2,0.16,98810.85,61928.4089,...,1.830283,882.763,883.513,891.02,879.18,884.171,-0.1592,-0.159245,0.076005,0.338282


In [14]:
#get name of stocks
stock_name_list = []

for stock in stock_list:

    str = stock.iloc[0,0]
    stock_name_list.append(str)

    stock.pop("ts_code") #pop the stock code
    #stock.pop("pre_close") #pop previous day close price

## add time periodicity

In [15]:
#setup timestamp for year periodicity 
    
if time_periodicity:

    date_time_db = []

    for stock_i in stock_list:
        date_time_db.append(pd.to_datetime(stock_i.pop('trade_date'), format='%Y%m%d'))

    if debug_mode:
        print(date_time_db[0])


    date_time_stamp_db = []

    for dt in date_time_db:
        date_time_stamp_db.append(dt.map(pd.Timestamp.timestamp))


    if debug_mode:
        display(date_time_stamp_db[0])


    year = 24*60*60*(365.2425)

    for i in range(len(stock_list)):
        
        stock_list[i]["Year sin"] = np.sin(date_time_stamp_db[i] * (2 * np.pi / year))
        stock_list[i]["Year cos"] = np.cos(date_time_stamp_db[i] * (2 * np.pi / year))

    if debug_mode:
        plt.plot(np.array(stock_list[0]['Year sin'])[:1000])
        plt.plot(np.array(stock_list[0]['Year cos'])[:1000])
        plt.xlabel('Time [day]')
        plt.title('Time of year signal, in day')

    if debug_mode:
        display(stock_list[5].head())


    if debug_mode:
        display(stock_list[0].head())

## Financial features implementation

In [16]:
if financial_feature_engineering == True:

    ma_label = 'open'

    for stock in stock_list:
        
        #MA

        processing_row = stock[ma_label].to_frame()

        stock['sma5'] = processing_row[ma_label].rolling(5).mean()

        stock['ema5'] = processing_row[ma_label].ewm(span=5).mean()

        stock['ema20'] = processing_row[ma_label].ewm(span=20).mean()

        stock['ema50'] = processing_row[ma_label].ewm(span=50).mean()



        #rsi implementation
        stock['rsi'] = rsi(stock)

        

        # # Golden Cross  implementation
        # np.where(stock['ema20'] > stock['ema50'], 1, 0)

        # stock['ema20_50_GC'] = 0 

        # for i in range(stock.shape[0]):
        #     if i == 0:
        #         pass
        #     elif stock['ema20'][i] >= stock['ema50'][i] and stock['ema20'][i-1] < stock['ema50'][i-1]: #slow code, works for now
        #         stock['ema20_50_GC'][i] = 1



        stock.dropna(inplace=True)


    display(stock_list[0])
    

In [17]:
# #GC value count inspction
# for stock in stock_list:
#     print(stock['ema20_50_GC'].value_counts())



## dropping all the unnecessary features

In [18]:
extra_drop_columns = ["pre_close", "close", "index_pre_close", "index_close"]


for i in range(len(stock_list)):
    stock_list[i] = stock_list[i].drop(hs300_index_pct_list + hs300_stock_pct_list+extra_drop_columns, axis=1).rename(columns={"pct_chg":"close_pct_chg","index_pct_chg":"index_close_pct_chg"})
    stock_list[i].dropna(inplace=True)
    stock_list[i] = stock_list[i].reset_index(drop=True)


display(stock_list[0])

Unnamed: 0,trade_date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20050510,3.6100,108414.13,6.733231e+04,-2.247191,1.435407,-0.167224,0.4292,0.429181,-3.114214,-2.560514
1,20050511,-1.9000,106868.37,6.681369e+04,3.119869,0.628931,2.512563,-1.2294,-1.229361,0.695163,0.419865
2,20050512,0.0000,79396.48,4.938704e+04,-1.592357,-0.937500,0.326797,-1.7776,-1.777566,-1.301766,-1.870865
3,20050513,-2.7500,111481.38,6.729219e+04,0.161812,-1.577287,-3.908795,0.1945,0.194509,-1.829287,-0.173099
4,20050516,-1.9900,54461.62,3.200359e+04,-2.746365,-3.525641,-2.372881,-1.3827,-1.382694,0.213242,-1.459758
...,...,...,...,...,...,...,...,...,...,...,...
4134,20221209,2.5449,1615831.92,2.197502e+06,1.901141,1.701183,1.753049,0.9867,0.986679,0.222425,0.801338
4135,20221212,-4.3066,1392584.28,1.848278e+06,1.343284,-1.236364,-2.172285,-1.1205,-1.120514,0.357908,-0.496703
4136,20221213,0.9916,902285.13,1.194285e+06,-3.166421,-1.620029,0.612557,-0.1963,-0.196335,-0.568989,-0.477917
4137,20221214,-0.0755,974061.10,1.287868e+06,0.608365,0.299401,-0.456621,0.2333,0.233278,-0.019216,0.210433


In [19]:
# if save_csv :

#     for i in range(1):
#         f_name = "stock_" + stock_name_list[i] + ".csv"
#         stock_list[i].to_csv(f_name, index=False)

## standardization

In [20]:
stock_list_stdized = []

scaler = StandardScaler()



for df_i in stock_list:
    
    column_to_std = df_i.columns.tolist()[1:]
    
    temp_df = df_i[column_to_std]
    temp_df = pd.DataFrame(scaler.fit_transform(temp_df))

    x2 = df_i.copy()
    x2[column_to_std] = temp_df
    

    stock_list_stdized.append(x2)

# Final data check

In [21]:
display(stock_list[0])
display(stock_list[50])

Unnamed: 0,trade_date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20050510,3.6100,108414.13,6.733231e+04,-2.247191,1.435407,-0.167224,0.4292,0.429181,-3.114214,-2.560514
1,20050511,-1.9000,106868.37,6.681369e+04,3.119869,0.628931,2.512563,-1.2294,-1.229361,0.695163,0.419865
2,20050512,0.0000,79396.48,4.938704e+04,-1.592357,-0.937500,0.326797,-1.7776,-1.777566,-1.301766,-1.870865
3,20050513,-2.7500,111481.38,6.729219e+04,0.161812,-1.577287,-3.908795,0.1945,0.194509,-1.829287,-0.173099
4,20050516,-1.9900,54461.62,3.200359e+04,-2.746365,-3.525641,-2.372881,-1.3827,-1.382694,0.213242,-1.459758
...,...,...,...,...,...,...,...,...,...,...,...
4134,20221209,2.5449,1615831.92,2.197502e+06,1.901141,1.701183,1.753049,0.9867,0.986679,0.222425,0.801338
4135,20221212,-4.3066,1392584.28,1.848278e+06,1.343284,-1.236364,-2.172285,-1.1205,-1.120514,0.357908,-0.496703
4136,20221213,0.9916,902285.13,1.194285e+06,-3.166421,-1.620029,0.612557,-0.1963,-0.196335,-0.568989,-0.477917
4137,20221214,-0.0755,974061.10,1.287868e+06,0.608365,0.299401,-0.456621,0.2333,0.233278,-0.019216,0.210433


Unnamed: 0,trade_date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20050510,0.9100,909.99,297.6228,-2.380952,-2.052786,-2.735562,0.4292,0.429181,-3.114214,-2.560514
1,20050511,2.1000,1728.90,585.8502,1.829268,3.592814,3.750000,-1.2294,-1.229361,0.695163,0.419865
2,20050512,0.5900,1084.00,368.0140,1.497006,-0.578035,0.602410,-1.7776,-1.777566,-1.301766,-1.870865
3,20050513,0.5800,1841.65,630.2068,1.179941,1.162791,0.898204,0.1945,0.194509,-1.829287,-0.173099
4,20050516,-0.2900,1610.50,547.8560,0.583090,-0.574713,-0.296736,-1.3827,-1.382694,0.213242,-1.459758
...,...,...,...,...,...,...,...,...,...,...,...
3919,20221209,0.7305,49745.92,68044.9130,-0.072833,0.072569,-0.367107,0.9867,0.986679,0.222425,0.801338
3920,20221212,-1.7404,47332.20,64308.5080,0.510204,0.145033,-0.515844,-1.1205,-1.120514,0.357908,-0.496703
3921,20221213,0.3690,26568.19,36080.4420,-1.885424,-1.230992,-0.074074,-0.1963,-0.196335,-0.568989,-0.477917
3922,20221214,0.0000,26729.47,36301.8130,0.517369,0.146628,-0.074129,0.2333,0.233278,-0.019216,0.210433


In [22]:
display(stock_list_stdized[0])
display(stock_list_stdized[50])

Unnamed: 0,trade_date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20050510,1.410223,-0.913949,-0.987371,-0.851769,0.564253,-0.088991,0.233111,0.233106,-1.778239,-1.861204
1,20050511,-0.793465,-0.916352,-0.987909,1.133797,0.236012,1.012783,-0.767121,-0.767076,0.364426,0.271044
2,20050512,-0.033573,-0.959071,-1.005979,-0.609510,-0.401536,0.114122,-1.097718,-1.097670,-0.758789,-1.367809
3,20050513,-1.133417,-0.909179,-0.987413,0.039452,-0.661933,-1.627311,0.091573,0.091587,-1.055504,-0.153179
4,20050516,-0.829460,-0.997845,-1.024004,-1.036441,-1.454926,-0.995831,-0.859570,-0.859543,0.093359,-1.073692
...,...,...,...,...,...,...,...,...,...,...,...
4134,20221209,0.984243,1.430091,1.221468,0.682923,0.672426,0.700515,0.569315,0.569304,0.098525,0.543961
4135,20221212,-1.755969,1.082940,0.859347,0.476542,-0.523175,-0.913357,-0.701448,-0.701436,0.174730,-0.384695
4136,20221213,0.363011,0.320523,0.181202,-1.191842,-0.679330,0.231610,-0.144102,-0.144111,-0.346623,-0.371255
4137,20221214,-0.063769,0.432135,0.278240,0.204656,0.101891,-0.207975,0.114972,0.114967,-0.037391,0.121211


Unnamed: 0,trade_date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20050510,0.240106,-0.819902,-0.564355,-0.668493,-0.641736,-0.838243,0.236867,0.236862,-1.837628,-1.903769
1,20050511,0.604577,-0.813219,-0.563559,0.462321,1.045525,1.081478,-0.794897,-0.794847,0.374820,0.273549
2,20050512,0.142097,-0.818482,-0.564160,0.373080,-0.200987,0.149794,-1.135915,-1.135863,-0.784976,-1.399946
3,20050513,0.139034,-0.812298,-0.563436,0.287920,0.319281,0.237348,0.090867,0.090882,-1.091355,-0.159641
4,20050516,-0.127427,-0.814185,-0.563664,0.127613,-0.199994,-0.116353,-0.890260,-0.890229,0.094925,-1.099611
...,...,...,...,...,...,...,...,...,...,...,...
3919,20221209,0.185129,-0.421302,-0.377290,-0.048560,-0.006546,-0.137183,0.583670,0.583658,0.100259,0.552235
3920,20221212,-0.571653,-0.441003,-0.387607,0.108037,0.015111,-0.181209,-0.727153,-0.727139,0.178946,-0.396051
3921,20221213,0.074410,-0.610479,-0.465551,-0.535400,-0.396132,-0.050445,-0.152237,-0.152245,-0.359387,-0.382326
3922,20221214,-0.038607,-0.609163,-0.464939,0.109961,0.015588,-0.050461,0.115003,0.114999,-0.040084,0.120548


## saving data

In [23]:
save_std = False

In [24]:
if save_csv and save_std:
    for s_std_i in range(len(stock_list_stdized)):
        stock_list_stdized[s_std_i] = stock_list_stdized[s_std_i].rename(columns={'trade_date':'date'})
    display(stock_list_stdized[0])
    display(stock_list_stdized[505])

In [25]:
# save csv, temporary

if save_csv and save_std:


    for i in range(100):
        f_name = "stock_" + stock_name_list[i].replace('.','') + ".csv"
        stock_list_stdized[i].to_csv(f_name, index=False)

In [26]:
# stock_list[0].shape
# stock_list[0] = stock_list[0].reset_index(drop=True)

In [27]:
if save_csv and not save_std:
    for s_i in range(len(stock_list)):
        stock_list[s_i] = stock_list[s_i].rename(columns={'trade_date':'date'})
    display(stock_list[0])
    display(stock_list[505])

Unnamed: 0,date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20050510,3.6100,108414.13,6.733231e+04,-2.247191,1.435407,-0.167224,0.4292,0.429181,-3.114214,-2.560514
1,20050511,-1.9000,106868.37,6.681369e+04,3.119869,0.628931,2.512563,-1.2294,-1.229361,0.695163,0.419865
2,20050512,0.0000,79396.48,4.938704e+04,-1.592357,-0.937500,0.326797,-1.7776,-1.777566,-1.301766,-1.870865
3,20050513,-2.7500,111481.38,6.729219e+04,0.161812,-1.577287,-3.908795,0.1945,0.194509,-1.829287,-0.173099
4,20050516,-1.9900,54461.62,3.200359e+04,-2.746365,-3.525641,-2.372881,-1.3827,-1.382694,0.213242,-1.459758
...,...,...,...,...,...,...,...,...,...,...,...
4134,20221209,2.5449,1615831.92,2.197502e+06,1.901141,1.701183,1.753049,0.9867,0.986679,0.222425,0.801338
4135,20221212,-4.3066,1392584.28,1.848278e+06,1.343284,-1.236364,-2.172285,-1.1205,-1.120514,0.357908,-0.496703
4136,20221213,0.9916,902285.13,1.194285e+06,-3.166421,-1.620029,0.612557,-0.1963,-0.196335,-0.568989,-0.477917
4137,20221214,-0.0755,974061.10,1.287868e+06,0.608365,0.299401,-0.456621,0.2333,0.233278,-0.019216,0.210433


Unnamed: 0,date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20071213,-5.7900,4552507.29,5068219.325,4.143126,-3.146259,2.079395,-3.8030,-3.802978,-1.111738,-1.234310
1,20071214,-3.4900,2300360.06,2397683.313,-4.159132,-5.531168,-5.925926,1.9113,1.911308,-4.441837,-1.740640
2,20071217,-6.2700,1633122.27,1640865.341,-2.358491,-2.788104,-3.248031,-2.4181,-2.418068,3.083383,-0.032395
3,20071218,2.7400,871342.23,874507.087,-6.086957,-1.434034,-2.339776,-0.5637,-0.563689,-2.993807,-1.900063
4,20071219,-0.2000,796890.03,804259.694,5.041152,-0.193986,3.333333,2.4095,2.409486,1.046329,1.609627
...,...,...,...,...,...,...,...,...,...,...,...
3526,20221209,-0.3788,579852.56,152083.183,-0.375940,-0.374532,-0.760456,0.9867,0.986679,0.222425,0.801338
3527,20221212,-1.9011,577014.87,149552.360,-1.132075,-1.503759,-1.532567,-1.1205,-1.120514,0.357908,-0.496703
3528,20221213,0.7752,424061.88,110257.755,-1.908397,0.000000,0.000000,-0.1963,-0.196335,-0.568989,-0.477917
3529,20221214,-1.5385,506453.00,130385.199,1.556420,-0.381679,-0.778210,0.2333,0.233278,-0.019216,0.210433


In [None]:
# save csv, temporary

if save_csv and not save_std:


    for i in range(1):
        f_name = "stock_" + stock_name_list[i].replace('.','') + ".csv"
        stock_list[i].to_csv(f_name, index=False)

In [26]:
# stock_list[0]

In [27]:
# t1 = pandas_window(stock_list[0], 10)
# t1[0]

In [28]:
# t1_v = pandas_window(stock_list[0].values, 10)
# t1_v[0]

In [29]:
# t1_tensor = torch.tensor(t1_v, dtype=torch.float32).to(device)
# print(t1_tensor.shape)
# print(t1_tensor.get_device())

In [30]:
# stock_list[0]["close"][0]

In [31]:
# convert_to_pct(stock_list[0])
# stock_list[0].dropna()

In [32]:
# display(stock_list[0])

In [33]:
x_t1, y_t1 = pandas_window_single_step(stock_list_stdized[0], "open_pct_change", 3)
print(x_t1[0])
print(y_t1[0])

[[ 2.00505100e+07  1.41022256e+00 -9.13948648e-01 -9.87370979e-01
  -8.51769063e-01  5.64252816e-01 -8.89913626e-02  2.33110617e-01
   2.33105828e-01 -1.77823858e+00 -1.86120433e+00]
 [ 2.00505110e+07 -7.93465353e-01 -9.16352310e-01 -9.87908747e-01
   1.13379659e+00  2.36011799e-01  1.01278331e+00 -7.67121306e-01
  -7.67075598e-01  3.64426048e-01  2.71044460e-01]
 [ 2.00505120e+07 -3.35729695e-02 -9.59071192e-01 -1.00597898e+00
  -6.09510428e-01 -4.01535902e-01  1.14121930e-01 -1.09771768e+00
  -1.09767014e+00 -7.58789108e-01 -1.36780949e+00]]
0.03945155831717496


In [34]:
display(stock_list_stdized[0][:12])

Unnamed: 0,date,close_pct_chg,vol,amount,open_pct_change,high_pct_change,low_pct_change,index_close_pct_chg,index_open_pct_change,index_high_pct_change,index_low_pct_change
0,20050510,1.410223,-0.913949,-0.987371,-0.851769,0.564253,-0.088991,0.233111,0.233106,-1.778239,-1.861204
1,20050511,-0.793465,-0.916352,-0.987909,1.133797,0.236012,1.012783,-0.767121,-0.767076,0.364426,0.271044
2,20050512,-0.033573,-0.959071,-1.005979,-0.60951,-0.401536,0.114122,-1.097718,-1.09767,-0.758789,-1.367809
3,20050513,-1.133417,-0.909179,-0.987413,0.039452,-0.661933,-1.627311,0.091573,0.091587,-1.055504,-0.153179
4,20050516,-0.82946,-0.997845,-1.024004,-1.036441,-1.454926,-0.995831,-0.85957,-0.859543,0.093359,-1.073692
5,20050517,1.050274,-0.987764,-1.019014,-0.757862,0.858951,0.336656,0.400821,0.40084,-0.808742,0.204346
6,20050518,0.494352,-1.001601,-1.024171,0.982854,0.310932,1.607349,0.0929,0.09292,0.492932,0.141568
7,20050519,0.358372,-0.969027,-1.010472,0.467977,0.439556,-0.224448,0.040856,0.040862,0.081999,-0.221132
8,20050520,0.030418,-0.928882,-0.992974,0.461613,0.434426,0.732269,-0.121729,-0.121744,0.016167,0.212678
9,20050523,-1.645345,-0.954031,-1.005387,-0.555715,-1.175506,-1.095121,-1.35251,-1.352501,-0.232342,-0.891608


In [35]:
x_tensor = torch.tensor(x_t1, dtype=torch.float32).to(device)
y_tensor = torch.tensor(y_t1, dtype=torch.float32).to(device)
y_tensor = torch.reshape(y_tensor, (y_tensor.shape[0],1))

print(x_tensor.shape)
print(x_tensor.get_device())
print(y_tensor.shape)
print(y_tensor.get_device())
print(y_tensor[0])

torch.Size([4135, 3, 11])
0
torch.Size([4135, 1])
0
tensor([0.0395], device='cuda:0')
