# Training
This notebook goes through the steps to train tree-based models using 8-fold cross-validation and evaluate the performance on the training data.

Make sure you have ***preprocessed the data first*** so that the inputs fit seemlessly with the model classes. Preprocessing is explained in ```0_process_data.ipynb```. Since these models are tree-based scaling the data is not necessary.

You need to exchange the paths to your own data for training. The pre-defined paths will only work if you install through GitHub and use the code in development settings (`pip install -e .`). If you install the python package (`pip install CycPepPerm`) the paths will point to the installed package directory.

### Housekeeping

In [1]:
from pathlib import Path

ROOT_PATH = Path.cwd().parent
DATA_PATH = ROOT_PATH / 'data'
TRAIN_RANDOM_DW = DATA_PATH / 'perm_random80_train_dw.csv'
TRAIN_RANDOM_MORDRED = DATA_PATH / 'perm_random80_train_mordred.csv'
MODEL_PATH = ROOT_PATH / 'models'

## Random Forest

### Train model
The example is carried out with the provided descriptors from DataWarrior. If one wants to train and evaluate the model with mordred descriptors the code is the same, just change the paths to the input data (provided above in Housekeeping, too).

In [None]:
from cyc_pep_perm.models.randomforest import RF

# instantiate class
rf_regressor = RF()

# traing a model using cross-validation
# XXX: this will take some time especially if not on a GPU
# and if you use all the default hyperparameters defined in the script to search over

# TODO define where to save the model to
savepath = MODEL_PATH / 'rf_best8cv_datawarrior_random.pkl'

# TODO: exchange TRAIN_RANDOM_DW with your own path!
model = rf_regressor.train(datapath=TRAIN_RANDOM_DW, savepath=savepath)

In [None]:
# get predictions on the training data
y_pred, rmse, r2 = rf_regressor.evaluate()
print(f"Training RMSE = {rmse:0.2f}")
print(f"Training R2 = {r2:0.2f}")

### Plot results

In [None]:
# plot predictions
import matplotlib.pyplot as plt

y_true = rf_regressor.y

plt.scatter(y_true, y_pred, color="r")
plt.xlabel("True permeability [%]")
plt.ylabel("Predicted permeability [%]")
# plot rme and r2
plt.text(
    0.05,
    0.9,
    "RMSE = %0.2f" % rmse,
    ha="left",
    va="center",
    transform=plt.gca().transAxes,
)
plt.text(0.05, 0.85, "R2 = %0.2f" % r2, ha="left", va="center", transform=plt.gca().transAxes)

In [None]:
# this plot tells you the contribution of different features to the prediction
shap_values = rf_regressor.shap_explain(rf_regressor.X)

## XGBoost

### Train model
The example is carried out with the provided descriptors from DataWarrior. If one wants to train and evaluate the model with mordred descriptors the code is the same, just change the paths to the input data (provded above in Housekeeping, too).

In [None]:
from cyc_pep_perm.models.xgboost import XGB

# instantiate class
xgb_regressor = XGB()

# traing a model using cross-validation
# XXX: this will take some time especially if not on a GPU
# and if you use all the default hyperparameters defined in the script to search over

# TODO define where to save the model to
savepath = MODEL_PATH / 'xgb_best8cv_datawarrior_random.pkl'

# TODO: exchange TRAIN_RANDOM_DW with your own path!
model = xgb_regressor.train(datapath=TRAIN_RANDOM_DW, savepath=savepath)

In [None]:
# get predictions on the training data
y_pred, rmse, r2 = xgb_regressor.evaluate()
print(f"Training RMSE = {rmse:0.2f}")
print(f"Training R2 = {r2:0.2f}")

### Plot results

In [None]:
# plot predictions
import matplotlib.pyplot as plt

y_true = xgb_regressor.y

plt.scatter(y_true, y_pred, color="r")
plt.xlabel("True permeability [%]")
plt.ylabel("Predicted permeability [%]")
# plot rme and r2
plt.text(
    0.05,
    0.9,
    "RMSE = %0.2f" % rmse,
    ha="left",
    va="center",
    transform=plt.gca().transAxes,
)
plt.text(0.05, 0.85, "R2 = %0.2f" % r2, ha="left", va="center", transform=plt.gca().transAxes)

In [None]:
# this plot tells you the contribution of different features to the prediction
shap_values = xgb_regressor.shap_explain(xgb_regressor.X)