Skip to content

Commit dbb9ecc

Browse files
committed
made it easier to change the outperformance parameter
1 parent 394bb2e commit dbb9ecc

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

stock_prediction.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
from utils import data_string_to_float, status_calc
44

55

6+
# The percentage by which a stock has to beat the S&P500 to be considered a 'buy'
7+
OUTPERFORMANCE = 10
8+
9+
610
def build_data_set():
711
"""
812
Reads the keystats.csv file and prepares it for scikit-learn
@@ -14,8 +18,9 @@ def build_data_set():
1418

1519
X_train = training_data[features].values
1620
# Generate the labels: '1' if a stock beats the S&P500 by more than 10%, else '0'.
17-
y_train = list(map(
18-
status_calc, training_data["stock_p_change"], training_data["SP500_p_change"]))
21+
y_train = list(status_calc(training_data["stock_p_change"],
22+
training_data["SP500_p_change"],
23+
OUTPERFORMANCE))
1924

2025
return X_train, y_train
2126

@@ -40,7 +45,7 @@ def predict_stocks():
4045
else:
4146
invest_list = z[y_pred].tolist()
4247
print(
43-
f"{len(invest_list)} stocks predicted to outperform the S&P500 by more than 10%:")
48+
f"{len(invest_list)} stocks predicted to outperform the S&P500 by more than {OUTPERFORMANCE}%:")
4449
print(' '.join(invest_list))
4550
return invest_list
4651

0 commit comments

Comments
 (0)