In [1]:
!python -V


Python 3.9.6


In [2]:
from __future__ import annotations
from sklearn.datasets import load_diabetes
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import mean_squared_error, accuracy_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from typing import Dict, Any
import numpy as np

In [4]:


MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
    "linear": {
        "version": "v0.1",
        "estimator": LinearRegression,
        "estimator_kwargs": {},
        "description": "StandardScaler + LinearRegression",
    },
    "ridge": {
        "version": "v0.2",
        "estimator": Ridge,
        "estimator_kwargs": {"alpha": 1.0},
        "description": "StandardScaler + Ridge(alpha=1.0)",
    },
}

In [5]:
# Load dataset
Xy = load_diabetes(as_frame=True)
X = Xy.frame.drop(columns=["target"])
y = Xy.frame["target"]

feature_names = list(X.columns)

# Deterministic split
X_train, X_test, y_train, y_test = train_test_split(
X.values, y.values, test_size=0.2, random_state=42
)

In [6]:


# Scale (fit on train, transform both)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)


config = MODEL_REGISTRY["linear"]
version = config["version"]

# Baseline model
estimator_cls = config["estimator"]
estimator_kwargs = config.get("estimator_kwargs", {})
model = estimator_cls(**estimator_kwargs)
model.fit(X_train_scaled, y_train)

# Evaluate
preds = model.predict(X_test_scaled)
rmse = float(np.sqrt(mean_squared_error(y_test, preds)))

In [7]:
rmse

(array([139.5475584 , 179.51720835, 134.03875572, 291.41702925,
        123.78965872,  92.1723465 , 258.23238899, 181.33732057,
         90.22411311, 108.63375858,  94.13865744, 168.43486358,
         53.5047888 , 206.63081659, 100.12925869, 130.66657085,
        219.53071499, 250.7803234 , 196.3688346 , 218.57511815,
        207.35050182,  88.48340941,  70.43285917, 188.95914235,
        154.8868162 , 159.36170122, 188.31263363, 180.39094033,
         47.99046561, 108.97453871, 174.77897633,  86.36406656,
        132.95761215, 184.53819483, 173.83220911, 190.35858492,
        124.4156176 , 119.65110656, 147.95168682,  59.05405241,
         71.62331856, 107.68284704, 165.45365458, 155.00975931,
        171.04799096,  61.45761356,  71.66672581, 114.96732206,
         51.57975523, 167.57599528, 152.52291955,  62.95568515,
        103.49741722, 109.20751489, 175.64118426, 154.60296242,
         94.41704366, 210.74209145, 120.2566205 ,  77.61585399,
        187.93203995, 206.49337474, 140.