Skip to content

Commit 67b84d8

Browse files
committed
fix loading data on some tickers on stock prediction tutorial
1 parent 1151b59 commit 67b84d8

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

machine-learning/stock-prediction/stock_prediction.ipynb

+2
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@
215215
" dates = result[\"X_test\"][:, -1, -1]\n",
216216
" # retrieve test features from the original dataframe\n",
217217
" result[\"test_df\"] = result[\"df\"].loc[dates]\n",
218+
" # remove duplicated dates in the testing dataframe\n",
219+
" result[\"test_df\"] = result[\"test_df\"][~result[\"test_df\"].index.duplicated(keep='first')]\n",
218220
" # remove dates from the training/testing sets & convert to float32\n",
219221
" result[\"X_train\"] = result[\"X_train\"][:, :, :len(feature_columns)].astype(np.float32)\n",
220222
" result[\"X_test\"] = result[\"X_test\"][:, :, :len(feature_columns)].astype(np.float32)\n",

machine-learning/stock-prediction/stock_prediction.py

+2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def load_data(ticker, n_steps=50, scale=True, shuffle=True, lookup_step=1, split
129129
dates = result["X_test"][:, -1, -1]
130130
# retrieve test features from the original dataframe
131131
result["test_df"] = result["df"].loc[dates]
132+
# remove duplicated dates in the testing dataframe
133+
result["test_df"] = result["test_df"][~result["test_df"].index.duplicated(keep='first')]
132134
# remove dates from the training/testing sets & convert to float32
133135
result["X_train"] = result["X_train"][:, :, :len(feature_columns)].astype(np.float32)
134136
result["X_test"] = result["X_test"][:, :, :len(feature_columns)].astype(np.float32)

0 commit comments

Comments
 (0)