In [1]:
%load_ext autoreload
%autoreload 2
import config as cfg
import numpy as np
import pandas as pd
import logging
from sca import plots, helpers as h
from sklearn.pipeline import make_pipeline

In [2]:
X, y, pts, ks = h.load_data(cfg.DATA_DIR / 'ascadv_clean.h5', as_df=True)
best_feats_idx = np.load('data/rf_rfe_5_best.npy')
X_best = X.iloc[:, best_feats_idx[:5]]

In [39]:
from scipy import fft
from sklearn.base import TransformerMixin
from dataclasses import dataclass, field

def _feat_window(X, f, n, k):
    i = list(n).index(f)
    return X[:, i-k:i+k]

@dataclass
class FFTFeaturizer(TransformerMixin):
    feats: list
    names: list
    k: int = 8
    n: int = 8
    
    def _fft_feat(self, W):
        freqs = fft.rfft(W, axis=1)
        return np.abs(freqs)[:, 1:self.n+1]

    def fit(self, X, y):
        return self

    def transform(self, X):
        X = np.asarray(X, dtype=np.float32)
        windows = [_feat_window(X, f, self.names, self.k) 
                   for f in self.feats]
        return np.concat([self._fft_feat(W) for W in windows], axis=1)

# FFTFeaturizer(["187", "1070"], X.columns, k=16, n=8).transform(X.values)

In [47]:
from sklearn.preprocessing import RobustScaler, PolynomialFeatures, FunctionTransformer
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.pipeline import FeatureUnion

feats = FeatureUnion([
    ("fft", FFTFeaturizer(["187", "1070"], X_best.columns, k=16, n=3)),
    ("poly", make_pipeline(
        RobustScaler(quantile_range=(5, 95)),
        PolynomialFeatures(degree=(2, 2), include_bias=False)
    ))
])
rf = RFC(max_depth=5, min_samples_leaf=10, n_jobs=-1, random_state=cfg.SEED)
pl_rf = make_pipeline(feats, rf)

score = np.mean(h.cv(pl_rf, X_best, y, pts, ks, seed=cfg.SEED))
logging.info(f"Mean PI: {score:.2e}")

21:47:32: [1] Mean PI: 0.027
21:47:37: [2] Mean PI: 0.028
21:47:42: [3] Mean PI: 0.027
21:47:47: [4] Mean PI: 0.026
21:47:52: [5] Mean PI: 0.027
21:47:57: [6] Mean PI: 0.028
21:48:02: [7] Mean PI: 0.027
21:48:07: [8] Mean PI: 0.026
21:48:12: [9] Mean PI: 0.027
21:48:17: [10] Mean PI: 0.028
21:48:17: Mean PI: 2.72e-02
