Skip to content

Commit a54baba

Browse files
feat(model): add predict function with input validation and preprocessing pipeline support
1 parent f276a27 commit a54baba

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

model.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,151 @@ def load_model(load_path: str) -> Dict[str, Any]:
519519
_check_version_compatibility(metadata['sklearn_version'])
520520

521521
return model_bundle
522+
523+
524+
def predict(
525+
model: LinearRegression,
526+
preprocessing_pipeline: Pipeline,
527+
raw_features_df: pd.DataFrame,
528+
expected_feature_names: List[str]
529+
) -> pd.Series:
530+
"""
531+
Generate predictions on new data using saved model and preprocessing.
532+
533+
This function takes raw input features and applies the same preprocessing
534+
transformations that were used during training, then generates predictions
535+
using the trained model. It validates that the input data has the correct
536+
feature columns and handles column reordering to match the training data.
537+
538+
Parameters
539+
----------
540+
model : LinearRegression
541+
Trained LinearRegression model instance.
542+
preprocessing_pipeline : Pipeline
543+
Fitted sklearn Pipeline (imputer + scaler) that was used during training.
544+
Must be already fitted on training data.
545+
raw_features_df : pd.DataFrame
546+
pandas DataFrame with raw feature data to make predictions on.
547+
Must contain all expected feature columns (order doesn't matter).
548+
May contain extra columns which will be ignored.
549+
expected_feature_names : List[str]
550+
List of feature column names from training data.
551+
These are the columns the model expects to see.
552+
553+
Returns
554+
-------
555+
pd.Series
556+
pandas Series with predictions, indexed to match the input DataFrame.
557+
558+
Raises
559+
------
560+
TypeError
561+
If inputs are not of the expected types.
562+
ValueError
563+
If feature columns don't match training data (missing columns).
564+
If raw_features_df is empty.
565+
If all features are NaN after preprocessing.
566+
567+
Examples
568+
--------
569+
>>> # Load a saved model
570+
>>> model_bundle = load_model('models/my_model.joblib')
571+
>>> model = model_bundle['model']
572+
>>> pipeline = model_bundle['pipeline']
573+
>>> feature_names = model_bundle['metadata']['original_feature_names']
574+
>>>
575+
>>> # Make predictions on new data
576+
>>> predictions = predict(model, pipeline, new_data_df, feature_names)
577+
>>> print(predictions)
578+
579+
Notes
580+
-----
581+
- The preprocessing pipeline is applied using transform() (not fit_transform!)
582+
- Column order in raw_features_df doesn't matter; columns are reordered automatically
583+
- Extra columns in the input are ignored
584+
- Missing columns raise a clear error with details about what's missing
585+
- The function preserves the index from the input DataFrame
586+
- Edge cases like empty DataFrames and all-NaN features are handled gracefully
587+
"""
588+
# Validate input types
589+
if not isinstance(model, LinearRegression):
590+
raise TypeError(
591+
f"model must be a LinearRegression instance, got {type(model).__name__} instead."
592+
)
593+
594+
if not isinstance(preprocessing_pipeline, Pipeline):
595+
raise TypeError(
596+
f"preprocessing_pipeline must be a Pipeline instance, got {type(preprocessing_pipeline).__name__} instead."
597+
)
598+
599+
if not isinstance(raw_features_df, pd.DataFrame):
600+
raise TypeError(
601+
f"raw_features_df must be a pandas DataFrame, got {type(raw_features_df).__name__} instead."
602+
)
603+
604+
if not isinstance(expected_feature_names, list):
605+
raise TypeError(
606+
f"expected_feature_names must be a list, got {type(expected_feature_names).__name__} instead."
607+
)
608+
609+
# Validate non-empty DataFrame
610+
if raw_features_df.empty:
611+
raise ValueError("raw_features_df is empty (no rows). Cannot make predictions on empty data.")
612+
613+
if len(expected_feature_names) == 0:
614+
raise ValueError("expected_feature_names is empty. Cannot validate features.")
615+
616+
# Validate feature columns
617+
expected_set = set(expected_feature_names)
618+
actual_set = set(raw_features_df.columns)
619+
620+
missing_features = expected_set - actual_set
621+
extra_features = actual_set - expected_set
622+
623+
if missing_features:
624+
raise ValueError(
625+
f"Expected features: {sorted(expected_feature_names)}, "
626+
f"got: {sorted(raw_features_df.columns.tolist())}. "
627+
f"Missing: {sorted(missing_features)}, "
628+
f"extra: {sorted(extra_features)}"
629+
)
630+
631+
# Warn about extra features if present
632+
if extra_features:
633+
warnings.warn(
634+
f"Input data contains extra columns that will be ignored: {sorted(extra_features)}",
635+
UserWarning,
636+
stacklevel=2
637+
)
638+
639+
# Select and reorder columns to match training data order
640+
features_df = raw_features_df[expected_feature_names].copy()
641+
642+
# Apply preprocessing pipeline (using transform, not fit_transform!)
643+
try:
644+
preprocessed_features = preprocessing_pipeline.transform(features_df)
645+
except Exception as e:
646+
raise RuntimeError(
647+
f"Failed to apply preprocessing pipeline: {str(e)}. "
648+
f"Ensure the pipeline is fitted and compatible with the input data."
649+
) from e
650+
651+
# Check for all-NaN features after preprocessing
652+
if np.isnan(preprocessed_features).all():
653+
raise ValueError(
654+
"All features are NaN after preprocessing. "
655+
"Check that input data contains valid numeric values."
656+
)
657+
658+
# Generate predictions
659+
try:
660+
predictions_array = model.predict(preprocessed_features)
661+
except Exception as e:
662+
raise RuntimeError(
663+
f"Failed to generate predictions: {str(e)}"
664+
) from e
665+
666+
# Return predictions as pandas Series with original index
667+
predictions = pd.Series(predictions_array, index=raw_features_df.index, name='predictions')
668+
669+
return predictions

0 commit comments

Comments
 (0)