# gplearn Development - WandB Integration

This notebook demo's integrating `gplearn` with Weights and Biases for logging to give a better view of model training. 
We avoid intervening in `gplearn`'s implementation as long as possible.

In [1]:
import numpy as np
import pandas as pd 

# 1. Working Model

A simple `gplearn` survival model with partial log-likelihood & shrink penalty fitness functio.

In [2]:
##### Fitness #####
from survshares.metrics import partial_likelihood
from gplearn_clean.gplearn.fitness import make_fitness

def fitness_pll_shrink(y_true, y_pred, sample_weight):
    """
    Partial log-likelihood with shrink penalty for gplearn. Smaller is better.
    """
    pll = partial_likelihood(y_true, y_pred, sample_weight)
    return pll + 0.05 * np.abs(y_pred).mean()


fitness_pll_shrink = make_fitness(
    function=fitness_pll_shrink, greater_is_better=False
)

In [3]:
##### Dataset #####
from survshares.datasets import Rossi

# Test dataset - Cox achieves C-index of 0.6403292470997135
X, T, E = Rossi().load(normalise=True)
feature_names = Rossi.features
calibration_time = Rossi.tmax

In [5]:
##### Wrapper #####
from sklearn.base import BaseEstimator, RegressorMixin
from pycox.models.cox import _CoxPHBase
import torchtuples as tt

class SymRegPH(BaseEstimator, RegressorMixin, _CoxPHBase):
    """
    Wrapper for gplearn's SymbolicRegressor to use with pycox supporting functions
    """

    def __init__(self, model):
        self.model = model

    def fit(self, X, T, E, *args, **kwargs):
        self.model.fit(X, T, sample_weight=E, *args, **kwargs)
        return self 

    def predict(self, X, *args, **kwargs):
        if isinstance(X, tt.TupleTree):
            X = X[0]
        return self.model.predict(X)

In [None]:
from gplearn_clean.gplearn.genetic import SymbolicRegressor
symreg_rl = SymRegPH(SymbolicRegressor(
    metric=fitness_pll_shrink,
    population_size=500,
    generations=20,
    stopping_criteria=0.7,
    parsimony_coefficient=1e-4,
    feature_names=feature_names,
    verbose=True,
)).fit(X, T, E)
symreg_rl.model



    |   Population Average    |             Best Individual              |
---- ------------------------- ------------------------------------------ ----------
 Gen   Length          Fitness   Length          Fitness      OOB Fitness  Time Left
   0    28.25              inf        7          5.92223              N/A     26.35s
   1    12.50              inf        7          5.92143              N/A     26.18s
   2     7.66          7.04097        5          5.89823              N/A     23.39s
   3     5.58          6.16305        9          5.89224              N/A     22.94s
   4     8.25          8.04746       11          5.88795              N/A     21.33s
   5     8.63          9.53506       13          5.87906              N/A     19.32s
   6    11.56          11.1594       23          5.87321              N/A     19.11s
   7    16.38          9.25956       33          5.86582              N/A     17.55s
   8    22.23          7.69193       33          5.86135              N/A  

# 2. WandB Wrapper

In [14]:
import wandb
from gplearn_clean.gplearn.genetic import SymbolicRegressor

class SymbolicSurvivalRegressor(SymbolicRegressor):
    """
    Survival regression using symbolic regression with gplearn.
    """

    def __init__(self,
                 *,
                 logging_console = True,
                 logging_wandb = True,
                 population_size=1000,
                 generations=20,
                 tournament_size=20,
                 stopping_criteria=0.0,
                 const_range=(-1., 1.),
                 init_depth=(2, 6),
                 init_method='half and half',
                 function_set=('add', 'sub', 'mul', 'div'),
                 metric='mean absolute error',
                 parsimony_coefficient=0.001,
                 p_crossover=0.9,
                 p_subtree_mutation=0.01,
                 p_hoist_mutation=0.01,
                 p_point_mutation=0.01,
                 p_point_replace=0.05,
                 max_samples=1.0,
                 feature_names=None,
                 warm_start=False,
                 low_memory=False,
                 n_jobs=1,
                 random_state=None):
        
        super().__init__(
            verbose=True,
            population_size=population_size,
            generations=generations,
            tournament_size=tournament_size,
            stopping_criteria=stopping_criteria,
            const_range=const_range,
            init_depth=init_depth,
            init_method=init_method,
            function_set=function_set,
            metric=metric,
            parsimony_coefficient=parsimony_coefficient,
            p_crossover=p_crossover,
            p_subtree_mutation=p_subtree_mutation,
            p_hoist_mutation=p_hoist_mutation,
            p_point_mutation=p_point_mutation,
            p_point_replace=p_point_replace,
            max_samples=max_samples,
            feature_names=feature_names,
            warm_start=warm_start,
            low_memory=low_memory,
            n_jobs=n_jobs,
            random_state=random_state)
    
        self.logging_console, self.logging_wandb = logging_console, logging_wandb

    def _verbose_reporter(self, run_details=None): 
        """
        Report the current generation to the console and wandb
        """
        if self.logging_console:
            super()._verbose_reporter(run_details)

        if self.logging_wandb and run_details is not None:
            wandb.log({
                'average fitness': run_details['average_fitness'][-1],
                'average length': run_details['average_length'][-1],
                'best fitness': run_details['best_fitness'][-1],
                'best length': run_details['best_length'][-1],
            }, step=run_details['generation'][-1])

    def fit(self, X, T, E, *args, **kwargs):
        super().fit(X, T, sample_weight=E, *args, **kwargs)
        return self 

    def predict(self, X, *args, **kwargs):
        if isinstance(X, tt.TupleTree):
            X = X[0]
        return super().predict(X)
    
# Start a new wandb run to track this script.
with wandb.init(entity="steliosbl-cambridge", project="test_proj_1") as run:
    params = dict(
        population_size=500,
        generations=20,
        stopping_criteria=0.7,
        parsimony_coefficient=1e-2,
        feature_names=feature_names,
    )
    run.config.update(params)
    run.config['metric'] = 'pll_shrink'
    symreg_rl = SymbolicSurvivalRegressor(
        metric=fitness_pll_shrink,
        n_jobs = 10,
        **params
    ).fit(X, T, E)



    |   Population Average    |             Best Individual              |
---- ------------------------- ------------------------------------------ ----------
 Gen   Length          Fitness   Length          Fitness      OOB Fitness  Time Left
   0    28.68              inf        7          5.92304              N/A      3.91s
   1     5.56          12.3431       13          5.88604              N/A      9.26s
   2     2.68           6.0331        3          5.88947              N/A      9.77s
   3     1.88          5.98637        3          5.88947              N/A      8.87s
   4     3.07          5.99099        3          5.88947              N/A      9.01s
   5     3.06          6.02913        3          5.88947              N/A      8.31s
   6     2.98          7.03024        3          5.88947              N/A     11.11s
   7     3.08          6.01529        3          5.88947              N/A      7.10s
   8     3.00          6.69549        3          5.88947              N/A  

0,1
average fitness,█▁▁▁▁▂▁▂▁▁▁ ▁ ▁ ▁▁▂
average length,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
best fitness,█▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
best length,▄█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
average fitness,6.98754
average length,2.94
best fitness,5.88947
best length,3.0


In [15]:
from gplearn_clean.gplearn.utils import _partition_estimators

_partition_estimators(500, 10)

(10,
 [50, 50, 50, 50, 50, 50, 50, 50, 50, 50],
 [0, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500])