# Linear Regression: Assumptions and Diagnostics

---

## Learning Objectives

By the end of this notebook, you will be able to:

- List and explain the 5 key assumptions of linear regression
- Create and interpret diagnostic plots: residual vs fitted, Q-Q plots
- Detect multicollinearity using correlation matrices and VIF
- Recognize heteroscedasticity and understand its consequences
- Know what to do when assumptions are violated

## Prerequisites

- Completed Notebook 01 (Linear Regression basics with sklearn)
- Basic understanding of residuals and MSE
- Familiarity with correlation

## Table of Contents

1. [The 5 Key Assumptions](#1-the-5-key-assumptions)
2. [Setup: Generate Well-Behaved Data](#2-setup-generate-well-behaved-data)
3. [Checking Linearity](#3-checking-linearity)
4. [Checking Normality of Residuals](#4-checking-normality-of-residuals)
5. [Checking Homoscedasticity](#5-checking-homoscedasticity)
6. [Checking Multicollinearity](#6-checking-multicollinearity)
7. [When Assumptions Are Violated: What to Do](#7-when-assumptions-are-violated-what-to-do)
8. [Common Mistakes](#8-common-mistakes)
9. [Exercise](#9-exercise)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

np.random.seed(42)
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (8, 5)

---

## 1. The 5 Key Assumptions

Linear regression results are reliable **only if** these assumptions hold (approximately):

| # | Assumption | What It Means | How to Check |
|---|---|---|---|
| 1 | **Linearity** | The relationship between X and y is linear | Residual vs fitted plot |
| 2 | **Independence** | Observations are independent of each other | Study design / Durbin-Watson test |
| 3 | **Normality of residuals** | Residuals follow a normal distribution | Q-Q plot, Shapiro-Wilk test |
| 4 | **Homoscedasticity** | Residual variance is constant across all fitted values | Residual vs fitted plot |
| 5 | **No multicollinearity** | Features are not highly correlated with each other | Correlation matrix, VIF |

**Important:** These assumptions are about the **residuals**, not the raw data. We check them **after** fitting the model.

---

## 2. Setup: Generate Well-Behaved Data

First, let's create a dataset that satisfies all assumptions so we know what "good" diagnostics look like.

In [None]:
# Well-behaved data: linear relationship, normal errors, constant variance
np.random.seed(42)
n = 200

X_good = np.random.randn(n, 2)
y_good = 3 + 2 * X_good[:, 0] - 1.5 * X_good[:, 1] + np.random.randn(n) * 0.8

# Fit model
model_good = LinearRegression()
model_good.fit(X_good, y_good)
y_pred_good = model_good.predict(X_good)
residuals_good = y_good - y_pred_good

print(f"R2: {model_good.score(X_good, y_good):.4f}")
print(f"Coefficients: {model_good.coef_}")
print(f"Intercept: {model_good.intercept_:.4f}")

---

## 3. Checking Linearity

**Residual vs Fitted plot:** If the relationship is linear, residuals should scatter randomly around zero with no visible pattern.

In [None]:
# --- Good case: linear relationship ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Good: residuals vs fitted
axes[0].scatter(y_pred_good, residuals_good, alpha=0.5, edgecolors="k", linewidths=0.5)
axes[0].axhline(y=0, color="r", linestyle="--", linewidth=1.5)
axes[0].set_xlabel("Fitted values")
axes[0].set_ylabel("Residuals")
axes[0].set_title("GOOD: Residual vs Fitted (Linear Data)")

# --- Bad case: nonlinear relationship ---
np.random.seed(42)
X_nonlinear = np.random.uniform(0, 5, (n, 1))
y_nonlinear = 2 + 3 * X_nonlinear.flatten() ** 2 + np.random.randn(n) * 3

model_bad_lin = LinearRegression()
model_bad_lin.fit(X_nonlinear, y_nonlinear)
y_pred_bad_lin = model_bad_lin.predict(X_nonlinear)
residuals_bad_lin = y_nonlinear - y_pred_bad_lin

axes[1].scatter(y_pred_bad_lin, residuals_bad_lin, alpha=0.5, edgecolors="k",
                linewidths=0.5, color="salmon")
axes[1].axhline(y=0, color="r", linestyle="--", linewidth=1.5)
axes[1].set_xlabel("Fitted values")
axes[1].set_ylabel("Residuals")
axes[1].set_title("BAD: Residual vs Fitted (Nonlinear Data)")

plt.tight_layout()
plt.show()

print("Left: Random scatter around zero = GOOD (linearity holds)")
print("Right: Curved pattern = BAD (linearity violated)")

---

## 4. Checking Normality of Residuals

**Q-Q plot:** Compares residual quantiles against theoretical normal quantiles. Points should fall on the diagonal line.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Good: normal residuals from the well-behaved model
stats.probplot(residuals_good, dist="norm", plot=axes[0])
axes[0].set_title("GOOD: Q-Q Plot (Normal Residuals)")
axes[0].get_lines()[0].set_markerfacecolor("steelblue")
axes[0].get_lines()[0].set_alpha(0.6)

# Bad: skewed residuals
np.random.seed(42)
X_skew = np.random.randn(n, 1)
# Use exponential errors (right-skewed) instead of normal
y_skew = 2 + 3 * X_skew.flatten() + np.random.exponential(2, n)

model_skew = LinearRegression()
model_skew.fit(X_skew, y_skew)
residuals_skew = y_skew - model_skew.predict(X_skew)

stats.probplot(residuals_skew, dist="norm", plot=axes[1])
axes[1].set_title("BAD: Q-Q Plot (Skewed Residuals)")
axes[1].get_lines()[0].set_markerfacecolor("salmon")
axes[1].get_lines()[0].set_alpha(0.6)

plt.tight_layout()
plt.show()

# Shapiro-Wilk test
stat_good, p_good = stats.shapiro(residuals_good)
stat_skew, p_skew = stats.shapiro(residuals_skew)

print(f"Shapiro-Wilk test (H0: residuals are normal):")
print(f"  Good data: W={stat_good:.4f}, p={p_good:.4f}  {'-> Normal' if p_good > 0.05 else '-> Not normal'}")
print(f"  Skew data: W={stat_skew:.4f}, p={p_skew:.4f}  {'-> Normal' if p_skew > 0.05 else '-> Not normal'}")

In [None]:
# Histogram of residuals (supplementary check)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].hist(residuals_good, bins=25, edgecolor="k", alpha=0.7, color="steelblue", density=True)
x_range = np.linspace(residuals_good.min(), residuals_good.max(), 100)
axes[0].plot(x_range, stats.norm.pdf(x_range, residuals_good.mean(), residuals_good.std()),
             "r-", linewidth=2, label="Normal fit")
axes[0].set_title("GOOD: Residual Distribution")
axes[0].set_xlabel("Residual")
axes[0].legend()

axes[1].hist(residuals_skew, bins=25, edgecolor="k", alpha=0.7, color="salmon", density=True)
x_range2 = np.linspace(residuals_skew.min(), residuals_skew.max(), 100)
axes[1].plot(x_range2, stats.norm.pdf(x_range2, residuals_skew.mean(), residuals_skew.std()),
             "r-", linewidth=2, label="Normal fit")
axes[1].set_title("BAD: Residual Distribution (Right-Skewed)")
axes[1].set_xlabel("Residual")
axes[1].legend()

plt.tight_layout()
plt.show()

---

## 5. Checking Homoscedasticity

**Homoscedasticity** means the variance of residuals is constant across all levels of the predicted values.

**Heteroscedasticity** (the violation) means variance changes — often increasing with larger values. This:
- Makes standard errors unreliable
- Invalidates confidence intervals and p-values
- Can lead to inefficient coefficient estimates

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Good: homoscedastic residuals (from well-behaved data)
axes[0].scatter(y_pred_good, residuals_good, alpha=0.5, edgecolors="k", linewidths=0.5)
axes[0].axhline(y=0, color="r", linestyle="--", linewidth=1.5)
axes[0].set_xlabel("Fitted values")
axes[0].set_ylabel("Residuals")
axes[0].set_title("GOOD: Constant Variance (Homoscedastic)")

# Bad: heteroscedastic data (variance increases with X)
np.random.seed(42)
X_hetero = np.random.uniform(1, 10, (n, 1))
# Error variance scales with X
noise_hetero = np.random.randn(n) * (0.5 * X_hetero.flatten())
y_hetero = 2 + 3 * X_hetero.flatten() + noise_hetero

model_hetero = LinearRegression()
model_hetero.fit(X_hetero, y_hetero)
y_pred_hetero = model_hetero.predict(X_hetero)
residuals_hetero = y_hetero - y_pred_hetero

axes[1].scatter(y_pred_hetero, residuals_hetero, alpha=0.5, edgecolors="k",
                linewidths=0.5, color="salmon")
axes[1].axhline(y=0, color="r", linestyle="--", linewidth=1.5)
axes[1].set_xlabel("Fitted values")
axes[1].set_ylabel("Residuals")
axes[1].set_title("BAD: Increasing Variance (Heteroscedastic)")

plt.tight_layout()
plt.show()

print("Left: Residuals form a uniform band — GOOD")
print("Right: Residuals fan out (funnel shape) — BAD (heteroscedasticity)")

---

## 6. Checking Multicollinearity

**Multicollinearity** occurs when features are highly correlated with each other. This:
- Makes individual coefficient estimates unstable
- Inflates standard errors
- Makes it hard to determine which feature is driving the prediction

**Detection methods:**
- Correlation matrix (pairwise)
- **Variance Inflation Factor (VIF):** measures how much the variance of a coefficient is inflated due to collinearity
  - VIF = 1: no collinearity
  - VIF > 5: moderate concern
  - VIF > 10: serious multicollinearity

In [None]:
# Generate data with multicollinearity
np.random.seed(42)
x1 = np.random.randn(n)
x2 = x1 + np.random.randn(n) * 0.1  # x2 is almost identical to x1
x3 = np.random.randn(n)              # x3 is independent

X_collinear = np.column_stack([x1, x2, x3])
y_collinear = 3 + 2 * x1 - 1 * x3 + np.random.randn(n) * 0.5

df_collinear = pd.DataFrame(X_collinear, columns=["x1", "x2", "x3"])

# Correlation matrix
plt.figure(figsize=(6, 5))
corr = df_collinear.corr()
sns.heatmap(corr, annot=True, cmap="coolwarm", vmin=-1, vmax=1,
            square=True, fmt=".3f", linewidths=0.5)
plt.title("Correlation Matrix (x1 and x2 are highly correlated)")
plt.tight_layout()
plt.show()

In [None]:
# Compute VIF for each feature
def compute_vif(X_df):
    """Compute Variance Inflation Factor for each feature."""
    vif_data = []
    for i, col in enumerate(X_df.columns):
        # Regress feature i on all other features
        other_cols = [c for c in X_df.columns if c != col]
        X_other = X_df[other_cols].values
        y_feat = X_df[col].values

        model = LinearRegression()
        model.fit(X_other, y_feat)
        r2 = model.score(X_other, y_feat)

        vif = 1 / (1 - r2) if r2 < 1 else float("inf")
        vif_data.append({"Feature": col, "VIF": vif})

    return pd.DataFrame(vif_data)

vif_df = compute_vif(df_collinear)
print("Variance Inflation Factors:")
print(vif_df.to_string(index=False))
print("\nVIF > 10 for x1 and x2 indicates serious multicollinearity.")

In [None]:
# Impact on coefficients: fit the same model many times with slight data changes
np.random.seed(42)
coefs_collinear = []

for _ in range(100):
    # Bootstrap resample
    idx = np.random.choice(n, size=n, replace=True)
    model_temp = LinearRegression()
    model_temp.fit(X_collinear[idx], y_collinear[idx])
    coefs_collinear.append(model_temp.coef_)

coefs_collinear = np.array(coefs_collinear)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, name in enumerate(["x1 (collinear)", "x2 (collinear)", "x3 (independent)"]):
    axes[i].hist(coefs_collinear[:, i], bins=20, edgecolor="k", alpha=0.7,
                 color="salmon" if i < 2 else "steelblue")
    axes[i].set_title(f"{name}\nStd = {coefs_collinear[:, i].std():.3f}")
    axes[i].set_xlabel("Coefficient value")

plt.suptitle("Coefficient Instability Due to Multicollinearity", fontsize=13, y=1.02)
plt.tight_layout()
plt.show()

print("x1 and x2 coefficients vary wildly across resamples,")
print("while x3 (independent) remains stable.")

---

## 7. When Assumptions Are Violated: What to Do

| Violation | Possible Remedies |
|---|---|
| **Nonlinearity** | Add polynomial features, use log/sqrt transforms, switch to a nonlinear model |
| **Non-normal residuals** | Transform the target (log, Box-Cox), use robust regression, increase sample size |
| **Heteroscedasticity** | Transform the target (log), use weighted least squares, use robust standard errors |
| **Multicollinearity** | Remove one of the correlated features, use PCA, use Ridge regression (L2 regularization) |
| **Non-independence** | Use time-series models (for temporal data), mixed-effects models (for grouped data) |

**Key takeaway:** Diagnostic plots are not just academic exercises. They tell you whether you can trust your model's predictions and coefficient estimates.

---

## 8. Common Mistakes

| Mistake | Why It's a Problem | Fix |
|---|---|---|
| Ignoring residual plots | You miss violations that make coefficients unreliable | Always create residual vs fitted and Q-Q plots |
| Assuming linear regression always works | Many real relationships are nonlinear | Check linearity assumption first |
| Reporting R2 without diagnostics | High R2 can still have violated assumptions | R2 is necessary but not sufficient |
| Ignoring multicollinearity | Coefficients become uninterpretable | Check correlation matrix and VIF |
| Only checking normality of features | The assumption is about **residuals**, not inputs | Check residual normality after fitting |

---

## 9. Exercise

**Task:** Generate a dataset with a known violation, fit a linear regression, and create diagnostic plots.

Steps:
1. Generate data where $y = 5 + 2x^2 + \text{noise}$ (linearity violated)
2. Fit a `LinearRegression` model
3. Create a residual vs fitted plot and a Q-Q plot
4. Describe what the plots reveal about the assumption violations

In [None]:
# --- Exercise Solution ---

# Step 1: Generate nonlinear data
np.random.seed(42)
X_ex = np.random.uniform(-3, 3, (150, 1))
y_ex = 5 + 2 * X_ex.flatten() ** 2 + np.random.randn(150) * 1.5

# Step 2: Fit linear regression
model_ex = LinearRegression()
model_ex.fit(X_ex, y_ex)
y_pred_ex = model_ex.predict(X_ex)
residuals_ex = y_ex - y_pred_ex

# Step 3: Diagnostic plots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Scatter with fit line
X_sorted_ex = np.sort(X_ex, axis=0)
axes[0].scatter(X_ex, y_ex, alpha=0.6, edgecolors="k", linewidths=0.5)
axes[0].plot(X_sorted_ex, model_ex.predict(X_sorted_ex), "r-", linewidth=2, label="Linear fit")
axes[0].set_xlabel("x")
axes[0].set_ylabel("y")
axes[0].set_title("Data + Linear Fit (Clearly Wrong)")
axes[0].legend()

# Residual vs fitted
axes[1].scatter(y_pred_ex, residuals_ex, alpha=0.6, edgecolors="k", linewidths=0.5, color="salmon")
axes[1].axhline(y=0, color="r", linestyle="--", linewidth=1.5)
axes[1].set_xlabel("Fitted values")
axes[1].set_ylabel("Residuals")
axes[1].set_title("Residual vs Fitted (U-shaped = Nonlinearity)")

# Q-Q plot
stats.probplot(residuals_ex, dist="norm", plot=axes[2])
axes[2].set_title("Q-Q Plot of Residuals")
axes[2].get_lines()[0].set_markerfacecolor("steelblue")
axes[2].get_lines()[0].set_alpha(0.6)

plt.tight_layout()
plt.show()

# Step 4: Interpretation
print("Diagnostic Summary:")
print("- Residual vs Fitted: Clear U-shaped curve -> linearity assumption violated")
print("- Q-Q Plot: Deviations from the line -> residuals are not perfectly normal")
print("- Fix: Add polynomial feature (x^2) or use a nonlinear model")
print(f"\nR2 with linear model: {model_ex.score(X_ex, y_ex):.4f}")