@@ -96,9 +96,11 @@ def train(input_path, target_column, output_model_path, report_path):
9696 click .echo (f"✓ Preprocessing complete: { X_transformed .shape [1 ]} transformed features" )
9797
9898 # Convert transformed data to DataFrame for model training
99+ # Use feature names from pipeline to ensure consistency with prediction
100+ feature_names_from_pipeline = fitted_pipeline .get_feature_names_out ().tolist ()
99101 X_transformed_df = pd .DataFrame (
100102 X_transformed ,
101- columns = [ f'feature_ { i } ' for i in range ( X_transformed . shape [ 1 ])]
103+ columns = feature_names_from_pipeline
102104 )
103105 except Exception as e :
104106 click .echo (f"✗ Error during preprocessing: { str (e )} " , err = True )
@@ -264,14 +266,176 @@ def train(input_path, target_column, output_model_path, report_path):
264266)
265267def predict (model_path , input_path , output_path ):
266268 """Make predictions using a trained model."""
267- # Placeholder function - print arguments for now
268269 click .echo ("=== Predict Command ===" )
269270 click .echo (f"Model Path: { model_path } " )
270271 click .echo (f"Input CSV: { input_path } " )
271272 click .echo (f"Output Path: { output_path } " )
273+ click .echo ("" )
272274
273- # TODO: Add validation for output path (check if writable)
274- # TODO: Implement actual prediction logic
275+ try :
276+ # Step 1: Load model bundle
277+ click .echo ("Step 1: Loading model..." )
278+ try :
279+ from model import load_model , predict as make_predictions
280+ model_bundle = load_model (model_path )
281+
282+ # Extract components
283+ model = model_bundle ['model' ]
284+ pipeline = model_bundle ['pipeline' ]
285+ metadata = model_bundle ['metadata' ]
286+ expected_feature_names = metadata ['original_feature_names' ]
287+ target_name = metadata ['target_name' ]
288+
289+ click .echo (f"✓ Model loaded successfully" )
290+ click .echo (f" - Target variable: '{ target_name } '" )
291+ click .echo (f" - Expected features: { len (expected_feature_names )} " )
292+ click .echo (f" - Trained on: { metadata .get ('training_timestamp' , 'N/A' )} " )
293+ except FileNotFoundError as e :
294+ click .echo (f"✗ Error: Model file not found: { model_path } " , err = True )
295+ click .echo (" Suggestion: Check that the model file path is correct and the file exists." , err = True )
296+ raise click .Abort ()
297+ except (EOFError , ValueError ) as e :
298+ click .echo (f"✗ Error loading model: { str (e )} " , err = True )
299+ click .echo (" Suggestion: The model file may be corrupted or incompatible." , err = True )
300+ raise click .Abort ()
301+ except Exception as e :
302+ click .echo (f"✗ Unexpected error loading model: { str (e )} " , err = True )
303+ raise click .Abort ()
304+
305+ # Step 2: Load input CSV
306+ click .echo ("\n Step 2: Loading input data..." )
307+ try :
308+ input_df = pd .read_csv (input_path )
309+
310+ if input_df .empty :
311+ click .echo (f"✗ Error: Input CSV file is empty: { input_path } " , err = True )
312+ click .echo (" Suggestion: Ensure the CSV file contains data rows." , err = True )
313+ raise click .Abort ()
314+
315+ click .echo (f"✓ Input data loaded successfully: { input_df .shape [0 ]} samples, { input_df .shape [1 ]} features" )
316+ except FileNotFoundError as e :
317+ click .echo (f"✗ Error: Input file not found: { input_path } " , err = True )
318+ click .echo (" Suggestion: Check that the input file path is correct and the file exists." , err = True )
319+ raise click .Abort ()
320+ except pd .errors .EmptyDataError as e :
321+ click .echo (f"✗ Error: Input CSV file is empty: { input_path } " , err = True )
322+ click .echo (" Suggestion: Ensure the CSV file contains data." , err = True )
323+ raise click .Abort ()
324+ except Exception as e :
325+ click .echo (f"✗ Error reading input CSV: { str (e )} " , err = True )
326+ raise click .Abort ()
327+
328+ # Step 3: Validate features and make predictions
329+ click .echo ("\n Step 3: Making predictions..." )
330+ try :
331+ predictions = make_predictions (
332+ model = model ,
333+ preprocessing_pipeline = pipeline ,
334+ raw_features_df = input_df ,
335+ expected_feature_names = expected_feature_names
336+ )
337+ click .echo (f"✓ Predictions generated successfully: { len (predictions )} predictions" )
338+ except ValueError as e :
339+ error_msg = str (e )
340+ if "Missing:" in error_msg :
341+ click .echo (f"✗ Error: Feature mismatch between input data and trained model" , err = True )
342+ click .echo (f" { error_msg } " , err = True )
343+ click .echo (" Suggestion: Ensure the input CSV contains all required feature columns." , err = True )
344+ else :
345+ click .echo (f"✗ Error: { error_msg } " , err = True )
346+ raise click .Abort ()
347+ except Exception as e :
348+ click .echo (f"✗ Error making predictions: { str (e )} " , err = True )
349+ raise click .Abort ()
350+
351+ # Step 4: Create output DataFrame
352+ click .echo ("\n Step 4: Creating output file..." )
353+ try :
354+ # Check if predictions contain all NaN values
355+ if predictions .isna ().all ():
356+ click .echo (f"✗ Error: All predictions are NaN" , err = True )
357+ click .echo (" Suggestion: Check that input data contains valid numeric values." , err = True )
358+ raise click .Abort ()
359+
360+ # Create output DataFrame with original data + predictions
361+ output_df = input_df .copy ()
362+ prediction_column_name = f"predicted_{ target_name } "
363+ output_df [prediction_column_name ] = predictions .values
364+
365+ click .echo (f"✓ Output DataFrame created with column '{ prediction_column_name } '" )
366+ except Exception as e :
367+ click .echo (f"✗ Error creating output DataFrame: { str (e )} " , err = True )
368+ raise click .Abort ()
369+
370+ # Step 5: Save output CSV
371+ click .echo ("\n Step 5: Saving predictions..." )
372+ try :
373+ # Create parent directories if they don't exist
374+ output_path_obj = Path (output_path )
375+ output_path_obj .parent .mkdir (parents = True , exist_ok = True )
376+
377+ output_df .to_csv (output_path , index = False )
378+ click .echo (f"✓ Predictions saved to: { output_path } " )
379+ except PermissionError as e :
380+ click .echo (f"✗ Error: Permission denied writing to: { output_path } " , err = True )
381+ click .echo (" Suggestion: Check that you have write permissions for the output path." , err = True )
382+ raise click .Abort ()
383+ except Exception as e :
384+ click .echo (f"✗ Error saving predictions: { str (e )} " , err = True )
385+ click .echo (" Suggestion: Check that the output path is valid and writable." , err = True )
386+ raise click .Abort ()
387+
388+ # Step 6: Calculate and print summary statistics
389+ click .echo ("\n Step 6: Calculating summary statistics..." )
390+ try :
391+ # Calculate statistics, excluding NaN values
392+ valid_predictions = predictions .dropna ()
393+
394+ if len (valid_predictions ) == 0 :
395+ click .echo (f"⚠ Warning: All predictions are NaN, cannot calculate statistics" , err = True )
396+ else :
397+ stats = {
398+ 'count' : len (valid_predictions ),
399+ 'mean' : float (valid_predictions .mean ()),
400+ 'median' : float (valid_predictions .median ()),
401+ 'std' : float (valid_predictions .std ()),
402+ 'min' : float (valid_predictions .min ()),
403+ 'max' : float (valid_predictions .max ())
404+ }
405+
406+ click .echo ("✓ Summary statistics calculated" )
407+ click .echo (f" - Count: { stats ['count' ]} " )
408+ click .echo (f" - Mean: { stats ['mean' ]:.4f} " )
409+ click .echo (f" - Median: { stats ['median' ]:.4f} " )
410+ click .echo (f" - Std Dev: { stats ['std' ]:.4f} " )
411+ click .echo (f" - Min: { stats ['min' ]:.4f} " )
412+ click .echo (f" - Max: { stats ['max' ]:.4f} " )
413+ except Exception as e :
414+ click .echo (f"✗ Error calculating statistics: { str (e )} " , err = True )
415+ # Don't abort here, predictions are already saved
416+
417+ # Step 7: Print success message
418+ click .echo ("\n " + "=" * 60 )
419+ click .echo ("🎉 Prediction completed successfully!" )
420+ click .echo ("=" * 60 )
421+ click .echo (f"\n 📊 Prediction Summary:" )
422+ click .echo (f" - Output file: { output_path } " )
423+ click .echo (f" - Number of predictions: { len (predictions )} " )
424+ if len (valid_predictions ) > 0 :
425+ click .echo (f" - Prediction column: '{ prediction_column_name } '" )
426+ click .echo (f"\n 📈 Statistics:" )
427+ click .echo (f" - Mean: { stats ['mean' ]:.4f} " )
428+ click .echo (f" - Median: { stats ['median' ]:.4f} " )
429+ click .echo (f" - Range: [{ stats ['min' ]:.4f} , { stats ['max' ]:.4f} ]" )
430+ click .echo ("" )
431+
432+ except click .Abort :
433+ # Already handled above
434+ raise
435+ except Exception as e :
436+ click .echo (f"\n ✗ Unexpected error: { str (e )} " , err = True )
437+ click .echo (" Please check the error message above for details." , err = True )
438+ raise click .Abort ()
275439
276440
277441if __name__ == '__main__' :
0 commit comments