Skip to content

Commit a2316b1

Browse files
authored
Merge pull request #61 from turintech/feat/matplotlib-visualization-plots
Add visualizations module with predictions and residuals plotting functions
2 parents 8b0396f + f466ac8 commit a2316b1

File tree

1 file changed

+315
-0
lines changed

1 file changed

+315
-0
lines changed

visualizations.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
"""Visualization module for model evaluation and feature analysis.
2+
3+
This module provides matplotlib-based visualization functions for evaluating
4+
linear regression model performance and understanding feature importance. All
5+
functions are pure (no side effects) and return matplotlib Figure objects for
6+
flexible use in reports or interactive displays.
7+
"""
8+
9+
import numpy as np
10+
import pandas as pd
11+
import matplotlib.pyplot as plt
12+
from matplotlib.figure import Figure
13+
from typing import Union, List
14+
15+
16+
def create_predictions_plot(
17+
y_actual: Union[np.ndarray, pd.Series],
18+
y_predicted: Union[np.ndarray, pd.Series]
19+
) -> Figure:
20+
"""
21+
Create a scatter plot comparing actual vs predicted target values.
22+
23+
This function generates a scatter plot with actual values on the x-axis and
24+
predicted values on the y-axis. An ideal fit line (y=x diagonal) is included
25+
to visualize model accuracy. Points close to the diagonal indicate good
26+
predictions, while deviations show prediction errors.
27+
28+
Parameters
29+
----------
30+
y_actual : Union[np.ndarray, pd.Series]
31+
Actual (true) target values from the dataset.
32+
y_predicted : Union[np.ndarray, pd.Series]
33+
Predicted target values from the model.
34+
35+
Returns
36+
-------
37+
Figure
38+
Matplotlib Figure object containing the predictions scatter plot.
39+
The figure can be saved, displayed, or embedded in reports.
40+
41+
Raises
42+
------
43+
ValueError
44+
If y_actual and y_predicted have different lengths or are empty.
45+
46+
Examples
47+
--------
48+
>>> import numpy as np
49+
>>> y_actual = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
50+
>>> y_predicted = np.array([1.1, 2.2, 2.9, 4.1, 4.8])
51+
>>> fig = create_predictions_plot(y_actual, y_predicted)
52+
>>> fig.savefig('predictions.png')
53+
>>> plt.close(fig)
54+
55+
Notes
56+
-----
57+
- Semi-transparent points (alpha=0.6) help visualize overlapping predictions
58+
- The diagonal line represents perfect predictions (y_actual = y_predicted)
59+
- Figure size is set to 10x6 inches for readability
60+
- Use plt.close(fig) after use to prevent memory leaks
61+
"""
62+
# Convert to numpy arrays for consistent handling
63+
y_actual = np.asarray(y_actual)
64+
y_predicted = np.asarray(y_predicted)
65+
66+
# Validate inputs
67+
_validate_arrays(y_actual, y_predicted)
68+
69+
# Create figure and axis
70+
fig, ax = plt.subplots(figsize=(10, 6))
71+
72+
# Create scatter plot with semi-transparent points
73+
ax.scatter(y_actual, y_predicted, alpha=0.6, color='steelblue',
74+
edgecolors='navy', linewidth=0.5, label='Predictions')
75+
76+
# Add ideal fit line (y=x diagonal)
77+
min_val = min(y_actual.min(), y_predicted.min())
78+
max_val = max(y_actual.max(), y_predicted.max())
79+
ax.plot([min_val, max_val], [min_val, max_val],
80+
'r--', linewidth=2, label='Ideal Fit (y=x)')
81+
82+
# Labels and title
83+
ax.set_xlabel('Actual Values', fontsize=12, fontweight='bold')
84+
ax.set_ylabel('Predicted Values', fontsize=12, fontweight='bold')
85+
ax.set_title('Actual vs Predicted Values', fontsize=14, fontweight='bold', pad=20)
86+
87+
# Add legend and grid
88+
ax.legend(loc='upper left', fontsize=10)
89+
ax.grid(True, alpha=0.3, linestyle='--')
90+
91+
# Ensure equal aspect for better visual interpretation
92+
ax.set_aspect('equal', adjustable='box')
93+
94+
# Prevent label cutoff
95+
plt.tight_layout()
96+
97+
return fig
98+
99+
100+
def create_residuals_plot(
101+
y_predicted: Union[np.ndarray, pd.Series],
102+
residuals: Union[np.ndarray, pd.Series]
103+
) -> Figure:
104+
"""
105+
Create a residuals plot to assess model fit quality.
106+
107+
This function generates a scatter plot of residuals (actual - predicted) versus
108+
predicted values. A horizontal line at y=0 represents perfect predictions. Random
109+
scatter around zero indicates a good model fit, while patterns suggest systematic
110+
errors or model inadequacy.
111+
112+
Parameters
113+
----------
114+
y_predicted : Union[np.ndarray, pd.Series]
115+
Predicted target values from the model.
116+
residuals : Union[np.ndarray, pd.Series]
117+
Residual values calculated as (actual - predicted).
118+
119+
Returns
120+
-------
121+
Figure
122+
Matplotlib Figure object containing the residuals plot.
123+
The figure can be saved, displayed, or embedded in reports.
124+
125+
Raises
126+
------
127+
ValueError
128+
If y_predicted and residuals have different lengths or are empty.
129+
130+
Examples
131+
--------
132+
>>> import numpy as np
133+
>>> y_predicted = np.array([1.1, 2.2, 2.9, 4.1, 4.8])
134+
>>> residuals = np.array([-0.1, -0.2, 0.1, -0.1, 0.2])
135+
>>> fig = create_residuals_plot(y_predicted, residuals)
136+
>>> fig.savefig('residuals.png')
137+
>>> plt.close(fig)
138+
139+
Notes
140+
-----
141+
- Random scatter around y=0 indicates good model fit (homoscedasticity)
142+
- Patterns (e.g., funnel shape) suggest heteroscedasticity or non-linearity
143+
- Semi-transparent points (alpha=0.6) help visualize overlapping residuals
144+
- Different color (coral) from predictions plot for visual distinction
145+
- Figure size is set to 10x6 inches for readability
146+
- Use plt.close(fig) after use to prevent memory leaks
147+
"""
148+
# Convert to numpy arrays for consistent handling
149+
y_predicted = np.asarray(y_predicted)
150+
residuals = np.asarray(residuals)
151+
152+
# Validate inputs
153+
_validate_arrays(y_predicted, residuals)
154+
155+
# Create figure and axis
156+
fig, ax = plt.subplots(figsize=(10, 6))
157+
158+
# Create scatter plot with semi-transparent points
159+
ax.scatter(y_predicted, residuals, alpha=0.6, color='coral',
160+
edgecolors='darkred', linewidth=0.5, label='Residuals')
161+
162+
# Add horizontal line at y=0
163+
ax.axhline(y=0, color='black', linestyle='--', linewidth=2,
164+
label='Zero Line')
165+
166+
# Labels and title
167+
ax.set_xlabel('Predicted Values', fontsize=12, fontweight='bold')
168+
ax.set_ylabel('Residuals (Actual - Predicted)', fontsize=12, fontweight='bold')
169+
ax.set_title('Residuals Plot', fontsize=14, fontweight='bold', pad=20)
170+
171+
# Add legend and grid
172+
ax.legend(loc='upper left', fontsize=10)
173+
ax.grid(True, alpha=0.3, linestyle='--')
174+
175+
# Prevent label cutoff
176+
plt.tight_layout()
177+
178+
return fig
179+
180+
181+
def create_coefficients_plot(
182+
feature_names: List[str],
183+
coefficients: Union[np.ndarray, pd.Series, List[float]]
184+
) -> Figure:
185+
"""
186+
Create a bar chart showing feature importance based on model coefficients.
187+
188+
This function generates a horizontal bar chart of model coefficients sorted by
189+
absolute value in descending order. Positive coefficients are shown in green
190+
(positive correlation with target) and negative coefficients in red (negative
191+
correlation), making it easy to identify the most important features and their
192+
directional impact.
193+
194+
Parameters
195+
----------
196+
feature_names : List[str]
197+
Names of the features corresponding to each coefficient.
198+
coefficients : Union[np.ndarray, pd.Series, List[float]]
199+
Model coefficients for each feature (e.g., from LinearRegression.coef_).
200+
201+
Returns
202+
-------
203+
Figure
204+
Matplotlib Figure object containing the coefficients bar chart.
205+
The figure can be saved, displayed, or embedded in reports.
206+
207+
Raises
208+
------
209+
ValueError
210+
If feature_names and coefficients have different lengths or are empty.
211+
212+
Examples
213+
--------
214+
>>> feature_names = ['age', 'income', 'education', 'experience']
215+
>>> coefficients = [0.5, 1.2, -0.3, 0.8]
216+
>>> fig = create_coefficients_plot(feature_names, coefficients)
217+
>>> fig.savefig('coefficients.png')
218+
>>> plt.close(fig)
219+
220+
Notes
221+
-----
222+
- Bars are sorted by absolute coefficient value (most important at top)
223+
- Green bars indicate positive correlation with target
224+
- Red bars indicate negative correlation with target
225+
- Horizontal bar chart makes long feature names more readable
226+
- Figure size is set to 10x6 inches for readability
227+
- Use plt.close(fig) after use to prevent memory leaks
228+
"""
229+
# Convert coefficients to numpy array for consistent handling
230+
coefficients = np.asarray(coefficients)
231+
232+
# Validate inputs
233+
if len(feature_names) != len(coefficients):
234+
raise ValueError(
235+
f"Length mismatch: feature_names has {len(feature_names)} elements "
236+
f"but coefficients has {len(coefficients)} elements. They must match."
237+
)
238+
239+
if len(feature_names) == 0:
240+
raise ValueError("feature_names and coefficients cannot be empty.")
241+
242+
# Create DataFrame for easier sorting
243+
coef_df = pd.DataFrame({
244+
'feature': feature_names,
245+
'coefficient': coefficients
246+
})
247+
248+
# Sort by absolute value in descending order
249+
coef_df['abs_coefficient'] = np.abs(coef_df['coefficient'])
250+
coef_df = coef_df.sort_values('abs_coefficient', ascending=True) # Ascending for horizontal bars
251+
252+
# Create figure and axis
253+
fig, ax = plt.subplots(figsize=(10, 6))
254+
255+
# Create color array based on sign of coefficient
256+
colors = ['green' if c > 0 else 'red' for c in coef_df['coefficient']]
257+
258+
# Create horizontal bar chart
259+
bars = ax.barh(coef_df['feature'], coef_df['coefficient'], color=colors,
260+
alpha=0.7, edgecolor='black', linewidth=0.8)
261+
262+
# Add vertical line at x=0
263+
ax.axvline(x=0, color='black', linestyle='-', linewidth=1.5)
264+
265+
# Labels and title
266+
ax.set_xlabel('Coefficient Value', fontsize=12, fontweight='bold')
267+
ax.set_ylabel('Feature Name', fontsize=12, fontweight='bold')
268+
ax.set_title('Feature Importance (Model Coefficients)', fontsize=14,
269+
fontweight='bold', pad=20)
270+
271+
# Add grid for better readability
272+
ax.grid(True, alpha=0.3, linestyle='--', axis='x')
273+
274+
# Add legend
275+
from matplotlib.patches import Patch
276+
legend_elements = [
277+
Patch(facecolor='green', alpha=0.7, edgecolor='black', label='Positive Impact'),
278+
Patch(facecolor='red', alpha=0.7, edgecolor='black', label='Negative Impact')
279+
]
280+
ax.legend(handles=legend_elements, loc='lower right', fontsize=10)
281+
282+
# Prevent label cutoff
283+
plt.tight_layout()
284+
285+
return fig
286+
287+
288+
def _validate_arrays(
289+
arr1: np.ndarray,
290+
arr2: np.ndarray
291+
) -> None:
292+
"""
293+
Validate that two arrays have the same length and are non-empty.
294+
295+
Parameters
296+
----------
297+
arr1 : np.ndarray
298+
First array to validate.
299+
arr2 : np.ndarray
300+
Second array to validate.
301+
302+
Raises
303+
------
304+
ValueError
305+
If arrays have different lengths or are empty.
306+
"""
307+
if len(arr1) == 0 or len(arr2) == 0:
308+
raise ValueError(
309+
f"Arrays cannot be empty. Got lengths: {len(arr1)} and {len(arr2)}."
310+
)
311+
312+
if len(arr1) != len(arr2):
313+
raise ValueError(
314+
f"Arrays must have the same length. Got {len(arr1)} and {len(arr2)}."
315+
)

0 commit comments

Comments
 (0)