In [1]:
import sys
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LassoCV, RidgeCV, ElasticNetCV
from sklearn.metrics import mean_squared_error
from tqdm.auto import tqdm
from typing import List
import time

sys.path.append("../")

from RGS import FastRandomizedGreedySelection, RandomizedGreedySelection
from RGS_experimental import FastRandomizedGreedySelectionCV
from data_generator import *
from data_plotting import *

In [2]:
import numpy as np
from sklearn.model_selection import KFold
from RGS_experimental import FastRandomizedGreedySelectionCV

# Set random seed for reproducibility 
rng = np.random.default_rng(42)

# Generate synthetic data 
n_samples = 1000
n_features = 20
k_true = 3  # Only first 3 features are relevant

# Generate features - now in sklearn convention (n_samples, n_features)
X = rng.standard_normal((n_samples, n_features))

# Generate target with only first k_true features being relevant
true_coef = np.zeros(n_features)
true_coef[:k_true] = [1.0, 0.5, 0.25]  # Decreasing importance
noise = rng.standard_normal(n_samples) * 0.1
y = X @ true_coef + noise

# Initialize model
k_max = 10
m_grid = [1, 3, 5, 10, 15, 20]
cv = KFold(n_splits=5, shuffle=True, random_state=42)
model = FastRandomizedGreedySelectionCV(
    k_max=k_max,
    m_grid=m_grid,
    n_replications=1000,
    random_state=42,
    cv=cv
)

# Fit model
model.fit(X, y)

# Print results
print("True coefficients:", true_coef)
print("\nSelected k:", model.k_)
print("Selected m:", model.m_)
print("\nLearned coefficients:", model.coef_[model.m_][model.k_])

True coefficients: [1.   0.5  0.25 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.  ]

Selected k: 3
Selected m: 20

Learned coefficients: [0.99998206 0.50466162 0.25161512 0.         0.         0.
 0.         0.         0.00560357 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]
