In [None]:
import numpy as np
import pandas as pd
from scipy.stats import special_ortho_group
import matplotlib.pyplot as plt

import hisel

In [None]:
n = 5000
n_cat = 10
n_cont = 30
n_relcat = 2
n_relcont = 8
dim_y = 1

In [None]:
ms = np.random.randint(low=5, high=8, size=n_cat)
cats = [np.random.randint(m, size=(n, 1)) for m in ms]
cat = np.concatenate(cats, axis=1)
cat_ = np.expand_dims(cat, axis=2)
catdf = pd.DataFrame(cat, columns = [f'cat{n}' for n in range(n_cat)])

In [None]:
acat = np.random.permutation(
    np.concatenate((np.diag(np.random.choice([-1, 1], size=n_relcat)), 
                    np.zeros(shape=(n_relcat, n_cat - n_relcat), dtype=int)),
                   axis=1).T).T
tcat = np.expand_dims(np.ones(shape=(1, n_relcat), dtype=int) @ acat, axis=0)
relevant_cats = np.sort(np.argsort(np.sum(np.abs(acat), axis=0))[::-1][:n_relcat])

In [None]:
cont = np.random.uniform(low=-1, high=1, size=(n, n_cont))
cont_ = np.expand_dims(cont, axis=2)
contdf = pd.DataFrame(cont, columns = [f'cont{n}' for n in range(n_cont)])

In [None]:
xdf = pd.merge(catdf, contdf, left_index=True, right_index=True)

In [None]:
u1 = special_ortho_group.rvs(n_relcont)
u2 = special_ortho_group.rvs(n_relcont)
acont = np.random.permutation(
    np.concatenate((np.eye(n_relcont), np.zeros((n_relcont, n_cont - n_relcont))), axis=1).T).T
ct1 = np.expand_dims(u1 @ acont, axis=0)
ct2 = np.expand_dims(u2 @ acont, axis=0)
relevant_conts = n_cat + np.sort(np.argsort(np.sum(np.abs(acont), axis=0))[::-1][:n_relcont])

In [None]:
all_relevant = np.sort(np.concatenate((relevant_cats, relevant_conts)))
relevant_cat_features = sorted(xdf.iloc[:, relevant_cats].columns.tolist())
relevant_cont_features = sorted(xdf.iloc[:, relevant_conts].columns.tolist())
relevant_features = sorted(xdf.iloc[:, all_relevant].columns.tolist())

In [None]:
t = np.random.uniform(low=-1, high=1, size=(1, dim_y, n_relcont))

In [None]:
chooser = tcat @ cat_
q = np.quantile(chooser, .5)
y = np.squeeze(t @ (
    (chooser > q) * ct1 @ cont_ +
    (chooser <= q) * ct2 @ cont_
))

In [None]:
ydf = pd.DataFrame(y, columns = [f'y{n}' for n in range(dim_y)])

## KSG selection

In [None]:
ksgfeatures, ksgmis = hisel.select.ksgmi(xdf, ydf, threshold=.0001)

In [None]:
expected = sorted(list(relevant_features))
selected = sorted(list(ksgfeatures))
leftout = sorted(list(set(expected).difference(set(selected))))
print(f'Expected features:\n{expected}')
print(f'Selected features:\n{selected}')
print(f'Left-out features:\n{leftout}')

## Selection

In [None]:
categorical_search_params = hisel.feature_selection.SearchParameters(
    num_permutations=20,
    im_ratio=.01,
    max_iter=1,
    parallel=True,
    random_state=None,
)

In [None]:
hsiclasso_params = hisel.feature_selection.HSICLassoParameters(
    hsic_threshold=.01,
    batch_size=5000,
    minibatch_size=500,
    number_of_epochs=4,
    use_preselection=False,
    device=None
)

In [None]:
selection = hisel.feature_selection.select_features(
    xdf, 
    ydf,
    hsiclasso_params,
    categorical_search_params
)
hsic_selection = selection.continuous_lasso_selection
cat_selection = selection.categorical_search_selection

In [None]:
selected_cat_features = sorted(cat_selection.features)
selected_cont_features = sorted(hsic_selection.features)
selected_features = sorted(selection.selected_features)

In [None]:
leftout_cat = sorted(list(
    set(relevant_cat_features).difference(set(selected_cat_features))
))
print(f'Relevant cat features:\n{relevant_cat_features}')
print(f'Selected cat features:\n{selected_cat_features}')
print(f'Left-out cat features:\n{leftout_cat}')

In [None]:
leftout_cont = sorted(list(
    set(relevant_cont_features).difference(set(selected_cont_features))
))
print(f'Relevant cont features:\n{relevant_cont_features}')
print(f'Selected cont features:\n{selected_cont_features}')
print(f'Left-out cont features:\n{leftout_cont}')

In [None]:
print(f'All relevant features:\n{relevant_features}')
print(f'Selected features:\n{selected_features}')

You can explore how the selection threshold affects the choice of the continuous features

In [None]:
hsic_selection.select_from_lasso_path(threshold=.025)

You can visualise the regularisation curve used to select the continuous features

In [None]:
curve = hsic_selection.regcurve
plt.plot(np.arange(1, 1+len(curve)), curve)