Skip to content

Commit d95e5e6

Browse files
docs: add comprehensive documentation, sample data, docstrings, and fix NaN handling
Add complete README with usage examples, create sample dataset with realistic data, document all functions with Google-style docstrings, and filter NaN target value
1 parent 492ce84 commit d95e5e6

File tree

4 files changed

+656
-33
lines changed

4 files changed

+656
-33
lines changed

README.md

Lines changed: 312 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,341 @@
1-
# template-python
1+
# ML Training and Prediction CLI
22

3-
A simple Python template repository using modern tooling with `uv` and `pyproject.toml`.
3+
A command-line tool for training Linear Regression models with automated preprocessing, comprehensive reporting, and easy prediction capabilities. Built with modern Python tooling using `uv` and scikit-learn.
44

55
## Overview
66

7-
This template provides a minimal starting point for Python projects using:
8-
- **uv**: Fast Python package manager and project manager
9-
- **pyproject.toml**: Modern Python project configuration (PEP 518/621)
10-
- Python 3.13+
7+
This CLI tool simplifies the end-to-end machine learning workflow for regression tasks. It handles data loading, preprocessing (missing value imputation, feature scaling, categorical encoding), model training, evaluation, and prediction—all through simple command-line commands. Training produces detailed HTML reports with visualizations to help you understand model performance and feature importance.
118

12-
## Prerequisites
9+
**Key Use Cases:**
10+
- Quick prototyping and baseline models for regression problems
11+
- Automated preprocessing pipelines with consistent train/test handling
12+
- Model training with reproducible results and comprehensive reports
13+
- Easy deployment of trained models for batch predictions
1314

14-
Install `uv` if you haven't already:
15+
## Features
1516

16-
```bash
17-
# macOS/Linux
18-
curl -LsSf https://astral.sh/uv/install.sh | sh
17+
- **Training (`train` command)**:
18+
- Trains Linear Regression models on CSV data
19+
- Automatically handles numeric and categorical features
20+
- Missing value imputation (mean for numeric, most frequent for categorical)
21+
- Feature scaling with StandardScaler
22+
- One-hot encoding for categorical variables
23+
- Saves trained model with preprocessing pipeline for reproducibility
1924

20-
# Windows
21-
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
25+
- **Prediction (`predict` command)**:
26+
- Loads saved models and applies same preprocessing automatically
27+
- Validates input data against training schema
28+
- Generates predictions with summary statistics
29+
- Outputs CSV with original data plus predictions
2230

23-
# Or via pip
24-
pip install uv
25-
```
31+
- **Automated Reporting**:
32+
- Generates self-contained HTML reports with embedded visualizations
33+
- Performance metrics (R², MSE, RMSE, MAE)
34+
- Feature importance (model coefficients)
35+
- Actual vs Predicted scatter plot
36+
- Residuals plot for model diagnostics
37+
- Coefficients bar chart
38+
39+
- **Data Quality Handling**:
40+
- Automatic detection and imputation of missing values
41+
- Validation of feature names and data types
42+
- Clear error messages for data issues
43+
44+
## Installation
2645

27-
For more installation options, see the [uv documentation](https://docs.astral.sh/uv/).
46+
1. **Install uv** (if you haven't already):
47+
```bash
48+
# macOS/Linux
49+
curl -LsSf https://astral.sh/uv/install.sh | sh
50+
51+
# Windows
52+
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
2853

29-
## Setup
54+
# Or via pip
55+
pip install uv
56+
```
3057

31-
1. Clone or use this template:
58+
2. **Clone the repository**:
3259
```bash
3360
git clone <repository-url>
34-
cd template-python
61+
cd ml-cli-tool
3562
```
3663

37-
2. Sync dependencies (creates virtual environment automatically):
64+
3. **Sync dependencies** (creates virtual environment automatically):
3865
```bash
3966
uv sync
4067
```
4168

42-
## Running the Template
69+
## Usage
70+
71+
### Training a Model
72+
73+
Train a model on your CSV data with the `train` command:
74+
75+
```bash
76+
uv run python main.py train \
77+
--input sample_data.csv \
78+
--target price \
79+
--output-model model.joblib \
80+
--report report.html
81+
```
82+
83+
**Example with Sample Data:**
84+
85+
The repository includes `sample_data.csv` with house price data (80 samples, 4 features):
86+
- `sqft`: Square footage (numeric)
87+
- `bedrooms`: Number of bedrooms (numeric)
88+
- `age`: Age of house in years (numeric)
89+
- `location_score`: Location quality score 1-10 (numeric)
90+
- `price`: Sale price in dollars (target variable)
91+
92+
```bash
93+
# Train on the sample data
94+
uv run python main.py train \
95+
--input sample_data.csv \
96+
--target price \
97+
--output-model house_model.joblib \
98+
--report house_report.html
99+
```
100+
101+
**Expected Output:**
102+
```
103+
=== Train Command ===
104+
Input CSV: sample_data.csv
105+
Target Column: price
106+
Output Model Path: house_model.joblib
107+
Report Path: house_report.html
108+
109+
Step 1: Loading and validating data...
110+
✓ Data loaded successfully: 80 samples, 4 features
111+
- Numeric features: 4
112+
- Target column: 'price'
113+
114+
Step 2: Creating and fitting preprocessing pipeline...
115+
✓ Preprocessing complete: 4 transformed features
116+
117+
Step 3: Training model...
118+
✓ Model trained successfully
119+
- R² Score: 0.9234
120+
- MSE: 123456789.0
121+
- RMSE: 11111.11
122+
- MAE: 8888.88
123+
124+
Step 4: Generating visualizations...
125+
✓ Visualizations created successfully
126+
127+
Step 5: Saving model...
128+
✓ Model saved to: house_model.joblib
129+
130+
Step 6: Generating HTML report...
131+
✓ Report saved to: house_report.html
132+
133+
============================================================
134+
🎉 Training completed successfully!
135+
============================================================
136+
137+
📊 Model Performance:
138+
- R² Score: 0.9234
139+
- MSE: 123456789.0
140+
- RMSE: 11111.11
141+
- MAE: 8888.88
142+
143+
💾 Output Files:
144+
- Model: house_model.joblib
145+
- Report: house_report.html
146+
```
147+
148+
### Making Predictions
149+
150+
Use a trained model to make predictions on new data:
151+
152+
```bash
153+
uv run python main.py predict \
154+
--model model.joblib \
155+
--input new_data.csv \
156+
--output predictions.csv
157+
```
158+
159+
**Example with Saved Model:**
160+
161+
```bash
162+
# Make predictions using the trained model
163+
uv run python main.py predict \
164+
--model house_model.joblib \
165+
--input sample_data.csv \
166+
--output predictions.csv
167+
```
168+
169+
**Expected Output:**
170+
```
171+
=== Predict Command ===
172+
Model Path: house_model.joblib
173+
Input CSV: sample_data.csv
174+
Output Path: predictions.csv
175+
176+
Step 1: Loading model...
177+
✓ Model loaded successfully
178+
- Target variable: 'price'
179+
- Expected features: 4
180+
- Trained on: 2024-01-15T10:30:45.123456
181+
182+
Step 2: Loading input data...
183+
✓ Input data loaded successfully: 80 samples, 4 features
184+
185+
Step 3: Making predictions...
186+
✓ Predictions generated successfully: 80 predictions
187+
188+
Step 4: Creating output file...
189+
✓ Output DataFrame created with column 'predicted_price'
190+
191+
Step 5: Saving predictions...
192+
✓ Predictions saved to: predictions.csv
193+
194+
Step 6: Calculating summary statistics...
195+
✓ Summary statistics calculated
196+
- Count: 80
197+
- Mean: 315000.0000
198+
- Median: 320000.0000
199+
- Std Dev: 85000.0000
200+
- Min: 185000.0000
201+
- Max: 485000.0000
202+
203+
============================================================
204+
🎉 Prediction completed successfully!
205+
============================================================
206+
207+
📊 Prediction Summary:
208+
- Output file: predictions.csv
209+
- Number of predictions: 80
210+
- Prediction column: 'predicted_price'
211+
212+
📈 Statistics:
213+
- Mean: 315000.0000
214+
- Median: 320000.0000
215+
- Range: [185000.0000, 485000.0000]
216+
```
217+
218+
The output CSV will contain all original columns plus a new `predicted_price` column.
219+
220+
## Preprocessing Details
221+
222+
The tool automatically applies a preprocessing pipeline to ensure clean, standardized data for the model:
223+
224+
### For Numeric Features:
225+
1. **Missing Value Imputation**: Replaces missing values (NaN) with the mean of that feature calculated from the training data
226+
2. **Standardization**: Applies StandardScaler to normalize features to zero mean and unit variance (z-score normalization)
227+
228+
### For Categorical Features:
229+
1. **Missing Value Imputation**: Replaces missing values with the most frequent category from the training data
230+
2. **One-Hot Encoding**: Converts categorical variables into binary indicator columns (handles unknown categories gracefully during prediction)
231+
232+
### Why This Matters:
233+
- **Consistency**: The same preprocessing is applied during both training and prediction, preventing data leakage
234+
- **No Data Leakage**: Imputation and scaling statistics are learned from training data only
235+
- **Reproducibility**: The preprocessing pipeline is saved with the model, ensuring identical transformations
236+
- **Robustness**: Missing values in new data are handled automatically using training statistics
237+
238+
## Report Contents
239+
240+
After training, an HTML report is generated containing:
43241

44-
Run the main script:
242+
### Performance Metrics
243+
- **R² Score**: Coefficient of determination (0-1, higher is better). Measures the proportion of variance explained by the model
244+
- **MSE**: Mean Squared Error. Average of squared differences between actual and predicted values
245+
- **RMSE**: Root Mean Squared Error. Square root of MSE, in the same units as the target variable
246+
- **MAE**: Mean Absolute Error. Average absolute difference between actual and predicted values
247+
248+
### Visualizations
249+
1. **Actual vs Predicted Plot**: Scatter plot showing how well predictions match actual values. Points close to the diagonal line indicate good predictions
250+
2. **Residuals Plot**: Shows prediction errors (actual - predicted) vs predicted values. Random scatter around zero indicates good model fit; patterns suggest issues
251+
3. **Feature Coefficients**: Horizontal bar chart showing each feature's impact on the target. Green bars indicate positive correlation, red bars indicate negative correlation
252+
253+
### Model Metadata
254+
- Model type (Linear Regression)
255+
- Training date and time
256+
- Number of features used
257+
- Imputation method (Mean for numeric, Most Frequent for categorical)
258+
- Scaling method (Standard Scaler)
259+
260+
The report is **completely self-contained** (all images embedded as base64) and can be opened in any browser or shared via email.
261+
262+
## Example End-to-End Workflow
263+
264+
Here's a complete example showing the full workflow from training to prediction to interpretation:
45265

46266
```bash
47-
uv run python main.py
267+
# 1. Train a model on your data
268+
uv run python main.py train \
269+
--input sample_data.csv \
270+
--target price \
271+
--output-model house_model.joblib \
272+
--report house_report.html
273+
274+
# 2. Open the HTML report in your browser to evaluate performance
275+
open house_report.html # macOS
276+
# or: xdg-open house_report.html # Linux
277+
# or: start house_report.html # Windows
278+
279+
# 3. If satisfied with the model, use it to make predictions on new data
280+
uv run python main.py predict \
281+
--model house_model.joblib \
282+
--input new_houses.csv \
283+
--output predicted_prices.csv
284+
285+
# 4. View the predictions
286+
cat predicted_prices.csv # or open in Excel/spreadsheet software
48287
```
49288

50-
This will execute the simple example in `main.py` which prints a greeting.
289+
**Tips for Best Results:**
290+
- Ensure your CSV has a header row with column names
291+
- The target column should be numeric for regression
292+
- Remove or encode categorical features with too many unique values
293+
- Check the HTML report to identify important features
294+
- Use the residuals plot to diagnose model issues (non-linearity, heteroscedasticity)
51295

52296
## Project Structure
53297

54298
```
55-
template-python/
56-
├── .git/ # Git repository
57-
├── .gitignore # Python-specific ignore patterns
58-
├── .python-version # Pinned Python version (3.13)
59-
├── pyproject.toml # Project configuration and dependencies
60-
├── main.py # Main entry point
61-
└── README.md # This file
62-
```
299+
ml-cli-tool/
300+
├── main.py # Entry point - runs the CLI
301+
├── cli.py # CLI command definitions (train, predict)
302+
├── data_loader.py # Data loading and validation
303+
├── preprocessing.py # Preprocessing pipeline creation
304+
├── model.py # Model training, saving, loading, prediction
305+
├── visualizations.py # Matplotlib plotting functions
306+
├── report.py # HTML report generation
307+
├── templates/
308+
│ └── report_template.html # Jinja2 template for reports
309+
├── sample_data.csv # Example dataset (house prices)
310+
├── pyproject.toml # Project configuration and dependencies
311+
├── .python-version # Pinned Python version (3.13)
312+
└── README.md # This file
313+
```
314+
315+
## Requirements
316+
317+
- Python 3.13+
318+
- Dependencies (automatically installed by `uv sync`):
319+
- scikit-learn >= 1.3
320+
- pandas >= 2.0
321+
- matplotlib >= 3.7
322+
- jinja2 >= 3.1
323+
- click >= 8.1
324+
325+
## Troubleshooting
326+
327+
**Issue**: "Target column 'X' not found in CSV file"
328+
- **Solution**: Check the column name spelling and ensure the CSV has a header row
329+
330+
**Issue**: "Expected features: [...], got: [...]. Missing: [...]"
331+
- **Solution**: Ensure prediction data has all the same feature columns as training data (order doesn't matter)
332+
333+
**Issue**: All predictions are NaN
334+
- **Solution**: Check that input data contains valid numeric values (not all NaN after preprocessing)
335+
336+
**Issue**: Model file is corrupted or incompatible
337+
- **Solution**: Retrain the model if the file was modified or created with an incompatible sklearn version
338+
339+
## License
340+
341+
See LICENSE file for details.

0 commit comments

Comments
 (0)