# Train

> In this module, we develop trainers of different models

In [None]:
#| default_exp train

In [None]:
#| hide
import sys
sys.path.append("/notebooks/katlas")
from nbdev.showdoc import *
%matplotlib inline
from katlas.core import Data
from katlas.feature import *

In [None]:
#| export
from fastbook import *
import xgboost as xgb
import matplotlib.pyplot as plt
from scipy.stats import spearmanr,pearsonr
from sklearn.model_selection import train_test_split
from pathlib import Path

In [None]:
#| export
def xgb_trainer(df,
                feature_col,
                target_col,
                test_index=None,
                xgb_params = { 
                            'max_depth':7, #from 4 to 7
                            'learning_rate':0.001, #from 0.001
                            'subsample':0.8,
                            'colsample_bytree':1, # from 0.2 to 1, because need to take all features
                            'eval_metric':'rmse',
                            'objective':'reg:squarederror',
                            'tree_method':'gpu_hist',
                            'predictor':'gpu_predictor',
                            'random_state':123
                        },
                model_file='xgb_model.bin',
                split_seed = 123, # seed of random split
               ):
    
    X = df[feature_col]
    y = df[target_col]
    
    print(f'xgb params is: {xgb_params}')
    
    if test_index is None:
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=split_seed)
    else:
        X_train,y_train = X.loc[~X.index.isin(test_index)],y.loc[~X.index.isin(test_index)]
        X_test, y_test = X.loc[test_index],y.loc[test_index]

        
    print(X_train.shape,y_train.shape,X_test.shape, y_test.shape)
    print(y_test.index)
    #prepare matrix for xgb
    dtrain = xgb.DMatrix(X_train, y_train)
    dtest = xgb.DMatrix(X_test, y_test)
    
    model = xgb.train(xgb_params, 
            dtrain=dtrain,
            evals=[(dtrain,'train'),(dtest,'valid')],
            num_boost_round=9999,
            early_stopping_rounds=100,
            verbose_eval=100,)
    
    # Save the model
    path = Path(model_file)
    
    # Make a directory if not exists
    path.parent.mkdir(exist_ok=True)
        
    model.save_model(model_file)
    print(f'Model saved to {model_file}')
    
    pred = model.predict(dtest)
    spearman_corr, _ = spearmanr(y_test, pred)
    print(f'Spearman correlation: {spearman_corr:.2f}')
    pearson_corr, p_value = pearsonr(y_test, pred)
    print(f'Pearson correlation: {pearson_corr:.2f}')
    out = np.vstack([np.ravel(y_test),np.ravel(pred)]).T
    pred_df = pd.DataFrame(out,index=y_test.index, columns = ['label','pred'] )
    pred_df['pred'] = pred
    
    fig, ax = plt.subplots()
    ax.scatter(y_test, pred)
    ax.set_xlabel('True values')
    ax.set_ylabel('Predicted values')
    ax.set_title('Scatter plot of true versus predicted values')
    plt.show()
    plt.close()
    
    
    dd = model.get_score(importance_type='gain')
    gain = pd.DataFrame({'feature':dd.keys(),f'gain_importance':dd.values()}).set_index('feature').sort_values(by='gain_importance',ascending=False)
    gain[:10].plot.barh()
    plt.show()
    plt.close()
        
    dd = model.get_score(importance_type='weight')
    weight = pd.DataFrame({'feature':dd.keys(),f'weight_importance':dd.values()}).set_index('feature').sort_values(by='weight_importance',ascending=False)
    weight[:10].plot.barh()
    plt.show()
    plt.close()
    
    return pred_df, gain, weight

In [None]:
show_doc(xgb_trainer)

---

### xgb_trainer

>      xgb_trainer (df, feature_col, target_col, test_index=None,
>                   xgb_params={'max_depth': 7, 'learning_rate': 0.001,
>                   'subsample': 0.8, 'colsample_bytree': 1, 'eval_metric':
>                   'rmse', 'objective': 'reg:squarederror', 'tree_method':
>                   'gpu_hist', 'predictor': 'gpu_predictor', 'random_state':
>                   123}, model_file='xgb_model.bin', split_seed=123)

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| df |  |  |  |
| feature_col |  |  |  |
| target_col |  |  |  |
| test_index | NoneType | None |  |
| xgb_params | dict | {'max_depth': 7, 'learning_rate': 0.001, 'subsample': 0.8, 'colsample_bytree': 1, 'eval_metric': 'rmse', 'objective': 'reg:squarederror', 'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor', 'random_state': 123} |  |
| model_file | str | xgb_model.bin |  |
| split_seed | int | 123 | seed of random split |

In [None]:
#| export
def xgb_predict(df, # a dataframe that contains ID and features for prediction
                feature_col, #feature column name
                ID_col = "ID", #ID column name
                model_file='xgb_model.bin'):
    # Load the XGBoost model
    model = xgb.Booster()
    model.load_model(model_file)
    
    # Prepare data for prediction
    X = df[feature_col]
    dtest = xgb.DMatrix(X)
    
    # Make predictions
    preds = model.predict(dtest)
    
    # Combine predictions with IDs into a DataFrame
    result_df = pd.DataFrame({ID_col: df[ID_col], 'preds': preds})
    
    return result_df

In [None]:
show_doc(xgb_predict)

---

### xgb_predict

>      xgb_predict (df, feature_col, ID_col='ID', model_file='xgb_model.bin')

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| df |  |  | a dataframe that contains ID and features for prediction |
| feature_col |  |  | feature column name |
| ID_col | str | ID | ID column name |
| model_file | str | xgb_model.bin |  |

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
