Skip to content

Commit 4718e85

Browse files
feat(model): add model persistence with save and load functionality for serialization and metadata tracking
1 parent 87eb5ac commit 4718e85

File tree

1 file changed

+340
-1
lines changed

1 file changed

+340
-1
lines changed

model.py

Lines changed: 340 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,25 @@
44
on preprocessed data and calculate comprehensive regression metrics. The module
55
is designed to work with data that has already been preprocessed (scaled and imputed)
66
by the preprocessing pipeline.
7+
8+
Additionally, this module provides model persistence functionality to save and load
9+
trained models along with their preprocessing pipelines and metadata for reproducibility.
710
"""
811

912
import pandas as pd
1013
import numpy as np
14+
import joblib
15+
import pickle
16+
import sys
17+
import platform
18+
import warnings
19+
from datetime import datetime
20+
from pathlib import Path
1121
from sklearn.linear_model import LinearRegression
1222
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
13-
from typing import Dict, Tuple
23+
from sklearn.pipeline import Pipeline
24+
from typing import Dict, Tuple, List, Optional, Any
25+
import sklearn
1426

1527

1628
def train_model(X: pd.DataFrame, y: pd.Series) -> Tuple[LinearRegression, Dict[str, float]]:
@@ -180,3 +192,330 @@ def _validate_inputs(X: pd.DataFrame, y: pd.Series) -> None:
180192
f"Insufficient data: only {X.shape[0]} sample(s) provided. "
181193
"At least 2 samples are required for Linear Regression."
182194
)
195+
196+
197+
def _get_feature_names_out(pipeline: Pipeline, original_features: List[str]) -> List[str]:
198+
"""
199+
Extract feature names after preprocessing pipeline transformation.
200+
201+
This helper function attempts to retrieve the feature names that result from
202+
applying the preprocessing pipeline. This is particularly important for pipelines
203+
that include transformations like OneHotEncoding which generate new feature names.
204+
205+
Parameters
206+
----------
207+
pipeline : Pipeline
208+
A fitted preprocessing pipeline.
209+
original_features : List[str]
210+
Original feature names before preprocessing.
211+
212+
Returns
213+
-------
214+
List[str]
215+
Feature names after preprocessing transformation. If feature names cannot
216+
be extracted (e.g., pipeline doesn't support get_feature_names_out), returns
217+
the original feature names.
218+
219+
Notes
220+
-----
221+
- Attempts to use get_feature_names_out() method if available (sklearn >= 1.0)
222+
- Falls back to original feature names if method is not available
223+
- Useful for understanding what features the model is actually using
224+
"""
225+
try:
226+
# Try to get feature names from the pipeline
227+
if hasattr(pipeline, 'get_feature_names_out'):
228+
feature_names = pipeline.get_feature_names_out()
229+
return list(feature_names)
230+
except Exception:
231+
# If anything fails, fall back to original names
232+
pass
233+
234+
# Fallback to original feature names
235+
return original_features
236+
237+
238+
def save_model(
239+
model: LinearRegression,
240+
pipeline: Pipeline,
241+
feature_names: List[str],
242+
target_name: str,
243+
save_path: str
244+
) -> None:
245+
"""
246+
Save a trained model, preprocessing pipeline, and metadata to disk.
247+
248+
This function serializes a trained LinearRegression model along with its
249+
preprocessing pipeline and comprehensive metadata for reproducibility. The
250+
saved file can be loaded later using load_model() to recreate the exact
251+
training environment.
252+
253+
File Format
254+
-----------
255+
The saved file is a dictionary containing three keys:
256+
- 'model': The trained LinearRegression object
257+
- 'pipeline': The fitted preprocessing Pipeline object
258+
- 'metadata': Dictionary with training information including:
259+
- 'original_feature_names': Feature names before preprocessing
260+
- 'transformed_feature_names': Feature names after preprocessing
261+
- 'target_name': Name of the target variable
262+
- 'training_timestamp': ISO 8601 formatted timestamp
263+
- 'sklearn_version': Version of scikit-learn used for training
264+
- 'python_version': Python version used for training
265+
266+
Parameters
267+
----------
268+
model : LinearRegression
269+
Trained LinearRegression model to save.
270+
pipeline : Pipeline
271+
Fitted preprocessing pipeline used to transform training data.
272+
feature_names : List[str]
273+
Original feature names (before preprocessing transformations).
274+
target_name : str
275+
Name of the target variable/column.
276+
save_path : str
277+
File path where the model bundle should be saved. Parent directories
278+
will be created if they don't exist.
279+
280+
Raises
281+
------
282+
TypeError
283+
If model is not a LinearRegression instance or pipeline is not a Pipeline.
284+
ValueError
285+
If feature_names is empty or target_name is empty/None.
286+
OSError
287+
If the save path directory cannot be created or file cannot be written.
288+
289+
Examples
290+
--------
291+
>>> from sklearn.linear_model import LinearRegression
292+
>>> from sklearn.pipeline import Pipeline
293+
>>>
294+
>>> # Assume model and pipeline are already trained
295+
>>> save_model(
296+
... model=trained_model,
297+
... pipeline=fitted_pipeline,
298+
... feature_names=['age', 'income', 'credit_score'],
299+
... target_name='loan_amount',
300+
... save_path='models/my_model.joblib'
301+
... )
302+
>>> print("Model saved successfully!")
303+
304+
Notes
305+
-----
306+
- Uses joblib for efficient serialization of sklearn objects
307+
- Saved files are cross-platform compatible
308+
- File extension .joblib or .pkl is recommended but not enforced
309+
- Metadata enables version compatibility checks during loading
310+
- Parent directories are created automatically if they don't exist
311+
- Original feature names are preserved to ensure correct column ordering during prediction
312+
"""
313+
# Validate inputs
314+
if not isinstance(model, LinearRegression):
315+
raise TypeError(
316+
f"model must be a LinearRegression instance, got {type(model).__name__} instead."
317+
)
318+
319+
if not isinstance(pipeline, Pipeline):
320+
raise TypeError(
321+
f"pipeline must be a Pipeline instance, got {type(pipeline).__name__} instead."
322+
)
323+
324+
if not feature_names or len(feature_names) == 0:
325+
raise ValueError("feature_names cannot be empty.")
326+
327+
if not target_name or not isinstance(target_name, str):
328+
raise ValueError("target_name must be a non-empty string.")
329+
330+
# Create metadata dictionary
331+
metadata = {
332+
'original_feature_names': feature_names,
333+
'transformed_feature_names': _get_feature_names_out(pipeline, feature_names),
334+
'target_name': target_name,
335+
'training_timestamp': datetime.now().isoformat(),
336+
'sklearn_version': sklearn.__version__,
337+
'python_version': platform.python_version()
338+
}
339+
340+
# Create model bundle
341+
model_bundle = {
342+
'model': model,
343+
'pipeline': pipeline,
344+
'metadata': metadata
345+
}
346+
347+
# Create parent directories if they don't exist
348+
save_path_obj = Path(save_path)
349+
save_path_obj.parent.mkdir(parents=True, exist_ok=True)
350+
351+
# Save using joblib
352+
try:
353+
joblib.dump(model_bundle, save_path)
354+
except Exception as e:
355+
raise OSError(f"Failed to save model to {save_path}: {str(e)}") from e
356+
357+
358+
def _check_version_compatibility(saved_sklearn_version: str) -> None:
359+
"""
360+
Check version compatibility and issue warnings if versions differ.
361+
362+
Compares the current scikit-learn version with the version used to train
363+
the saved model. Issues a warning if there's a mismatch, as this could
364+
potentially lead to compatibility issues or different prediction results.
365+
366+
Parameters
367+
----------
368+
saved_sklearn_version : str
369+
The scikit-learn version string from the saved model metadata.
370+
371+
Notes
372+
-----
373+
- Uses Python's warnings module to issue version mismatch warnings
374+
- Warnings are issued at the UserWarning level
375+
- Major version differences are more likely to cause issues than minor ones
376+
"""
377+
current_version = sklearn.__version__
378+
379+
if current_version != saved_sklearn_version:
380+
warnings.warn(
381+
f"Scikit-learn version mismatch: Model was trained with version "
382+
f"{saved_sklearn_version}, but current version is {current_version}. "
383+
f"This may lead to compatibility issues or different prediction results.",
384+
UserWarning,
385+
stacklevel=3
386+
)
387+
388+
389+
def load_model(load_path: str) -> Dict[str, Any]:
390+
"""
391+
Load a saved model, preprocessing pipeline, and metadata from disk.
392+
393+
This function deserializes a model bundle previously saved with save_model(),
394+
validates its structure, and checks for version compatibility. The returned
395+
dictionary contains the model, pipeline, and metadata needed for making
396+
predictions on new data.
397+
398+
Parameters
399+
----------
400+
load_path : str
401+
File path to the saved model bundle (.joblib or .pkl file).
402+
403+
Returns
404+
-------
405+
Dict[str, Any]
406+
Dictionary containing three keys:
407+
- 'model': The trained LinearRegression object
408+
- 'pipeline': The fitted preprocessing Pipeline object
409+
- 'metadata': Dictionary with training information:
410+
- 'original_feature_names': Feature names before preprocessing
411+
- 'transformed_feature_names': Feature names after preprocessing
412+
- 'target_name': Name of the target variable
413+
- 'training_timestamp': ISO 8601 formatted timestamp
414+
- 'sklearn_version': Version of scikit-learn used for training
415+
- 'python_version': Python version used for training
416+
417+
Raises
418+
------
419+
FileNotFoundError
420+
If the specified load_path does not exist.
421+
ValueError
422+
If the loaded object doesn't have the expected structure (missing keys).
423+
EOFError
424+
If the file is corrupted or truncated.
425+
pickle.UnpicklingError
426+
If the file cannot be deserialized (corrupted or incompatible format).
427+
428+
Warnings
429+
--------
430+
UserWarning
431+
Issued if the scikit-learn version differs from the one used for training.
432+
433+
Examples
434+
--------
435+
>>> # Load a previously saved model
436+
>>> model_bundle = load_model('models/my_model.joblib')
437+
>>>
438+
>>> # Extract components
439+
>>> model = model_bundle['model']
440+
>>> pipeline = model_bundle['pipeline']
441+
>>> metadata = model_bundle['metadata']
442+
>>>
443+
>>> # Check metadata
444+
>>> print(f"Model trained on: {metadata['training_timestamp']}")
445+
>>> print(f"Features: {metadata['original_feature_names']}")
446+
>>> print(f"Target: {metadata['target_name']}")
447+
>>>
448+
>>> # Use for predictions
449+
>>> X_new_preprocessed = pipeline.transform(X_new)
450+
>>> predictions = model.predict(X_new_preprocessed)
451+
452+
Notes
453+
-----
454+
- The model bundle must have been created using save_model() function
455+
- Version compatibility warnings help identify potential issues
456+
- The pipeline is already fitted and ready to transform new data
457+
- Original feature names help ensure correct column ordering
458+
- Cross-platform compatible (can load models saved on different OS)
459+
"""
460+
# Check if file exists
461+
if not Path(load_path).exists():
462+
raise FileNotFoundError(
463+
f"Model file not found at path: {load_path}. "
464+
f"Please check that the file exists and the path is correct."
465+
)
466+
467+
# Load the model bundle
468+
try:
469+
model_bundle = joblib.load(load_path)
470+
except EOFError as e:
471+
raise EOFError(
472+
f"Failed to load model from {load_path}: File appears to be corrupted or truncated. "
473+
f"The file may have been incompletely written or damaged."
474+
) from e
475+
except pickle.UnpicklingError as e:
476+
raise pickle.UnpicklingError(
477+
f"Failed to deserialize model from {load_path}: File format is invalid or incompatible. "
478+
f"The file may be corrupted or created with an incompatible version."
479+
) from e
480+
except Exception as e:
481+
raise RuntimeError(
482+
f"Unexpected error loading model from {load_path}: {str(e)}"
483+
) from e
484+
485+
# Validate structure
486+
if not isinstance(model_bundle, dict):
487+
raise ValueError(
488+
f"Loaded object is not a dictionary. Expected a model bundle with "
489+
f"'model', 'pipeline', and 'metadata' keys, but got {type(model_bundle).__name__}."
490+
)
491+
492+
required_keys = {'model', 'pipeline', 'metadata'}
493+
missing_keys = required_keys - set(model_bundle.keys())
494+
495+
if missing_keys:
496+
raise ValueError(
497+
f"Model bundle is missing required keys: {missing_keys}. "
498+
f"Expected keys: {required_keys}. Found keys: {set(model_bundle.keys())}."
499+
)
500+
501+
# Validate metadata structure
502+
metadata = model_bundle['metadata']
503+
required_metadata_keys = {
504+
'original_feature_names',
505+
'transformed_feature_names',
506+
'target_name',
507+
'training_timestamp',
508+
'sklearn_version'
509+
}
510+
missing_metadata_keys = required_metadata_keys - set(metadata.keys())
511+
512+
if missing_metadata_keys:
513+
raise ValueError(
514+
f"Metadata is missing required keys: {missing_metadata_keys}. "
515+
f"The model bundle may have been created with an older version of this module."
516+
)
517+
518+
# Check version compatibility
519+
_check_version_compatibility(metadata['sklearn_version'])
520+
521+
return model_bundle

0 commit comments

Comments
 (0)