In [None]:
!pip install --upgrade scikit-learn==1.0.0 --quiet
!pip install --upgrade linear-tree --quiet

In [51]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.datasets import load_boston, load_diabetes
from sklearn.model_selection import train_test_split, RepeatedKFold, cross_validate, cross_val_score, cross_val_predict
from sklearn.metrics import mean_squared_error, r2_score

from sklearn.linear_model import LinearRegression, Lasso, Ridge, ElasticNet, BayesianRidge
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, BaggingRegressor

from lineartree import LinearTreeRegressor, LinearForestRegressor

from tqdm.auto import tqdm

In [72]:
class Config:
    SEED = 3655
    N_SPLITS = 5
    N_REPEATS = 2
    N_ESTIMATORS = 100

In [121]:
X, y = load_diabetes(
    return_X_y=True, 
    as_frame=True
    )

In [122]:
X_train, X_valid, y_train, y_valid = train_test_split(X, y, train_size=.8, test_size=.2, random_state=Config.SEED)

In [123]:
cv = RepeatedKFold(
    n_splits=Config.N_SPLITS, 
    n_repeats=Config.N_REPEATS, 
    random_state=Config.SEED
    ).split(X)
cv = list(cv)

## RandomForestRegressor

In [75]:
reg_RF = RandomForestRegressor(
    n_estimators=Config.N_ESTIMATORS, 
    ccp_alpha=.01, 
    random_state=Config.SEED
    )

oof_rmse_RF = -cross_val_score(
        estimator=reg_RF, 
        X=X, 
        y=y, 
        scoring="neg_root_mean_squared_error", 
        cv=cv,
        n_jobs=-1
        )

print(f'oof rmse: {np.mean(oof_rmse_RF):.3f}')

oof rmse: 58.360


In [77]:
reg_RF.fit(X_train, y_train)

train_rmse = mean_squared_error(y_train, reg_RF.predict(X_train), squared=False)
valid_rmse = mean_squared_error(y_valid, reg_RF.predict(X_valid), squared=False)

print(f'train rmse: {train_rmse:.3f}')
print(f'valid rmse: {valid_rmse:.3f}')

train rmse: 21.661
valid rmse: 57.471


## LinearForestRegressor

In [124]:
reg_LF = LinearForestRegressor(
    base_estimator=LinearRegression(),
    n_estimators=Config.N_ESTIMATORS, 
    ccp_alpha=.01, 
    random_state=Config.SEED
    )

oof_rmse_LF = -cross_val_score(
        estimator=reg_LF, 
        X=X, 
        y=y, 
        scoring="neg_root_mean_squared_error", 
        cv=cv,
        n_jobs=-1
        )

print(f'oof rmse: {np.mean(oof_rmse_LF):.3f}')

oof rmse: 56.564


In [125]:
reg_LF.fit(X_train, y_train)

train_rmse = mean_squared_error(y_train, reg_LF.predict(X_train), squared=False)
valid_rmse = mean_squared_error(y_valid, reg_LF.predict(X_valid), squared=False)

print(f'train rmse: {train_rmse:.3f}')
print(f'valid rmse: {valid_rmse:.3f}')

train rmse: 20.465
valid rmse: 54.903


### Same Algorithms

In [127]:
reg_LR = LinearRegression()
reg_RF = RandomForestRegressor(
    n_estimators=Config.N_ESTIMATORS, 
    ccp_alpha=.01, 
    random_state=Config.SEED
    )

reg_LR.fit(X_train, y_train)

reg_RF.fit(X_train, y_train - reg_LR.predict(X_train))

train_rmse = mean_squared_error(y_train, reg_RF.predict(X_train) + reg_LR.predict(X_train), squared=False)
valid_rmse = mean_squared_error(y_valid, reg_RF.predict(X_valid) + reg_LR.predict(X_valid), squared=False)

print(f'train rmse: {train_rmse:.3f}')
print(f'valid rmse: {valid_rmse:.3f}')

train rmse: 20.465
valid rmse: 54.903


## BaggingRegressor + LinearTreeRegressor

In [87]:
LT = LinearTreeRegressor(
    base_estimator=LinearRegression(), 
    )

reg_BR_LT = BaggingRegressor(
    base_estimator=LT,
    n_estimators=Config.N_ESTIMATORS,
    random_state=Config.SEED,
    n_jobs=-1
)

oof_rmse_BR_LT = -cross_val_score(
        estimator=LT, 
        X=X, 
        y=y, 
        scoring="neg_root_mean_squared_error", 
        cv=cv,
        n_jobs=2
        )

print(f'oof rmse: {np.mean(oof_rmse_BR_LT):.3f}')

oof rmse: 69.159


In [88]:
reg_BR_LT.fit(X_train, y_train)

train_rmse = mean_squared_error(y_train, reg_BR_LT.predict(X_train), squared=False)
valid_rmse = mean_squared_error(y_valid, reg_BR_LT.predict(X_valid), squared=False)

print(f'train rmse: {train_rmse:.3f}')
print(f'valid rmse: {valid_rmse:.3f}')

train rmse: 37.676
valid rmse: 55.312
