In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import TimeSeriesSplit, RandomizedSearchCV
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.inspection import permutation_importance
import matplotlib.pyplot as plt
import seaborn as sns

In [8]:
df = pd.read_csv("data/archive/stocks/AAPL.csv")
df = df.sort_values(by="Date")
df.head()

Unnamed: 0,Date,Open,High,Low,Close,Adj Close,Volume
0,1980-12-12,0.513393,0.515625,0.513393,0.513393,0.406782,117258400
1,1980-12-15,0.488839,0.488839,0.486607,0.486607,0.385558,43971200
2,1980-12-16,0.453125,0.453125,0.450893,0.450893,0.35726,26432000
3,1980-12-17,0.462054,0.464286,0.462054,0.462054,0.366103,21610400
4,1980-12-18,0.475446,0.477679,0.475446,0.475446,0.376715,18362400


In [11]:
df['returns'] = df['Close'].pct_change().dropna()
df.head()

Unnamed: 0,Date,Open,High,Low,Close,Adj Close,Volume,returns
0,1980-12-12,0.513393,0.515625,0.513393,0.513393,0.406782,117258400,
1,1980-12-15,0.488839,0.488839,0.486607,0.486607,0.385558,43971200,-0.052174
2,1980-12-16,0.453125,0.453125,0.450893,0.450893,0.35726,26432000,-0.073394
3,1980-12-17,0.462054,0.464286,0.462054,0.462054,0.366103,21610400,0.024752
4,1980-12-18,0.475446,0.477679,0.475446,0.475446,0.376715,18362400,0.028986


In [13]:
for lag in [1, 2, 5, 10]:
    df[f'return_lag{lag}'] = df['returns'].shift(lag)
df = df.dropna()
df.head()

Unnamed: 0,Date,Open,High,Low,Close,Adj Close,Volume,returns,return_lag1,return_lag2,return_lag5,return_lag10
11,1980-12-30,0.629464,0.629464,0.627232,0.627232,0.496981,17220000,-0.024306,0.014084,0.092308,0.048673,-0.052174
12,1980-12-31,0.611607,0.611607,0.609375,0.609375,0.482832,8937600,-0.02847,-0.024306,0.014084,0.042194,-0.073394
13,1981-01-02,0.616071,0.620536,0.616071,0.616071,0.488138,5415200,0.010989,-0.02847,-0.024306,0.052632,0.024752
14,1981-01-05,0.604911,0.604911,0.602679,0.602679,0.477526,8932000,-0.021739,0.010989,-0.02847,0.092308,0.028986
15,1981-01-06,0.578125,0.578125,0.575893,0.575893,0.456303,11289600,-0.044444,-0.021739,0.010989,0.014084,0.061033
