Skip to content

Commit d65b52a

Browse files
feat(model): add linear regression training with comprehensive metrics and validation
1 parent e2ab547 commit d65b52a

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

model.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""Model training module for Linear Regression.
2+
3+
This module provides functionality to train a scikit-learn LinearRegression model
4+
on preprocessed data and calculate comprehensive regression metrics. The module
5+
is designed to work with data that has already been preprocessed (scaled and imputed)
6+
by the preprocessing pipeline.
7+
"""
8+
9+
import pandas as pd
10+
import numpy as np
11+
from sklearn.linear_model import LinearRegression
12+
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
13+
from typing import Dict, Tuple
14+
15+
16+
def train_model(X: pd.DataFrame, y: pd.Series) -> Tuple[LinearRegression, Dict[str, float]]:
17+
"""
18+
Train a Linear Regression model on preprocessed data and calculate metrics.
19+
20+
This function trains a scikit-learn LinearRegression model using the provided
21+
features and target values. After training, it generates predictions on the
22+
training data and calculates comprehensive regression metrics including R² score,
23+
Mean Squared Error (MSE), Root Mean Squared Error (RMSE), and Mean Absolute
24+
Error (MAE).
25+
26+
Parameters
27+
----------
28+
X : pd.DataFrame
29+
Preprocessed features DataFrame with numerical values (already scaled/imputed).
30+
Must be non-empty and contain no NaN values. Shape should be (n_samples, n_features).
31+
y : pd.Series
32+
Target values as a pandas Series with numerical values.
33+
Must be non-empty, contain no NaN values, and have the same number of samples as X.
34+
35+
Returns
36+
-------
37+
Tuple[LinearRegression, Dict[str, float]]
38+
A tuple containing:
39+
- model (LinearRegression): Trained LinearRegression model instance with
40+
accessible `.coef_` (coefficients) and `.intercept_` attributes
41+
- metrics (Dict[str, float]): Dictionary containing regression metrics:
42+
- 'r2': R² (coefficient of determination) score
43+
- 'mse': Mean Squared Error
44+
- 'rmse': Root Mean Squared Error (square root of MSE)
45+
- 'mae': Mean Absolute Error
46+
47+
Raises
48+
------
49+
TypeError
50+
If X is not a pandas DataFrame or y is not a pandas Series.
51+
ValueError
52+
If X or y are empty, contain NaN values, or have mismatched shapes.
53+
54+
Examples
55+
--------
56+
>>> import pandas as pd
57+
>>> from sklearn.datasets import make_regression
58+
>>>
59+
>>> # Generate synthetic data
60+
>>> X_array, y_array = make_regression(n_samples=100, n_features=3, noise=10, random_state=42)
61+
>>> X = pd.DataFrame(X_array, columns=['feature1', 'feature2', 'feature3'])
62+
>>> y = pd.Series(y_array, name='target')
63+
>>>
64+
>>> # Train model and get metrics
65+
>>> model, metrics = train_model(X, y)
66+
>>>
67+
>>> print(f"R² Score: {metrics['r2']:.4f}")
68+
>>> print(f"MSE: {metrics['mse']:.4f}")
69+
>>> print(f"RMSE: {metrics['rmse']:.4f}")
70+
>>> print(f"MAE: {metrics['mae']:.4f}")
71+
>>>
72+
>>> # Access model parameters
73+
>>> print(f"Coefficients: {model.coef_}")
74+
>>> print(f"Intercept: {model.intercept_}")
75+
76+
Notes
77+
-----
78+
- This function is pure and has no side effects (no file I/O or global state changes)
79+
- Input data must be preprocessed (scaled/imputed) before calling this function
80+
- Metrics are calculated on the training data itself
81+
- The function handles edge cases like single features and perfect fits
82+
- For very small datasets or perfect fits, some metrics may be extreme values
83+
- The model uses ordinary least squares (OLS) estimation
84+
"""
85+
# Validate inputs
86+
_validate_inputs(X, y)
87+
88+
# Train the Linear Regression model
89+
model = LinearRegression()
90+
model.fit(X, y)
91+
92+
# Generate predictions on training data for metric calculation
93+
y_pred = model.predict(X)
94+
95+
# Calculate regression metrics
96+
r2 = r2_score(y, y_pred)
97+
mse = mean_squared_error(y, y_pred)
98+
rmse = np.sqrt(mse)
99+
mae = mean_absolute_error(y, y_pred)
100+
101+
# Create metrics dictionary
102+
metrics = {
103+
'r2': float(r2),
104+
'mse': float(mse),
105+
'rmse': float(rmse),
106+
'mae': float(mae)
107+
}
108+
109+
return model, metrics
110+
111+
112+
def _validate_inputs(X: pd.DataFrame, y: pd.Series) -> None:
113+
"""
114+
Validate input data for model training.
115+
116+
Ensures that X and y meet all requirements for training:
117+
- Correct types (DataFrame and Series)
118+
- Non-empty
119+
- No NaN values
120+
- Matching shapes (same number of samples)
121+
122+
Parameters
123+
----------
124+
X : pd.DataFrame
125+
Features DataFrame to validate.
126+
y : pd.Series
127+
Target Series to validate.
128+
129+
Raises
130+
------
131+
TypeError
132+
If X is not a pandas DataFrame or y is not a pandas Series.
133+
ValueError
134+
If X or y are empty, contain NaN values, or have mismatched shapes.
135+
"""
136+
# Check types
137+
if not isinstance(X, pd.DataFrame):
138+
raise TypeError(
139+
f"X must be a pandas DataFrame, got {type(X).__name__} instead."
140+
)
141+
142+
if not isinstance(y, pd.Series):
143+
raise TypeError(
144+
f"y must be a pandas Series, got {type(y).__name__} instead."
145+
)
146+
147+
# Check if empty
148+
if X.empty:
149+
raise ValueError("X DataFrame is empty (no rows).")
150+
151+
if len(y) == 0:
152+
raise ValueError("y Series is empty (no values).")
153+
154+
if X.shape[1] == 0:
155+
raise ValueError("X DataFrame has no columns (no features).")
156+
157+
# Check for NaN values
158+
if X.isna().any().any():
159+
nan_columns = X.columns[X.isna().any()].tolist()
160+
raise ValueError(
161+
f"X contains NaN values. Columns with NaN: {nan_columns}. "
162+
"Please preprocess the data to handle missing values."
163+
)
164+
165+
if y.isna().any():
166+
raise ValueError(
167+
"y contains NaN values. Please preprocess the data to handle missing values."
168+
)
169+
170+
# Check shape matching
171+
if X.shape[0] != len(y):
172+
raise ValueError(
173+
f"Shape mismatch: X has {X.shape[0]} samples but y has {len(y)} samples. "
174+
"X and y must have the same number of samples."
175+
)
176+
177+
# Check for at least 2 samples (minimum for regression)
178+
if X.shape[0] < 2:
179+
raise ValueError(
180+
f"Insufficient data: only {X.shape[0]} sample(s) provided. "
181+
"At least 2 samples are required for Linear Regression."
182+
)

0 commit comments

Comments
 (0)