|
| 1 | +"""Report generation module for creating HTML reports with embedded visualizations. |
| 2 | +
|
| 3 | +This module provides functionality to generate self-contained HTML reports from |
| 4 | +trained models, including performance metrics, coefficients, and matplotlib |
| 5 | +visualizations. Reports use Jinja2 templates and embed images as base64-encoded |
| 6 | +strings for complete portability. |
| 7 | +""" |
| 8 | + |
| 9 | +import base64 |
| 10 | +from io import BytesIO |
| 11 | +from pathlib import Path |
| 12 | +from datetime import datetime |
| 13 | +from typing import Dict, List, Optional, Union, Any |
| 14 | +from matplotlib.figure import Figure |
| 15 | +from jinja2 import Environment, FileSystemLoader, Template |
| 16 | + |
| 17 | + |
| 18 | +def figure_to_base64(fig: Figure) -> str: |
| 19 | + """ |
| 20 | + Convert a matplotlib Figure to a base64-encoded PNG string. |
| 21 | + |
| 22 | + This function takes a matplotlib Figure object, saves it to an in-memory |
| 23 | + BytesIO buffer as PNG format, and encodes the resulting bytes as a base64 |
| 24 | + string. This allows the image to be embedded directly in HTML using a |
| 25 | + data URI, creating self-contained reports with no external file dependencies. |
| 26 | + |
| 27 | + Parameters |
| 28 | + ---------- |
| 29 | + fig : Figure |
| 30 | + A matplotlib Figure object to convert. Should be a complete figure |
| 31 | + ready for display (with all desired formatting, labels, etc.). |
| 32 | + |
| 33 | + Returns |
| 34 | + ------- |
| 35 | + str |
| 36 | + Base64-encoded string representation of the PNG image. This string |
| 37 | + can be used directly in HTML <img> tags with a data URI: |
| 38 | + `<img src="data:image/png;base64,{base64_string}">` |
| 39 | + |
| 40 | + Raises |
| 41 | + ------ |
| 42 | + TypeError |
| 43 | + If fig is not a matplotlib Figure object. |
| 44 | + RuntimeError |
| 45 | + If there's an error during PNG conversion or base64 encoding. |
| 46 | + |
| 47 | + Examples |
| 48 | + -------- |
| 49 | + >>> import matplotlib.pyplot as plt |
| 50 | + >>> fig, ax = plt.subplots() |
| 51 | + >>> ax.plot([1, 2, 3], [1, 4, 9]) |
| 52 | + >>> base64_str = figure_to_base64(fig) |
| 53 | + >>> html = f'<img src="data:image/png;base64,{base64_str}">' |
| 54 | + >>> plt.close(fig) |
| 55 | + |
| 56 | + Notes |
| 57 | + ----- |
| 58 | + - The figure is saved at 100 DPI by default for reasonable file size |
| 59 | + - The figure is not modified or closed by this function |
| 60 | + - Remember to close figures after use to prevent memory leaks |
| 61 | + - Typical base64 strings are 50-500KB depending on figure complexity |
| 62 | + - The bbox_inches='tight' parameter minimizes whitespace around the plot |
| 63 | + """ |
| 64 | + # Validate input |
| 65 | + if not isinstance(fig, Figure): |
| 66 | + raise TypeError( |
| 67 | + f"Expected matplotlib Figure object, got {type(fig).__name__} instead." |
| 68 | + ) |
| 69 | + |
| 70 | + try: |
| 71 | + # Create in-memory bytes buffer |
| 72 | + buffer = BytesIO() |
| 73 | + |
| 74 | + # Save figure to buffer as PNG |
| 75 | + # bbox_inches='tight' removes excess whitespace |
| 76 | + # dpi=100 provides good quality while keeping file size reasonable |
| 77 | + fig.savefig(buffer, format='png', bbox_inches='tight', dpi=100) |
| 78 | + |
| 79 | + # Get the bytes from buffer |
| 80 | + buffer.seek(0) |
| 81 | + image_bytes = buffer.read() |
| 82 | + |
| 83 | + # Encode as base64 |
| 84 | + base64_encoded = base64.b64encode(image_bytes).decode('utf-8') |
| 85 | + |
| 86 | + # Close the buffer |
| 87 | + buffer.close() |
| 88 | + |
| 89 | + return base64_encoded |
| 90 | + |
| 91 | + except Exception as e: |
| 92 | + raise RuntimeError( |
| 93 | + f"Failed to convert figure to base64: {str(e)}" |
| 94 | + ) from e |
| 95 | + |
| 96 | + |
| 97 | +def generate_report( |
| 98 | + metrics: Optional[Dict[str, float]], |
| 99 | + coefficients: Optional[List[Dict[str, Union[str, float]]]], |
| 100 | + figures: Optional[Dict[str, Figure]], |
| 101 | + metadata: Optional[Dict[str, Any]], |
| 102 | + output_path: Union[str, Path] |
| 103 | +) -> None: |
| 104 | + """ |
| 105 | + Generate a self-contained HTML report with model metrics and visualizations. |
| 106 | + |
| 107 | + This function creates a comprehensive HTML report using a Jinja2 template, |
| 108 | + embedding all visualizations as base64-encoded images. The resulting HTML |
| 109 | + file is completely self-contained and can be shared, emailed, or opened in |
| 110 | + any modern web browser without external dependencies. |
| 111 | + |
| 112 | + Parameters |
| 113 | + ---------- |
| 114 | + metrics : Optional[Dict[str, float]] |
| 115 | + Dictionary of performance metrics with keys: |
| 116 | + - 'r2': R² score (coefficient of determination) |
| 117 | + - 'mse': Mean Squared Error |
| 118 | + - 'rmse': Root Mean Squared Error |
| 119 | + - 'mae': Mean Absolute Error |
| 120 | + Can be None if metrics are not available. |
| 121 | + |
| 122 | + coefficients : Optional[List[Dict[str, Union[str, float]]]] |
| 123 | + List of dictionaries containing feature coefficients, where each dict has: |
| 124 | + - 'feature': Feature name (str) |
| 125 | + - 'value': Coefficient value (float) |
| 126 | + Example: [{'feature': 'age', 'value': 0.5}, {'feature': 'income', 'value': 1.2}] |
| 127 | + Can be None if coefficients are not available. |
| 128 | + |
| 129 | + figures : Optional[Dict[str, Figure]] |
| 130 | + Dictionary of matplotlib Figure objects with keys: |
| 131 | + - 'predictions': Actual vs predicted values plot |
| 132 | + - 'residuals': Residuals plot |
| 133 | + - 'coefficients': Feature coefficients bar chart |
| 134 | + Figures will be converted to base64 and embedded in HTML. |
| 135 | + Can be None if visualizations are not available. |
| 136 | + |
| 137 | + metadata : Optional[Dict[str, Any]] |
| 138 | + Dictionary containing model metadata with optional keys: |
| 139 | + - 'model_type': Type of model (default: 'Linear Regression') |
| 140 | + - 'training_date': Date model was trained |
| 141 | + - 'feature_count': Number of features used |
| 142 | + - 'imputation_method': Method used for handling missing values |
| 143 | + - 'scaling_method': Method used for feature scaling |
| 144 | + Can be None, in which case defaults will be used. |
| 145 | + |
| 146 | + output_path : Union[str, Path] |
| 147 | + Path where the HTML report should be saved. Parent directories |
| 148 | + will be created if they don't exist. Example: 'reports/model_report.html' |
| 149 | + |
| 150 | + Raises |
| 151 | + ------ |
| 152 | + FileNotFoundError |
| 153 | + If the Jinja2 template file cannot be found. |
| 154 | + PermissionError |
| 155 | + If the output path is not writable. |
| 156 | + RuntimeError |
| 157 | + If there's an error during template rendering or file writing. |
| 158 | + |
| 159 | + Examples |
| 160 | + -------- |
| 161 | + >>> from visualizations import create_predictions_plot |
| 162 | + >>> import numpy as np |
| 163 | + >>> |
| 164 | + >>> # Prepare data |
| 165 | + >>> y_actual = np.array([1, 2, 3, 4, 5]) |
| 166 | + >>> y_pred = np.array([1.1, 2.2, 2.9, 4.1, 4.8]) |
| 167 | + >>> |
| 168 | + >>> # Create visualizations |
| 169 | + >>> fig = create_predictions_plot(y_actual, y_pred) |
| 170 | + >>> figures = {'predictions': fig} |
| 171 | + >>> |
| 172 | + >>> # Prepare metrics |
| 173 | + >>> metrics = {'r2': 0.95, 'mse': 0.1, 'rmse': 0.316, 'mae': 0.2} |
| 174 | + >>> |
| 175 | + >>> # Prepare coefficients |
| 176 | + >>> coefficients = [ |
| 177 | + ... {'feature': 'age', 'value': 0.5}, |
| 178 | + ... {'feature': 'income', 'value': 1.2} |
| 179 | + ... ] |
| 180 | + >>> |
| 181 | + >>> # Prepare metadata |
| 182 | + >>> metadata = { |
| 183 | + ... 'model_type': 'Linear Regression', |
| 184 | + ... 'training_date': '2024-01-15', |
| 185 | + ... 'feature_count': 2, |
| 186 | + ... 'imputation_method': 'Mean', |
| 187 | + ... 'scaling_method': 'Standard Scaler' |
| 188 | + ... } |
| 189 | + >>> |
| 190 | + >>> # Generate report |
| 191 | + >>> generate_report(metrics, coefficients, figures, metadata, 'report.html') |
| 192 | + >>> |
| 193 | + >>> # Clean up |
| 194 | + >>> import matplotlib.pyplot as plt |
| 195 | + >>> plt.close(fig) |
| 196 | + |
| 197 | + Notes |
| 198 | + ----- |
| 199 | + - All visualizations are embedded as base64-encoded PNG images |
| 200 | + - The report is completely self-contained (no external file dependencies) |
| 201 | + - Missing data (None values) is handled gracefully with 'N/A' placeholders |
| 202 | + - The template includes responsive CSS for mobile and desktop viewing |
| 203 | + - Typical report file size is <2MB for 3 embedded plots |
| 204 | + - Parent directories are created automatically if they don't exist |
| 205 | + - The function closes no figures - caller is responsible for cleanup |
| 206 | + """ |
| 207 | + # Convert output_path to Path object for easier handling |
| 208 | + output_path = Path(output_path) |
| 209 | + |
| 210 | + # Create parent directories if they don't exist |
| 211 | + output_path.parent.mkdir(parents=True, exist_ok=True) |
| 212 | + |
| 213 | + try: |
| 214 | + # Set up Jinja2 environment |
| 215 | + # The template should be in the 'templates' directory relative to this file |
| 216 | + template_dir = Path(__file__).parent / 'templates' |
| 217 | + |
| 218 | + if not template_dir.exists(): |
| 219 | + raise FileNotFoundError( |
| 220 | + f"Templates directory not found at {template_dir}. " |
| 221 | + "Please ensure the 'templates' directory exists with report_template.html." |
| 222 | + ) |
| 223 | + |
| 224 | + env = Environment(loader=FileSystemLoader(str(template_dir))) |
| 225 | + |
| 226 | + # Load the template |
| 227 | + try: |
| 228 | + template = env.get_template('report_template.html') |
| 229 | + except Exception as e: |
| 230 | + raise FileNotFoundError( |
| 231 | + f"Could not load template 'report_template.html' from {template_dir}: {str(e)}" |
| 232 | + ) from e |
| 233 | + |
| 234 | + # Convert figures to base64 if provided |
| 235 | + encoded_figures = {} |
| 236 | + if figures: |
| 237 | + for key, fig in figures.items(): |
| 238 | + if fig is not None: |
| 239 | + try: |
| 240 | + encoded_figures[key] = figure_to_base64(fig) |
| 241 | + except Exception as e: |
| 242 | + # Log warning but continue - allow partial reports |
| 243 | + print(f"Warning: Failed to encode figure '{key}': {str(e)}") |
| 244 | + |
| 245 | + # Prepare metadata with defaults |
| 246 | + if metadata is None: |
| 247 | + metadata = {} |
| 248 | + |
| 249 | + # Ensure we have default values for missing metadata |
| 250 | + metadata_with_defaults = { |
| 251 | + 'model_type': metadata.get('model_type', 'Linear Regression'), |
| 252 | + 'training_date': metadata.get('training_date', 'N/A'), |
| 253 | + 'feature_count': metadata.get('feature_count', 'N/A'), |
| 254 | + 'imputation_method': metadata.get('imputation_method', 'Mean (numeric)'), |
| 255 | + 'scaling_method': metadata.get('scaling_method', 'Standard Scaler'), |
| 256 | + } |
| 257 | + |
| 258 | + # Generate timestamp for report |
| 259 | + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| 260 | + |
| 261 | + # Render the template |
| 262 | + html_content = template.render( |
| 263 | + timestamp=timestamp, |
| 264 | + metrics=metrics, |
| 265 | + coefficients=coefficients, |
| 266 | + figures=encoded_figures if encoded_figures else None, |
| 267 | + metadata=metadata_with_defaults |
| 268 | + ) |
| 269 | + |
| 270 | + # Write to output file |
| 271 | + try: |
| 272 | + with open(output_path, 'w', encoding='utf-8') as f: |
| 273 | + f.write(html_content) |
| 274 | + except PermissionError as e: |
| 275 | + raise PermissionError( |
| 276 | + f"Permission denied: Cannot write to {output_path}. " |
| 277 | + "Please check file permissions." |
| 278 | + ) from e |
| 279 | + except Exception as e: |
| 280 | + raise RuntimeError( |
| 281 | + f"Failed to write report to {output_path}: {str(e)}" |
| 282 | + ) from e |
| 283 | + |
| 284 | + except (FileNotFoundError, PermissionError, RuntimeError): |
| 285 | + # Re-raise known exceptions |
| 286 | + raise |
| 287 | + except Exception as e: |
| 288 | + # Catch any other unexpected errors |
| 289 | + raise RuntimeError( |
| 290 | + f"Unexpected error while generating report: {str(e)}" |
| 291 | + ) from e |
0 commit comments