Skip to content

Commit 1a9da27

Browse files
feat(cli): implement predict subcommand with model inference and feature name consistency
Add predict subcommand with model loading, feature validation, prediction generation, and output formatting. Fix feature naming to use pipeline's get_feature_n
1 parent 6f7f64e commit 1a9da27

File tree

2 files changed

+178
-5
lines changed

2 files changed

+178
-5
lines changed

cli.py

Lines changed: 168 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ def train(input_path, target_column, output_model_path, report_path):
9696
click.echo(f"✓ Preprocessing complete: {X_transformed.shape[1]} transformed features")
9797

9898
# Convert transformed data to DataFrame for model training
99+
# Use feature names from pipeline to ensure consistency with prediction
100+
feature_names_from_pipeline = fitted_pipeline.get_feature_names_out().tolist()
99101
X_transformed_df = pd.DataFrame(
100102
X_transformed,
101-
columns=[f'feature_{i}' for i in range(X_transformed.shape[1])]
103+
columns=feature_names_from_pipeline
102104
)
103105
except Exception as e:
104106
click.echo(f"✗ Error during preprocessing: {str(e)}", err=True)
@@ -264,14 +266,176 @@ def train(input_path, target_column, output_model_path, report_path):
264266
)
265267
def predict(model_path, input_path, output_path):
266268
"""Make predictions using a trained model."""
267-
# Placeholder function - print arguments for now
268269
click.echo("=== Predict Command ===")
269270
click.echo(f"Model Path: {model_path}")
270271
click.echo(f"Input CSV: {input_path}")
271272
click.echo(f"Output Path: {output_path}")
273+
click.echo("")
272274

273-
# TODO: Add validation for output path (check if writable)
274-
# TODO: Implement actual prediction logic
275+
try:
276+
# Step 1: Load model bundle
277+
click.echo("Step 1: Loading model...")
278+
try:
279+
from model import load_model, predict as make_predictions
280+
model_bundle = load_model(model_path)
281+
282+
# Extract components
283+
model = model_bundle['model']
284+
pipeline = model_bundle['pipeline']
285+
metadata = model_bundle['metadata']
286+
expected_feature_names = metadata['original_feature_names']
287+
target_name = metadata['target_name']
288+
289+
click.echo(f"✓ Model loaded successfully")
290+
click.echo(f" - Target variable: '{target_name}'")
291+
click.echo(f" - Expected features: {len(expected_feature_names)}")
292+
click.echo(f" - Trained on: {metadata.get('training_timestamp', 'N/A')}")
293+
except FileNotFoundError as e:
294+
click.echo(f"✗ Error: Model file not found: {model_path}", err=True)
295+
click.echo(" Suggestion: Check that the model file path is correct and the file exists.", err=True)
296+
raise click.Abort()
297+
except (EOFError, ValueError) as e:
298+
click.echo(f"✗ Error loading model: {str(e)}", err=True)
299+
click.echo(" Suggestion: The model file may be corrupted or incompatible.", err=True)
300+
raise click.Abort()
301+
except Exception as e:
302+
click.echo(f"✗ Unexpected error loading model: {str(e)}", err=True)
303+
raise click.Abort()
304+
305+
# Step 2: Load input CSV
306+
click.echo("\nStep 2: Loading input data...")
307+
try:
308+
input_df = pd.read_csv(input_path)
309+
310+
if input_df.empty:
311+
click.echo(f"✗ Error: Input CSV file is empty: {input_path}", err=True)
312+
click.echo(" Suggestion: Ensure the CSV file contains data rows.", err=True)
313+
raise click.Abort()
314+
315+
click.echo(f"✓ Input data loaded successfully: {input_df.shape[0]} samples, {input_df.shape[1]} features")
316+
except FileNotFoundError as e:
317+
click.echo(f"✗ Error: Input file not found: {input_path}", err=True)
318+
click.echo(" Suggestion: Check that the input file path is correct and the file exists.", err=True)
319+
raise click.Abort()
320+
except pd.errors.EmptyDataError as e:
321+
click.echo(f"✗ Error: Input CSV file is empty: {input_path}", err=True)
322+
click.echo(" Suggestion: Ensure the CSV file contains data.", err=True)
323+
raise click.Abort()
324+
except Exception as e:
325+
click.echo(f"✗ Error reading input CSV: {str(e)}", err=True)
326+
raise click.Abort()
327+
328+
# Step 3: Validate features and make predictions
329+
click.echo("\nStep 3: Making predictions...")
330+
try:
331+
predictions = make_predictions(
332+
model=model,
333+
preprocessing_pipeline=pipeline,
334+
raw_features_df=input_df,
335+
expected_feature_names=expected_feature_names
336+
)
337+
click.echo(f"✓ Predictions generated successfully: {len(predictions)} predictions")
338+
except ValueError as e:
339+
error_msg = str(e)
340+
if "Missing:" in error_msg:
341+
click.echo(f"✗ Error: Feature mismatch between input data and trained model", err=True)
342+
click.echo(f" {error_msg}", err=True)
343+
click.echo(" Suggestion: Ensure the input CSV contains all required feature columns.", err=True)
344+
else:
345+
click.echo(f"✗ Error: {error_msg}", err=True)
346+
raise click.Abort()
347+
except Exception as e:
348+
click.echo(f"✗ Error making predictions: {str(e)}", err=True)
349+
raise click.Abort()
350+
351+
# Step 4: Create output DataFrame
352+
click.echo("\nStep 4: Creating output file...")
353+
try:
354+
# Check if predictions contain all NaN values
355+
if predictions.isna().all():
356+
click.echo(f"✗ Error: All predictions are NaN", err=True)
357+
click.echo(" Suggestion: Check that input data contains valid numeric values.", err=True)
358+
raise click.Abort()
359+
360+
# Create output DataFrame with original data + predictions
361+
output_df = input_df.copy()
362+
prediction_column_name = f"predicted_{target_name}"
363+
output_df[prediction_column_name] = predictions.values
364+
365+
click.echo(f"✓ Output DataFrame created with column '{prediction_column_name}'")
366+
except Exception as e:
367+
click.echo(f"✗ Error creating output DataFrame: {str(e)}", err=True)
368+
raise click.Abort()
369+
370+
# Step 5: Save output CSV
371+
click.echo("\nStep 5: Saving predictions...")
372+
try:
373+
# Create parent directories if they don't exist
374+
output_path_obj = Path(output_path)
375+
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
376+
377+
output_df.to_csv(output_path, index=False)
378+
click.echo(f"✓ Predictions saved to: {output_path}")
379+
except PermissionError as e:
380+
click.echo(f"✗ Error: Permission denied writing to: {output_path}", err=True)
381+
click.echo(" Suggestion: Check that you have write permissions for the output path.", err=True)
382+
raise click.Abort()
383+
except Exception as e:
384+
click.echo(f"✗ Error saving predictions: {str(e)}", err=True)
385+
click.echo(" Suggestion: Check that the output path is valid and writable.", err=True)
386+
raise click.Abort()
387+
388+
# Step 6: Calculate and print summary statistics
389+
click.echo("\nStep 6: Calculating summary statistics...")
390+
try:
391+
# Calculate statistics, excluding NaN values
392+
valid_predictions = predictions.dropna()
393+
394+
if len(valid_predictions) == 0:
395+
click.echo(f"⚠ Warning: All predictions are NaN, cannot calculate statistics", err=True)
396+
else:
397+
stats = {
398+
'count': len(valid_predictions),
399+
'mean': float(valid_predictions.mean()),
400+
'median': float(valid_predictions.median()),
401+
'std': float(valid_predictions.std()),
402+
'min': float(valid_predictions.min()),
403+
'max': float(valid_predictions.max())
404+
}
405+
406+
click.echo("✓ Summary statistics calculated")
407+
click.echo(f" - Count: {stats['count']}")
408+
click.echo(f" - Mean: {stats['mean']:.4f}")
409+
click.echo(f" - Median: {stats['median']:.4f}")
410+
click.echo(f" - Std Dev: {stats['std']:.4f}")
411+
click.echo(f" - Min: {stats['min']:.4f}")
412+
click.echo(f" - Max: {stats['max']:.4f}")
413+
except Exception as e:
414+
click.echo(f"✗ Error calculating statistics: {str(e)}", err=True)
415+
# Don't abort here, predictions are already saved
416+
417+
# Step 7: Print success message
418+
click.echo("\n" + "=" * 60)
419+
click.echo("🎉 Prediction completed successfully!")
420+
click.echo("=" * 60)
421+
click.echo(f"\n📊 Prediction Summary:")
422+
click.echo(f" - Output file: {output_path}")
423+
click.echo(f" - Number of predictions: {len(predictions)}")
424+
if len(valid_predictions) > 0:
425+
click.echo(f" - Prediction column: '{prediction_column_name}'")
426+
click.echo(f"\n📈 Statistics:")
427+
click.echo(f" - Mean: {stats['mean']:.4f}")
428+
click.echo(f" - Median: {stats['median']:.4f}")
429+
click.echo(f" - Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
430+
click.echo("")
431+
432+
except click.Abort:
433+
# Already handled above
434+
raise
435+
except Exception as e:
436+
click.echo(f"\n✗ Unexpected error: {str(e)}", err=True)
437+
click.echo(" Please check the error message above for details.", err=True)
438+
raise click.Abort()
275439

276440

277441
if __name__ == '__main__':

model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,18 @@ def predict(
655655
"Check that input data contains valid numeric values."
656656
)
657657

658+
# Convert preprocessed features to DataFrame with feature names to avoid sklearn warning
659+
# Get the transformed feature names from the pipeline
660+
transformed_feature_names = _get_feature_names_out(preprocessing_pipeline, expected_feature_names)
661+
preprocessed_df = pd.DataFrame(
662+
preprocessed_features,
663+
columns=transformed_feature_names,
664+
index=raw_features_df.index
665+
)
666+
658667
# Generate predictions
659668
try:
660-
predictions_array = model.predict(preprocessed_features)
669+
predictions_array = model.predict(preprocessed_df)
661670
except Exception as e:
662671
raise RuntimeError(
663672
f"Failed to generate predictions: {str(e)}"

0 commit comments

Comments
 (0)