# Tests: Fixed Equation Forms
This notebook tests SHARE fitting given fixed programs (equation forms). Do the shape functions fit properly? 

## 0. Setup

In [None]:
import numpy as np 
import pandas as pd 
import torch

import matplotlib.pyplot as plt 
import seaborn as sns
sns.set_style("whitegrid")

from gplearn.gplearn._program import _Program 
import survshares.lightning # SBL Patch

%load_ext autoreload
%autoreload 2

In [6]:
from gplearn.gplearn.model import ShapeNN
from gplearn.gplearn.fitness import mean_square_error
from gplearn.gplearn.functions import (
    add2, sub2, mul2, div2, shape1
)
from sklearn.utils.validation import check_random_state


share_params = {'function_set': [add2, sub2, mul2, div2],
    'arities': {2: [add2, sub2, mul2, div2]},
    'init_depth': (2, 6),
    'init_method': 'half and half',
    'n_features': 10,
    'const_range': (-1.0, 1.0),
    'metric': mean_square_error,
    'p_point_replace': 0.05,
    'parsimony_coefficient': 0.1,
    "optim_dict": {
        "alg": "adam",
        "lr": 1e-2,  # tuned automatically
        "max_n_epochs": 1000,
        "tol": 1e-3,
        "task": 'regression',
        "device": 'cpu',
        "batch_size": 1000,
        "shape_class": ShapeNN,
        "constructor_dict": {
            "n_hidden_layers": 5,
            "width": 10,
            "activation_name": "ELU",
        },
        "num_workers_dataloader": 0,
        "seed": 42,
        "checkpoint_folder": "results/checkpoints/test",
        "keep_models": False,
        "enable_progress_bar": False
    },
    "random_state": check_random_state(415)
}

## 1. No Shape Functions

We first test a simple program with no shape functions as a sanity check.

In [None]:
from sklearn.utils._testing import assert_array_almost_equal

def test_program():
    """Check executing the program works"""

    # Test for a small program
    test_gp = [mul2, div2, 8, 1, sub2, 9, .5]
    X = np.reshape(share_params['random_state'].uniform(size=50), (5, 10))
    gp = _Program(program=test_gp, **share_params)

    result = gp.execute(torch.Tensor(X))
    expected = (X[:,8]/X[:,1])*(X[:,9] - 0.5)
    assert_array_almost_equal(result, expected)

test_program()

## 2. Univariate Equations 

We check the shapes learned by fitting these univariate equations:

1. $y=x^2$
2. $y=x^3$
3. $y=x+5$
4. $y=x$ (thus $s(x) = x)$
5. $y=sin(x)$
6. $y=cos(exp(tan(x)))$ 

In [None]:
%%capture
def test_shares_univariate():
    test_share = [shape1, 0]
    X_train, X_test = np.sort(np.random.rand(1000, 1), axis=0), np.sort(np.random.rand(1000, 1), axis=0)

    tests = {
        'y=x^2': lambda x: x ** 2,
        'y=x^3': lambda x: x ** 3,
        'y=x+5': lambda x: x + 5,
        'y=x': lambda x: x,
        'y=sin(x \pi)': lambda x: np.sin(x*np.pi),
        'y=cos(exp(tan(x)))': lambda x: np.cos(np.exp(np.tan(x))),
    }

    fig, axes = plt.subplots(2, 3, figsize=(10, 5))

    results = {}
    for ax, (name, func) in zip(axes.flatten(), tests.items()):
        y_train, y_test = func(X_train[:, 0]), func(X_test[:, 0])
        gp = _Program(program=test_share, **share_params)
        gp.raw_fitness(torch.Tensor(X_train), torch.Tensor(y_train), None)
        y_pred = gp.execute(torch.Tensor(X_test))
        results[name] = np.abs(y_pred - y_test).mean()

        ax.plot(X_test, y_pred, label='Predicted', color='blue')
        ax.plot(X_test, y_test, label='True', color='red')
        ax.set_title(f'${name}$ : $\pm {results[name]:.2f}$')


    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=2)
    fig.tight_layout(rect=[0, 0, 1, 0.95])

    return fig, results

fig_univariate, results_univariate = test_shares_univariate()
fig_univariate

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LR finder stopped early after 96 steps due to diverging loss.
Learning rate set to 0.02089296130854041
Restoring states from the checkpoint path at data/lightning_logs/.lr_find_f84bd5f9-2781-421e-831a-0ccb89ef2c4d.ckpt
Restored all states from the checkpoint at data/lightning_logs/.lr_find_f84bd5f9-2781-421e-831a-0ccb89ef2c4d.ckpt
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LR finder stopped early after 96 steps due to diverging loss.
Learning rate set to 0.05248074602497723
Restoring states from the checkpoint path at data/lightning_logs/.lr_find_89f3d0b7-b8f9-4684-b01e-e2c2c1b9cefa.ckpt
Restored all states from the checkpoint at data/lightning_logs/.lr_find_89f3d0b7-b8f9-4684-b01e-e2c2c1b9cefa.ckpt
G

## 3. Multivariate Equations

Next we check the following equations with more variables:

1. $y=x_1 + x_2$ (thus $s(x) = x$)
2. $y=x_1^2 + x_2/2$
3. $y=s(x_1) + s(x_2) + s(x_3)$ (risk study)

We observe that the learned shape functions are vertically translated from the ground truth - this is expected.

In [None]:
%%capture
def test_shares_bivariate():
    test_share = [add2, shape1, 0, shape1, 1]
    X_train, X_test = np.random.rand(1000, 2), np.random.rand(1000, 2)

    tests = {
        'y=x_1 + x_2': [
            lambda x: x, 
            lambda x: x 
        ],
        'y=x_1^2 + sin(x_2 \pi)': [
            lambda x: x ** 2, 
            lambda x: np.sin(x * np.pi)
        ],
    }

    fig, axes = plt.subplots(2, 2, figsize=(10, 5))
    results = {}

    for ax_row, (name, functions) in zip(axes, tests.items()):
        y_train = functions[0](X_train[:, 0]) + functions[1](X_train[:, 1])
        y_test = functions[0](X_test[:, 0]) + functions[1](X_test[:, 1])

        gp = _Program(program=test_share, **share_params)
        gp.raw_fitness(torch.Tensor(X_train), torch.Tensor(y_train), None)
        y_pred = gp.execute(torch.Tensor(X_test))

        for i, (ax, func, shape, x_true) in enumerate(zip(ax_row, functions, gp.model.shape_functions, X_test.T)):
            x_true = np.sort(xi)
            s_true = func(xi)
            shape.to(torch.device('cpu'))
            with torch.no_grad():
                s_pred = shape(torch.Tensor(xi)).flatten()

            result = np.corrcoef(s_pred.numpy(), s_true)[0, 1]
            results[f'{name}_{i}'] = result
            ax.plot(x_true, s_pred, label='Shape', color='blue')
            ax.plot(x_true, s_true, label='Ground Truth', color='red')
            ax.set_title(f'${name}: s_{i+1}$ : $Corr={result:.2f}$')

    handles, labels = axes[0][0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=2)
    fig.tight_layout(rect=[0, 0, 1, 0.95])

    return fig, results

fig_bivariate, results_bivariate = test_shares_bivariate()
fig_bivariate

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LR finder stopped early after 96 steps due to diverging loss.
Learning rate set to 0.008317637711026709
Restoring states from the checkpoint path at data/lightning_logs/.lr_find_fc121a04-1981-4567-b3e0-a6ff28e42477.ckpt
Restored all states from the checkpoint at data/lightning_logs/.lr_find_fc121a04-1981-4567-b3e0-a6ff28e42477.ckpt
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LR finder stopped early after 97 steps due to diverging loss.
Learning rate set to 0.008317637711026709
Restoring states from the checkpoint path at data/lightning_logs/.lr_find_cea1ea0e-754a-434d-81a5-f6aeacb4712a.ckpt
Restored all states from the checkpoint at data/lightning_logs/.lr_find_cea1ea0e-754a-434d-81a5-f6aeacb4712a.ckpt
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

In [None]:
%%capture
from experiments.risk_scores_data import get_nodes_feature, get_age_feature, get_bmi_feature, generate_data
from sklearn.model_selection import train_test_split

def get_dataset():
    df = generate_data(1000, seed=42)
    X = df.drop(columns=['target']).values
    y = df['target'].values

    return train_test_split(X, y, test_size=0.2, random_state=42)

def test_shares_risk():
    test_share = [add2, add2, shape1, 0, shape1, 1, shape1, 2]
    X_train, X_test, y_train, y_test = get_dataset()

    features = {
        'nodes': (
            get_nodes_feature, 
            (0, 50, 1000)
        ),
        'age': (
            get_age_feature, 
            (45,70,1000)
        ),
        'bmi': (
            get_bmi_feature, 
            (17,45,1000)
        )
    }

    gp = _Program(program=test_share, **share_params)
    gp.raw_fitness(torch.Tensor(X_train), torch.Tensor(y_train), None)
    y_pred = gp.execute(torch.Tensor(X_test))
    
    fig, ax = plt.subplots(1, 3, figsize=(12, 5))
    results = {}

    for i, (ax, shape, (name, (func, ranges))) in enumerate(zip(ax, gp.model.shape_functions, features.items())):
        x_true = np.linspace(*ranges)
        s_true = func()(x_true) 
        shape.to(torch.device('cpu'))
        with torch.no_grad():
            s_pred = shape(torch.Tensor(xi)).flatten()

        result = np.corrcoef(s_pred.numpy(), s_true)[0, 1]
        results[f'{name}_{i}'] = result
        ax.plot(x_true, s_pred, label='Shape', color='blue')
        ax.plot(x_true, s_true, label='Ground Truth', color='red')
        ax.set_title(f'$s_{i}(${name}$)$ : Corr=${result:.2f}$')

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=2)
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    
    return fig, results 

fig_risk, results_risk = test_shares_risk()
fig_risk

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LR finder stopped early after 92 steps due to diverging loss.
Learning rate set to 0.04365158322401657
Restoring states from the checkpoint path at data/lightning_logs/.lr_find_8ed54f9c-7963-4df1-8e57-f4fee6248e20.ckpt
Restored all states from the checkpoint at data/lightning_logs/.lr_find_8ed54f9c-7963-4df1-8e57-f4fee6248e20.ckpt
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
