In [1]:
import sys
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LassoCV, RidgeCV, ElasticNetCV
from sklearn.metrics import mean_squared_error
from tqdm.auto import tqdm
from typing import List
import time

sys.path.append("../")

from RGS import FastRandomizedGreedySelection, RandomizedGreedySelection
from RGS_experimental import FastRandomizedGreedySelectionCV
from data_generator import *
from data_plotting import *

In [2]:
import numpy as np
from sklearn.model_selection import KFold
from RGS_experimental import FastRandomizedGreedySelectionCV

# Set random seed for reproducibility 
rng = np.random.default_rng(42)

# Generate synthetic data 
n_samples = 1000
n_features = 20
k_true = 3  # Only first 3 features are relevant

# Generate features - now in sklearn convention (n_samples, n_features)
X = rng.standard_normal((n_samples, n_features))

# Generate target with only first k_true features being relevant
true_coef = np.zeros(n_features)
true_coef[:k_true] = [1.0, 0.5, 0.25]  # Decreasing importance
noise = rng.standard_normal(n_samples) * 0.1
y = X @ true_coef + noise

# Initialize model
k_max = 10
m_grid = [1, 3, 5, 10, 15, 20]
cv = KFold(n_splits=5, shuffle=True, random_state=42)
model = FastRandomizedGreedySelectionCV(
    k_max=k_max,
    m_grid=m_grid,
    n_replications=1000,
    random_state=42,
    cv=cv
)

# Fit model
model.fit(X, y)

# Print results
print("True coefficients:", true_coef)
print("\nSelected k:", model.k_)
print("Selected m:", model.m_)
print("\nLearned coefficients:", model.coef_[model.m_][model.k_])

Starting cross-validation...

Fold 1/5

Fold 2/5

Fold 3/5

Fold 4/5

Fold 5/5

MSE scores for each (k,m) combination:

k=1:
  m= 1: 1.052463 ± 0.092810
  m= 3: 1.335862 ± 0.140009
  m= 5: 1.052463 ± 0.092810
  m=10: 1.052463 ± 0.092810
  m=15: 0.325863 ± 0.018843
  m=20: 0.325863 ± 0.018843

k=2:
  m= 1: 0.947816 ± 0.083991
  m= 3: 0.937409 ± 0.101480
  m= 5: 0.601409 ± 0.054413
  m=10: 0.267409 ± 0.024639
  m=15: 0.064859 ± 0.003508
  m=20: 0.073128 ± 0.002926

k=3:
  m= 1: 1.297012 ± 0.135025
  m= 3: 1.312033 ± 0.126087
  m= 5: 1.297287 ± 0.157340
  m=10: 1.301559 ± 0.162221
  m=15: 1.297975 ± 0.130966
  m=20: 0.010082 ± 0.000761

k=4:
  m= 1: 1.333053 ± 0.137567
  m= 3: 1.333476 ± 0.138066
  m= 5: 1.333301 ± 0.139155
  m=10: 1.331618 ± 0.138071
  m=15: 1.329814 ± 0.136228
  m=20: 0.010178 ± 0.000763

k=5:
  m= 1: 1.334059 ± 0.138694
  m= 3: 1.333782 ± 0.138544
  m= 5: 1.332780 ± 0.139853
  m=10: 1.332451 ± 0.138072
  m=15: 1.330223 ± 0.136964
  m=20: 0.010190 ± 0.000768

k=6:
  m= 