# Generalized Additive Models (GAMs): Flexible Non-Linear Regression

**The Problem:** Linear regression assumes relationships are straight lines. Real data rarely cooperates.

**The Solution:** Generalized Additive Models (GAMs) combine:
- **Flexibility** of non-linear patterns (like neural networks)
- **Interpretability** of additive models (like linear regression)
- **Statistical rigor** with p-values and confidence intervals (like classical statistics)

GAMs model:
$$y = \beta_0 + f_1(x_1) + f_2(x_2) + ... + f_p(x_p) + \varepsilon$$

Where each $f_j$ is a **smooth function** learned from data, not a fixed parameter.

`tangent/ds` implements mgcv-like GAMs with:
- Multiple basis types (B-splines, cubic regression splines, truncated power)
- Automatic smoothness selection (GCV, REML)
- Statistical inference (EDF, p-values, confidence intervals)

In [None]:
// Setup DOM for plotting in Jupyter with Deno
import { Window } from 'https://esm.sh/happy-dom@12.10.3';
const window = new Window();
globalThis.document = window.document;
globalThis.HTMLElement = window.HTMLElement;

// import packages
import * as ds from '../../src/index.js';
import * as Plot from '@observablehq/plot';

console.log('GAM module loaded successfully');

## Example 1: Non-Linear Regression

Let's fit a sine wave - a classic non-linear pattern that linear regression fails on.

In [None]:
// Generate sine wave data with noise
const X_sine = [];
const y_sine = [];

for (let i = 0; i < 100; i++) {
  const x = (i / 99) * 2 * Math.PI;
  X_sine.push([x]);
  y_sine.push(Math.sin(x) + (Math.random() - 0.5) * 0.2);
}

// Plot data
ds.plot.plotScatter({
  x: X_sine.map(row => row[0]),
  y: y_sine,
  title: 'Sine Wave with Noise'
}).show(Plot);

### Fit GAM with Default Settings

By default:
- Uses truncated power basis (`tp`)
- No smoothness penalty (backward compatible)
- 4 spline knots

In [None]:
// Fit basic GAM
const gam_basic = new ds.ml.GAMRegressor({ nSplines: 8 });
gam_basic.fit(X_sine, y_sine);

// Predict on fine grid
const X_pred = [];
for (let i = 0; i < 200; i++) {
  X_pred.push([(i / 199) * 2 * Math.PI]);
}

const y_pred_basic = gam_basic.predict(X_pred);

console.log('Basic GAM fitted successfully');
console.log('R² on training data:', gam_basic.gam.r2.toFixed(3));

### Fit with Automatic Smoothness Selection (GCV)

**GCV (Generalized Cross-Validation)** automatically finds the right amount of smoothing:
- Too little smoothing → overfits noise
- Too much smoothing → misses real patterns
- GCV balances the trade-off

In [None]:
// Fit GAM with GCV smoothness selection
const gam_gcv = new ds.ml.GAMRegressor({
  nSplines: 12,
  basis: 'cr',  // cubic regression splines
  smoothMethod: 'GCV'  // automatic smoothness
});

gam_gcv.fit(X_sine, y_sine);
const y_pred_gcv = gam_gcv.predict(X_pred);

console.log('GAM with GCV fitted successfully');
console.log('R² on training data:', gam_gcv.gam.r2.toFixed(3));
console.log('Effective degrees of freedom:', gam_gcv.gam.edf.toFixed(2));

### Compare Predictions

In [None]:
// Visualize fits
const x_plot = X_pred.map(row => row[0]);

ds.plot.plotLine({
  series: [
    { x: X_sine.map(row => row[0]), y: y_sine, label: 'Data', type: 'scatter' },
    { x: x_plot, y: y_pred_basic, label: 'GAM (no penalty)', type: 'line' },
    { x: x_plot, y: y_pred_gcv, label: 'GAM (GCV)', type: 'line' },
    { x: x_plot, y: x_plot.map(x => Math.sin(x)), label: 'True function', type: 'line', style: 'dashed' }
  ],
  title: 'GAM Comparison: Basic vs GCV',
  xLabel: 'x',
  yLabel: 'y'
}).show(Plot);

## Statistical Inference: mgcv-like Summary

GAMs provide rich statistical information:

In [None]:
// Get summary (like R's mgcv)
const summary = gam_gcv.summary();

console.log('=== GAM Summary ===\n');
console.log('Sample size:', summary.n);
console.log('R-squared:', summary.rSquared.toFixed(4));
console.log('Residual std error:', summary.residualStdError.toFixed(4));
console.log('Total EDF:', summary.edf.toFixed(2));

console.log('\n=== Smooth Terms ===');
summary.smoothTerms.forEach(term => {
  console.log(`\n${term.term}:`);
  console.log('  EDF:', term.edf.toFixed(2), '/', term.refDf);
  console.log('  p-value:', term.pValue.toFixed(4));
  console.log('  Significance:', term.pValue < 0.001 ? '***' : term.pValue < 0.01 ? '**' : term.pValue < 0.05 ? '*' : 'ns');
});

**Interpreting the Summary:**

- **EDF (Effective Degrees of Freedom)**: How "wiggly" the smooth is
  - EDF ≈ 1 → almost linear
  - EDF ≈ 2-3 → modest curvature
  - EDF > 5 → complex non-linearity
  
- **p-value**: Is this smooth term significant?
  - < 0.05 → significant non-linear effect
  - > 0.05 → could be linear or not important

- **R²**: Proportion of variance explained (like linear regression)

## Confidence Intervals

Quantify uncertainty in predictions:

In [None]:
// Predictions with 95% confidence intervals
const result = gam_gcv.predictWithInterval(X_pred, 0.95);

console.log('Predictions with confidence intervals:');
console.log('Sample at x=π/2:');
const idx = Math.floor(X_pred.length / 4);
console.log('  Fitted:', result.fitted[idx].toFixed(3));
console.log('  95% CI: [', result.lower[idx].toFixed(3), ',', result.upper[idx].toFixed(3), ']');
console.log('  Standard error:', result.se[idx].toFixed(3));

In [None]:
// Plot with confidence band
ds.plot.plotLine({
  series: [
    { x: X_sine.map(row => row[0]), y: y_sine, label: 'Data', type: 'scatter' },
    { x: x_plot, y: result.fitted, label: 'GAM fit', type: 'line' },
    { x: x_plot, y: result.lower, label: '95% CI', type: 'line', style: 'dashed', opacity: 0.5 },
    { x: x_plot, y: result.upper, label: '', type: 'line', style: 'dashed', opacity: 0.5 }
  ],
  title: 'GAM with 95% Confidence Intervals',
  xLabel: 'x',
  yLabel: 'y'
}).show(Plot);

## Example 2: Multiple Predictors

GAMs naturally handle multiple features with additive smooths.

In [None]:
// Generate data with two predictors
const X_multi = [];
const y_multi = [];

for (let i = 0; i < 200; i++) {
  const x1 = (i / 199) * 2 * Math.PI;
  const x2 = (i % 20) / 19 * Math.PI;
  
  X_multi.push([x1, x2]);
  y_multi.push(
    Math.sin(x1) + 0.5 * Math.cos(2 * x2) + (Math.random() - 0.5) * 0.15
  );
}

console.log('Generated', X_multi.length, 'samples with 2 features');

In [None]:
// Fit GAM with multiple smooths
const gam_multi = new ds.ml.GAMRegressor({
  nSplines: 8,
  basis: 'cr',
  smoothMethod: 'GCV'
});

gam_multi.fit(X_multi, y_multi);

// Get summary
const summary_multi = gam_multi.summary();

console.log('=== Multi-predictor GAM ===\n');
console.log('R-squared:', summary_multi.rSquared.toFixed(4));
console.log('Total EDF:', summary_multi.edf.toFixed(2));

console.log('\n=== Individual Smooth Terms ===');
summary_multi.smoothTerms.forEach(term => {
  console.log(`\n${term.term}:`);
  console.log('  EDF:', term.edf.toFixed(2));
  console.log('  p-value:', term.pValue < 0.001 ? '<0.001***' : term.pValue.toFixed(4));
});

**Interpretation:**
- Each feature gets its own smooth function: $y = f_1(x_1) + f_2(x_2)$
- EDF tells you how non-linear each effect is
- p-values test if each smooth is significant
- **Additive** structure maintains interpretability

## Example 3: Classification with GAM

GAMs work for binary classification too (logistic GAM).

In [None]:
// Generate classification data with non-linear boundary
const X_class = [];
const y_class = [];

for (let i = 0; i < 200; i++) {
  const x = (i / 199) * 6 - 3;
  const prob = 1 / (1 + Math.exp(-x * Math.sin(x)));
  
  X_class.push([x]);
  y_class.push(Math.random() < prob ? 'A' : 'B');
}

console.log('Generated classification data:');
const countA = y_class.filter(y => y === 'A').length;
console.log('Class A:', countA, '| Class B:', y_class.length - countA);

In [None]:
// Fit GAM classifier
const gam_class = new ds.ml.GAMClassifier({
  nSplines: 10,
  basis: 'cr',
  lambda: 0.01,  // small penalty
  maxIter: 100
});

gam_class.fit(X_class, y_class);

// Predict on grid
const X_class_pred = [];
for (let i = 0; i < 100; i++) {
  X_class_pred.push([(i / 99) * 6 - 3]);
}

const y_class_pred = gam_class.predict(X_class_pred);
const y_class_proba = gam_class.predictProba(X_class_pred);

console.log('GAM classifier fitted');
console.log('Sample predictions:', y_class_pred.slice(0, 10));
console.log('Sample probabilities:', y_class_proba.slice(0, 5).map(p => p.map(v => v.toFixed(3))));

## Basis Types Comparison

Different basis functions have different properties:

In [None]:
// Compare basis types
const basis_types = ['tp', 'cr', 'bs'];
const basis_names = {
  'tp': 'Truncated Power',
  'cr': 'Cubic Regression Spline',
  'bs': 'B-spline'
};

console.log('=== Basis Function Comparison ===\n');

basis_types.forEach(basis => {
  const gam_basis = new ds.ml.GAMRegressor({
    nSplines: 8,
    basis: basis,
    smoothMethod: null  // no penalty for comparison
  });
  
  gam_basis.fit(X_sine, y_sine);
  const r2 = gam_basis.gam.r2;
  
  console.log(`${basis_names[basis]} (${basis}):`);
  console.log('  R²:', r2.toFixed(4));
  console.log('  # basis functions:', gam_basis.gam.coef.length);
  console.log('');
});

**Choosing a Basis:**

- **Truncated Power (`tp`)**: 
  - Simple, interpretable
  - Good for gentle curves
  - Default for backward compatibility

- **Cubic Regression Spline (`cr`)**: 
  - Natural boundary constraints (linear at edges)
  - Good general-purpose choice
  - Similar to R's mgcv default

- **B-spline (`bs`)**: 
  - Numerically stable (Cox-de Boor recursion)
  - Local support (changing one knot affects only nearby region)
  - Best for complex shapes

## Smoothness Selection: GCV vs REML

Two automatic methods for choosing smoothing penalty:

In [None]:
// Compare GCV vs REML
const methods = ['GCV', 'REML'];

console.log('=== Smoothness Selection Methods ===\n');

methods.forEach(method => {
  const gam = new ds.ml.GAMRegressor({
    nSplines: 12,
    basis: 'cr',
    smoothMethod: method
  });
  
  gam.fit(X_sine, y_sine);
  const summary = gam.summary();
  
  console.log(`${method}:`);
  console.log('  R²:', summary.rSquared.toFixed(4));
  console.log('  Total EDF:', summary.edf.toFixed(2));
  console.log('  Residual std error:', summary.residualStdError.toFixed(4));
  console.log('');
});

**GCV vs REML:**

- **GCV (Generalized Cross-Validation)**:
  - Minimizes prediction error
  - Efficient to compute
  - Can undersmooth (slightly wiggly)

- **REML (Restricted Maximum Likelihood)**:
  - Unbiased smoothness estimates
  - Better uncertainty quantification
  - Slightly slower

**Rule of thumb:** Start with GCV, use REML if you need precise inference.

## Practical Tips

### When to use GAMs:
✅ Relationships are clearly non-linear
✅ You need interpretability (additive structure)
✅ You want statistical inference (p-values, CIs)
✅ Moderate sample size (n > 100)
✅ Not too many features (p < 20)

### When NOT to use GAMs:
❌ Strong interactions between features (use trees/neural nets)
❌ Very high dimensions (use regularized linear models)
❌ Pure prediction focus (use gradient boosting)
❌ Small sample size (use simpler models)

### Hyperparameter Guidelines:

**nSplines:**
- Start with 4-8
- Increase if residuals show patterns
- Too many → overfitting (use smoothness penalty!)

**smoothMethod:**
- `null`: No penalty (only if you know nSplines is right)
- `'GCV'`: General purpose, fast
- `'REML'`: When inference is critical

**basis:**
- `'tp'`: Simple, stable for gentle curves
- `'cr'`: Good default for most problems
- `'bs'`: Complex shapes, numerical stability

**lambda (fixed penalty):**
- Small (0.001-0.01): Light smoothing
- Medium (0.1-1): Moderate smoothing
- Large (10+): Heavy smoothing (almost linear)

## Example 4: Real Data - MPG Prediction

Let's apply GAM to predict car fuel efficiency (miles per gallon) from engine characteristics.

In [None]:
// Load mtcars dataset
const mtcarsResponse = await fetch(
  'https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5aab4d93091f3947f9e7ffe12b6d8b8b85e8b1b4/mtcars.csv'
);
const mtcarsText = await mtcarsResponse.text();
const mtcarsLines = mtcarsText.trim().split('\n');
const mtcarsHeader = mtcarsLines[0].split(',');

const mtcarsData = mtcarsLines.slice(1).map(line => {
  const values = line.split(',');
  const row = {};
  mtcarsHeader.forEach((col, i) => {
    row[col] = col === 'model' ? values[i] : parseFloat(values[i]);
  });
  return row;
});

console.log('Loaded', mtcarsData.length, 'cars');
console.table(mtcarsData.slice(0, 5));

In [None]:
// Prepare data: predict mpg from weight (wt) and horsepower (hp)
const X_cars = mtcarsData.map(row => [row.wt, row.hp]);
const y_cars = mtcarsData.map(row => row.mpg);

// Fit GAM with automatic smoothness
const gam_cars = new ds.ml.GAMRegressor({
  nSplines: 5,
  basis: 'cr',
  smoothMethod: 'GCV'
});

gam_cars.fit(X_cars, y_cars);

const summary_cars = gam_cars.summary();

console.log('=== MPG Prediction GAM ===\n');
console.log('R²:', summary_cars.rSquared.toFixed(4));
console.log('RMSE:', Math.sqrt(
  summary_cars.rss / summary_cars.n
).toFixed(2), 'mpg');

console.log('\n=== Smooth Effects ===');
summary_cars.smoothTerms.forEach((term, i) => {
  const feature = i === 0 ? 'weight' : 'horsepower';
  console.log(`\n${feature} (${term.term}):`);
  console.log('  EDF:', term.edf.toFixed(2));
  console.log('  Non-linearity:', term.edf > 2 ? 'Strong' : term.edf > 1.5 ? 'Moderate' : 'Weak');
  console.log('  p-value:', term.pValue < 0.001 ? '<0.001***' : term.pValue.toFixed(4));
});

**Interpretation:**
- Both weight and horsepower have significant non-linear effects on MPG
- EDF tells us how complex each relationship is
- High R² indicates GAM captures the non-linear relationships well
- Each effect is interpretable (can plot $f_1(weight)$ and $f_2(horsepower)$ separately)

## Summary

GAMs combine the best of both worlds:

✅ **Flexibility**: Capture complex non-linear patterns
✅ **Interpretability**: Additive structure (sum of smooth effects)
✅ **Statistical rigor**: p-values, EDF, confidence intervals
✅ **Automatic tuning**: GCV/REML for smoothness selection

### Feature Comparison:

| Feature | `tangent/ds` GAM | R mgcv |
|---------|------------------|--------|
| Multiple basis types | ✅ tp, cr, bs | ✅ |
| Automatic smoothness | ✅ GCV, REML | ✅ |
| Statistical inference | ✅ EDF, p-values, CIs | ✅ |
| Binary classification | ✅ Logistic GAM | ✅ |
| Confidence intervals | ✅ Point-wise CIs | ✅ |
| Tensor products | ❌ (future) | ✅ |
| Custom families | ❌ (future) | ✅ |

### Quick Reference:

```javascript
// Basic usage
const gam = new ds.ml.GAMRegressor({
  nSplines: 8,           // number of basis functions
  basis: 'cr',           // 'tp', 'cr', or 'bs'
  smoothMethod: 'GCV',   // 'GCV', 'REML', or null
  lambda: null           // or fixed penalty value
});

gam.fit(X, y);

// Predictions
const predictions = gam.predict(X_new);

// With confidence intervals
const result = gam.predictWithInterval(X_new, 0.95);
// result.fitted, result.lower, result.upper, result.se

// Statistical summary
const summary = gam.summary();
// summary.rSquared, summary.edf, summary.smoothTerms
```

GAMs are a powerful tool for understanding complex relationships while maintaining interpretability and statistical rigor!