# Invertible Neural Networks (Normalizing Flows)

This notebook demonstrates two complementary uses of invertible neural networks:

1. **Part 1: Regression with Uncertainty** (ConditionalInvertibleNN)
   - **Problem**: Predict outputs Y from inputs X with uncertainty estimates
   - **Solution**: Learn p(Y|X) - the conditional distribution of outputs given inputs
   - **Use cases**: Predictions with confidence intervals, heteroscedastic uncertainty

2. **Part 2: Density Estimation & Generative Modeling** (InvertibleNN)
   - **Problem**: Learn the probability distribution of data X
   - **Solution**: Learn p(X) - transform data to/from a simple Gaussian distribution
   - **Use cases**: Anomaly detection, synthetic data generation, likelihood computation

## Mathematical Foundation

Both models use **normalizing flows** - invertible transformations with tractable Jacobians:

$$\log p(x) = \log p(z) + \log \left|\det \frac{\partial f}{\partial x}\right|$$

where $z = f(x)$ and $p(z)$ is a simple base distribution (Gaussian).

In [1]:
import jax
import jax.numpy as np
from sklearn.datasets import make_moons, make_circles, make_blobs
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from pycse.sklearn.cinn import ConditionalInvertibleNN
from pycse.sklearn.inn import InvertibleNN

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

print("‚úì ConditionalInvertibleNN (Regression) imported!")
print("‚úì InvertibleNN (Density Estimation) imported!")

‚úì ConditionalInvertibleNN (Regression) imported!
‚úì InvertibleNN (Density Estimation) imported!


---

# Part 1: Regression with Uncertainty (ConditionalInvertibleNN)

## Problem: Predicting Outputs with Uncertainty

**Traditional regression** (linear regression, neural networks) gives point predictions:
- Input X ‚Üí Output Y
- No uncertainty estimates
- Assumes constant noise

**ConditionalInvertibleNN** learns the full conditional distribution p(Y|X):
- Input X ‚Üí Distribution over Y
- Provides mean predictions AND uncertainty
- Handles heteroscedastic noise (input-dependent uncertainty)
- Can sample multiple plausible predictions

**When to use this:**
- Need confidence intervals on predictions
- Uncertainty varies across input space (heteroscedastic)
- Want to quantify prediction reliability
- Need full predictive distribution, not just point estimates

## Example 1: Simple Regression with Uncertainty

In [2]:
# Generate simple regression data
key = jax.random.PRNGKey(42)
X = np.linspace(-3, 3, 200)[:, None]
y_true = 2 * X + 1
y = y_true + 0.3 * jax.random.normal(key, X.shape)

print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Task: Learn Y = f(X) with uncertainty")

Data shape: X=(200, 1), y=(200, 1)
Task: Learn Y = f(X) with uncertainty


In [3]:
# Create conditional flow for regression
cinn = ConditionalInvertibleNN(
    n_features_in=1,   # 1D input (X)
    n_features_out=1,  # 1D output (Y)
    n_layers=8,  # More layers for better modeling
    hidden_dims=[128, 128],  # Larger network
    seed=42
)

cinn

0,1,2
,n_features_in,1
,n_features_out,1
,n_layers,8
,hidden_dims,"[128, 128]"
,seed,42


In [None]:
print("Training regression model...")
cinn.fit(X, y, maxiter=2000)  # More training

# Print report
cinn.report()

Training regression model...


In [None]:
# Predict with uncertainty - use many samples for smooth estimates
y_pred, y_std = cinn.predict(X, return_std=True, n_samples=2000)

# Visualize
plt.figure(figsize=(10, 6))
plt.scatter(X, y, alpha=0.3, s=10, label='Training data', c='gray')
plt.plot(X, y_pred, 'r-', label='Mean prediction', linewidth=2)
plt.fill_between(
    X.ravel(),
    (y_pred - 2*y_std).ravel(),
    (y_pred + 2*y_std).ravel(),
    alpha=0.3,
    color='red',
    label='95% confidence interval'
)
plt.plot(X, y_true, 'k--', label='True function', linewidth=1)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Regression with Uncertainty Quantification')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()

print("‚úì Model provides mean prediction AND uncertainty estimates!")

## Example 2: Heteroscedastic Regression (Varying Noise)

**Real-world challenge**: Uncertainty often varies across the input space!

- Near x=0: Low noise, high confidence
- Far from x=0: High noise, low confidence

ConditionalInvertibleNN **automatically learns** this input-dependent uncertainty.

In [None]:
# Generate data with varying noise levels
key = jax.random.PRNGKey(99)
X_het = np.linspace(-3, 3, 250)[:, None]
y_true_het = X_het**2

# Noise increases with |X|
noise_std = 0.1 + 0.3 * np.abs(X_het)
noise = noise_std * jax.random.normal(key, X_het.shape)
y_het = y_true_het + noise

print("Generated heteroscedastic data:")
print(f"  Noise at X=0: ~{noise_std[len(X_het)//2, 0]:.2f}")
print(f"  Noise at X=¬±3: ~{noise_std[-1, 0]:.2f}")
print("  ‚Üí Noise level depends on X!")

In [None]:
# Train conditional flow with better hyperparameters
cinn_het = ConditionalInvertibleNN(
    n_features_in=1,
    n_features_out=1,
    n_layers=10,  # More layers to capture input-dependent uncertainty
    hidden_dims=[128, 128, 128],  # Deeper network
    seed=42
)

print("Training on heteroscedastic data...")
cinn_het.fit(X_het, y_het, maxiter=2500)  # More iterations
print("Training complete!")

In [None]:
# Predict with learned uncertainty - use many samples for smooth estimate
y_pred_het, y_std_het = cinn_het.predict(X_het, return_std=True, n_samples=5000)

# Apply smoothing to uncertainty for cleaner visualization
from scipy.ndimage import gaussian_filter1d
y_std_het_smooth = gaussian_filter1d(y_std_het.ravel(), sigma=5)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Left: Predictions with confidence bands
ax = axes[0]
ax.scatter(X_het, y_het, alpha=0.3, s=10, label='Training data', c='gray')
ax.plot(X_het, y_pred_het, 'r-', label='Mean prediction', linewidth=2)
ax.fill_between(
    X_het.ravel(),
    (y_pred_het.ravel() - 2*y_std_het_smooth),
    (y_pred_het.ravel() + 2*y_std_het_smooth),
    alpha=0.3,
    color='red',
    label='95% confidence'
)
ax.plot(X_het, y_true_het, 'k--', label='True function', linewidth=1)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Heteroscedastic Regression')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Learned vs True uncertainty
ax = axes[1]
ax.plot(X_het, noise_std * 2, 'k--', label='True noise (2œÉ)', linewidth=2)
ax.plot(X_het, y_std_het_smooth * 2, 'r-', label='Learned uncertainty (2œÉ, smoothed)', linewidth=2, alpha=0.8)
ax.set_xlabel('X')
ax.set_ylabel('Uncertainty (2œÉ)')
ax.set_title('Model Learns Input-Dependent Uncertainty!')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()

print("\n‚úì Uncertainty bands widen where data is noisier!")
print("‚úì This is heteroscedastic uncertainty - traditional regression can't do this!")

## Example 3: Sampling from the Predictive Distribution

Unlike traditional regression (single prediction), ConditionalInvertibleNN learns the **full distribution** p(Y|X).

We can sample multiple plausible predictions for each input.

In [None]:
# Generate nonlinear data
key = jax.random.PRNGKey(123)
X_nl = np.linspace(-2*np.pi, 2*np.pi, 300)[:, None]
y_true_nl = np.sin(X_nl)
y_nl = y_true_nl + 0.15 * jax.random.normal(key, X_nl.shape)

# Train model with better hyperparameters
cinn_nl = ConditionalInvertibleNN(
    n_features_in=1,
    n_features_out=1,
    n_layers=10,  # More layers
    hidden_dims=[128, 128],  # Larger network
    seed=42
)

print("Training on sine wave...")
cinn_nl.fit(X_nl, y_nl, maxiter=2000)  # More training

# Get predictions with samples
y_pred_nl, y_samples = cinn_nl.predict(X_nl, return_samples=True, n_samples=500)

print(f"Generated {y_samples.shape[0]} plausible predictions for each input!")

In [None]:
# Visualize samples from predictive distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Left: Mean and confidence
ax = axes[0]
y_std_nl = np.std(y_samples, axis=0)
ax.scatter(X_nl, y_nl, alpha=0.3, s=10, label='Training data', c='gray')
ax.plot(X_nl, y_pred_nl, 'r-', label='Mean prediction', linewidth=2)
ax.fill_between(
    X_nl.ravel(),
    (y_pred_nl - 2*y_std_nl).ravel(),
    (y_pred_nl + 2*y_std_nl).ravel(),
    alpha=0.3,
    color='red',
    label='95% confidence'
)
ax.plot(X_nl, y_true_nl, 'k--', label='True function', linewidth=1)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Mean Prediction with Uncertainty')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Sampled predictions
ax = axes[1]
for i in range(min(20, y_samples.shape[0])):
    ax.plot(X_nl, y_samples[i], 'r-', alpha=0.1, linewidth=0.5)
ax.scatter(X_nl[::10], y_nl[::10], alpha=0.5, s=20, label='Training data', c='gray')
ax.plot(X_nl, y_true_nl, 'k--', label='True function', linewidth=2)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('20 Samples from p(Y|X) for Each X')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()

print("\n‚úì Each red line is a plausible prediction!")
print("‚úì Sampling gives us the full distribution, not just a single answer.")

## Example 4: Multi-Output Regression

ConditionalInvertibleNN handles **multi-dimensional outputs** with correlated uncertainties.

In [None]:
# Generate 2D output from 1D input (parametric curve)
key = jax.random.PRNGKey(77)
t = np.linspace(0, 2*np.pi, 200)[:, None]

# Lissajous curve
x_true = np.sin(3*t)
y_true = np.cos(2*t)

x = x_true + 0.1 * jax.random.normal(key, x_true.shape)
y = y_true + 0.1 * jax.random.normal(key, y_true.shape)

Y_multi = np.concatenate([x, y], axis=1)  # Shape: (200, 2)

print(f"Multi-output regression:")
print(f"  Input: t (time) - shape {t.shape}")
print(f"  Output: (x, y) coordinates - shape {Y_multi.shape}")

In [None]:
# Train multi-output model with better hyperparameters
cinn_multi = ConditionalInvertibleNN(
    n_features_in=1,   # Input: t
    n_features_out=2,  # Output: (x, y)
    n_layers=10,  # More layers
    hidden_dims=[128, 128],  # Larger network
    seed=42
)

print("Training multi-output model...")
cinn_multi.fit(t, Y_multi, maxiter=2000)  # More training

# Predict with more samples for smoother uncertainty
Y_pred, Y_std = cinn_multi.predict(t, return_std=True, n_samples=500)

print("Multi-output predictions complete!")

In [None]:
# Visualize multi-output regression
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Left: X output
ax = axes[0]
ax.scatter(t, x, alpha=0.3, s=10, label='Training data (x)', c='gray')
ax.plot(t, Y_pred[:, 0], 'r-', label='Prediction (x)', linewidth=2)
ax.fill_between(
    t.ravel(),
    (Y_pred[:, 0] - 2*Y_std[:, 0]).ravel(),
    (Y_pred[:, 0] + 2*Y_std[:, 0]).ravel(),
    alpha=0.3,
    color='red'
)
ax.plot(t, x_true, 'k--', label='True x', linewidth=1)
ax.set_xlabel('t')
ax.set_ylabel('x')
ax.set_title('X Component with Uncertainty')
ax.legend()
ax.grid(True, alpha=0.3)

# Middle: Y output
ax = axes[1]
ax.scatter(t, y, alpha=0.3, s=10, label='Training data (y)', c='gray')
ax.plot(t, Y_pred[:, 1], 'r-', label='Prediction (y)', linewidth=2)
ax.fill_between(
    t.ravel(),
    (Y_pred[:, 1] - 2*Y_std[:, 1]).ravel(),
    (Y_pred[:, 1] + 2*Y_std[:, 1]).ravel(),
    alpha=0.3,
    color='red'
)
ax.plot(t, y_true, 'k--', label='True y', linewidth=1)
ax.set_xlabel('t')
ax.set_ylabel('y')
ax.set_title('Y Component with Uncertainty')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Parametric curve
ax = axes[2]
ax.scatter(x, y, alpha=0.3, s=10, label='Training data', c='gray')
ax.plot(Y_pred[:, 0], Y_pred[:, 1], 'r-', label='Prediction', linewidth=2)
ax.plot(x_true, y_true, 'k--', label='True curve', linewidth=1)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Parametric Curve (x, y)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

plt.tight_layout()

print("\n‚úì Single model predicts both x AND y with their uncertainties!")
print("‚úì Captures correlations between output dimensions.")

### Summary: When to Use Conditional INN for Regression

‚úÖ **Use ConditionalInvertibleNN when you need:**
- Predictions with confidence intervals
- Heteroscedastic uncertainty (varies with input)
- Full predictive distribution p(Y|X)
- Multi-output regression with correlations
- To quantify prediction reliability

‚ö†Ô∏è **Not ideal for:**
- Simple tasks where point estimates suffice
- When uncertainty doesn't matter
- Very small datasets (<100 samples)

**Comparison to alternatives:**
- vs **Standard NN**: CINN provides uncertainty, not just point predictions
- vs **Bayesian NN**: CINN is faster, easier to train, exact likelihoods
- vs **Gaussian Process**: CINN scales better to high dimensions

---

# Part 2: Density Estimation & Generative Modeling (InvertibleNN)

## Problem: Learning Data Distributions

**Goal**: Learn the probability distribution p(X) of data X itself (not predicting Y from X).

**Why this matters:**
- **Anomaly Detection**: Assign low probability to outliers
- **Generative Modeling**: Sample new, realistic data points
- **Exact Likelihoods**: Know exactly how probable any point is
- **Density Estimation**: Understand the shape of your data distribution

**InvertibleNN approach:**
- Learn transformation: Complex data X ‚Üî Simple Gaussian Z
- Forward: Map data to Gaussian latent space
- Inverse: Generate new data from Gaussian samples
- Exact likelihood via change of variables formula

**When to use this:**
- No input-output pairs, just data X
- Need to detect anomalies
- Want to generate new synthetic data
- Need exact probability densities

## Example 5: Basic Density Estimation - Learning the Moons Distribution

In [None]:
# Generate 2D moons dataset (NO input-output pairs, just data points!)
X_moons, _ = make_moons(n_samples=1000, noise=0.05, random_state=42)
X_moons = np.array(X_moons)

print(f"Training data shape: {X_moons.shape}")
print(f"Task: Learn p(X) - the probability distribution of the moon shape")
print(f"\nNote: No Y here! We're learning the distribution of X itself.")

In [None]:
# Create invertible NN for density estimation with better hyperparameters
inn = InvertibleNN(
    n_features=2,        # 2D data
    n_layers=10,         # More coupling layers
    hidden_dims=[128, 128],  # Larger network
    seed=42
)

print("Training density estimation model...")
inn.fit(X_moons, normalize=True, maxiter=2000, tol=1e-5)  # More training

# Print report
inn.report()

In [None]:
# Visualize the learned distribution
fig = inn.plot(X_moons, n_samples=1000)
plt.suptitle('Density Estimation on Moons Dataset', fontsize=14, y=1.02)
plt.show()

print("\n‚úì Left: Original data (complex moon shape)")
print("‚úì Middle: Latent space (should be ~Gaussian blob)")
print("‚úì Right: Generated samples (new moon-shaped data!)")

## Example 6: Anomaly Detection with Likelihoods

**Use case**: Detect outliers by computing their probability under learned distribution.

**Intuition**: 
- Normal points ‚Üí High probability
- Anomalies ‚Üí Low probability

In [None]:
# Train on "normal" data (centered Gaussian blob)
X_normal, _ = make_blobs(n_samples=500, centers=[[0, 0]], 
                         cluster_std=0.5, random_state=42)
X_normal = np.array(X_normal)

# Train anomaly detector with better hyperparameters
inn_anomaly = InvertibleNN(n_features=2, n_layers=8, hidden_dims=[128, 128], seed=42)
print("Training anomaly detector on normal data...")
inn_anomaly.fit(X_normal, normalize=True, maxiter=1500)  # More training
print("Training complete!\n")

# Test points (some normal, some anomalous)
X_test = np.array([
    [0.0, 0.0],    # Normal (center)
    [0.3, 0.3],    # Normal
    [0.5, -0.5],   # Normal (edge)
    [1.5, 1.5],    # Mild outlier
    [3.0, 3.0],    # Moderate outlier
    [5.0, 5.0],    # Strong outlier
])

# Compute log probabilities
log_probs_test = inn_anomaly.log_prob(X_test)

# Set threshold (5th percentile of training data)
train_log_probs = inn_anomaly.log_prob(X_normal)
threshold = np.percentile(train_log_probs, 5)

print("Anomaly Detection Results:")
print(f"{'Point':<20} {'Log-Prob':<15} {'Status':<10}")
print("-" * 45)

for point, lp in zip(X_test, log_probs_test):
    status = "üö® ANOMALY" if lp < threshold else "‚úì NORMAL"
    print(f"{str(point):<20} {lp:<15.3f} {status:<10}")

print(f"\nThreshold (5th percentile): {threshold:.3f}")
print("\n‚úì Points far from training data have low probability!")

In [None]:
# Visualize likelihood landscape
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Create grid
x1_range = np.linspace(-3, 6, 100)
x2_range = np.linspace(-3, 6, 100)
X1, X2 = np.meshgrid(x1_range, x2_range)
X_grid = np.column_stack([X1.ravel(), X2.ravel()])
log_probs_grid = inn_anomaly.log_prob(X_grid)
probs_grid = np.exp(log_probs_grid).reshape(X1.shape)

# Left: Probability density
ax = axes[0]
contour = ax.contourf(X1, X2, probs_grid, levels=20, cmap='YlOrRd')
ax.scatter(X_normal[:, 0], X_normal[:, 1], alpha=0.3, s=10, c='blue', label='Training')
ax.scatter(X_test[:, 0], X_test[:, 1], s=200, c='black', marker='X', 
           edgecolors='white', linewidths=2, label='Test', zorder=5)
plt.colorbar(contour, ax=ax, label='Probability Density')
ax.set_xlabel('x‚ÇÅ')
ax.set_ylabel('x‚ÇÇ')
ax.set_title('Probability Density Landscape')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Decision boundary
ax = axes[1]
log_probs_reshaped = log_probs_grid.reshape(X1.shape)
contour2 = ax.contourf(X1, X2, log_probs_reshaped, levels=20, cmap='RdYlGn')
ax.contour(X1, X2, log_probs_reshaped, levels=[threshold], 
           colors='black', linewidths=3, linestyles='--')

normal_mask = log_probs_test >= threshold
anomaly_mask = log_probs_test < threshold

if np.any(normal_mask):
    ax.scatter(X_test[normal_mask, 0], X_test[normal_mask, 1], 
               s=200, c='green', marker='o', edgecolors='white', 
               linewidths=2, label='Normal', zorder=5)
if np.any(anomaly_mask):
    ax.scatter(X_test[anomaly_mask, 0], X_test[anomaly_mask, 1], 
               s=200, c='red', marker='X', edgecolors='white', 
               linewidths=2, label='Anomaly', zorder=5)

plt.colorbar(contour2, ax=ax, label='Log-Probability')
ax.set_xlabel('x‚ÇÅ')
ax.set_ylabel('x‚ÇÇ')
ax.set_title('Anomaly Detection Boundary')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Example 7: Forward and Inverse Transformations

**Key feature**: Perfect invertibility!
- Forward: Complex data X ‚Üí Simple Gaussian Z
- Inverse: Gaussian Z ‚Üí Realistic data X
- No information loss (reconstruction error ‚âà 0)

In [None]:
# Forward: data ‚Üí latent
Z, log_det = inn.forward(X_moons[:100])

print("Forward Transformation (X ‚Üí Z):")
print(f"  Original data: {X_moons[:100].shape}")
print(f"  Latent space: {Z.shape}")
print(f"\nLatent statistics (should be ~ N(0, 1)):")
print(f"  Mean: [{np.mean(Z[:, 0]):.3f}, {np.mean(Z[:, 1]):.3f}]")
print(f"  Std:  [{np.std(Z[:, 0]):.3f}, {np.std(Z[:, 1]):.3f}]")

# Inverse: latent ‚Üí data
X_reconstructed = inn.inverse(Z)
reconstruction_error = np.max(np.abs(X_moons[:100] - X_reconstructed))

print(f"\nInverse Transformation (Z ‚Üí X):")
print(f"  Reconstructed: {X_reconstructed.shape}")
print(f"  Max error: {reconstruction_error:.2e}")
print(f"  Perfect invertibility: {reconstruction_error < 1e-6} ‚úì")

In [None]:
# Visualize transformation
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

colors = np.arange(100)

# Original
ax = axes[0]
scatter = ax.scatter(X_moons[:100, 0], X_moons[:100, 1], 
                     c=colors, cmap='viridis', s=50, alpha=0.7)
ax.set_xlabel('x‚ÇÅ')
ax.set_ylabel('x‚ÇÇ')
ax.set_title('Original Data (X)\nComplex Moon Shape')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

# Latent
ax = axes[1]
ax.scatter(Z[:, 0], Z[:, 1], c=colors, cmap='viridis', s=50, alpha=0.7)
# Reference circles
theta = np.linspace(0, 2*np.pi, 100)
for r in [1, 2, 3]:
    ax.plot(r*np.cos(theta), r*np.sin(theta), 'r--', alpha=0.3, linewidth=1)
ax.set_xlabel('z‚ÇÅ')
ax.set_ylabel('z‚ÇÇ')
ax.set_title('Latent Space (Z)\nSimple Gaussian')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

# Reconstructed
ax = axes[2]
ax.scatter(X_reconstructed[:, 0], X_reconstructed[:, 1], 
           c=colors, cmap='viridis', s=50, alpha=0.7)
ax.set_xlabel('x‚ÇÅ')
ax.set_ylabel('x‚ÇÇ')
ax.set_title(f'Reconstructed (f‚Åª¬π(Z))\nError: {reconstruction_error:.2e}')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()

print("\n‚úì Colors show point correspondence")
print("‚úì Complex moon maps to simple Gaussian and back perfectly!")

## Example 8: Testing on Different Distributions

In [None]:
# Create diverse datasets
datasets = {}

X_moons_test, _ = make_moons(n_samples=800, noise=0.05, random_state=42)
datasets['Moons'] = np.array(X_moons_test)

X_circles, _ = make_circles(n_samples=800, noise=0.05, factor=0.5, random_state=42)
datasets['Circles'] = np.array(X_circles)

X_blobs, _ = make_blobs(n_samples=800, centers=4, cluster_std=0.3, random_state=42)
datasets['Blobs'] = np.array(X_blobs)

n_spiral = 800
theta = np.linspace(0, 4*np.pi, n_spiral)
r = theta / (4*np.pi) * 3
X_spiral = np.stack([r*np.cos(theta), r*np.sin(theta)], axis=1)
noise = jax.random.normal(jax.random.PRNGKey(42), X_spiral.shape) * 0.1
datasets['Spiral'] = X_spiral + noise

print("Testing INN on 4 different distributions...")

In [None]:
# Train on each with better hyperparameters
trained_models = {}

for name, X in datasets.items():
    print(f"Training on {name}...")
    model = InvertibleNN(
        n_features=2, 
        n_layers=12,  # More layers for complex distributions
        hidden_dims=[128, 128],  # Larger network
        seed=42
    )
    model.fit(X, normalize=True, maxiter=2000)  # More training
    
    trained_models[name] = {
        'model': model,
        'score': model.score(X),
        'samples': model.sample(500, key=jax.random.PRNGKey(42)),
        'data': X
    }
    print(f"  Score: {trained_models[name]['score']:.3f}\n")

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, name in enumerate(datasets.keys()):
    X = trained_models[name]['data']
    samples = trained_models[name]['samples']
    score = trained_models[name]['score']
    
    # Top: original
    ax = axes[0, idx]
    ax.scatter(X[:, 0], X[:, 1], alpha=0.5, s=10, c='blue')
    ax.set_title(f'{name}\nOriginal', fontsize=10, fontweight='bold')
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, alpha=0.3)
    
    # Bottom: generated
    ax = axes[1, idx]
    ax.scatter(X[:, 0], X[:, 1], alpha=0.2, s=5, c='gray', label='Original')
    ax.scatter(samples[:, 0], samples[:, 1], alpha=0.6, s=10, c='red', label='Generated')
    ax.set_title(f'Score: {score:.2f}', fontsize=10)
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=7)

plt.suptitle('INN Performance on Different Distributions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n‚úì INN works well on diverse distribution types!")
print("‚úì Circles are challenging due to topology - may need more layers.")

### Summary: When to Use INN for Density Estimation

‚úÖ **Use InvertibleNN when you need to:**
- Detect anomalies with calibrated probabilities
- Generate new realistic synthetic data
- Compute exact likelihoods p(X)
- Understand data distribution structure
- No input-output pairs, just samples from distribution

‚ö†Ô∏è **Not ideal for:**
- Supervised learning (use ConditionalINN instead)
- Discrete data (INNs are for continuous distributions)
- Very high dimensions (>100) with limited data

**Comparison to alternatives:**
- vs **VAE**: INN has exact likelihood, no reconstruction error, harder to train
- vs **GAN**: INN is more stable, provides likelihoods, less mode collapse
- vs **KDE**: INN scales better to high dimensions, learns complex distributions

---

# Final Comparison: Conditional vs Unconditional

| Aspect | ConditionalInvertibleNN | InvertibleNN |
|--------|------------------------|-------------|
| **Learns** | p(Y\|X) - outputs given inputs | p(X) - data distribution |
| **Task** | Regression with uncertainty | Density estimation |
| **Input** | fit(X, y) - paired data | fit(X) - just data points |
| **Output** | Predictions + uncertainty | Likelihoods + samples |
| **Use cases** | ‚Ä¢ Regression<br>‚Ä¢ Uncertainty quantification<br>‚Ä¢ Heteroscedastic noise | ‚Ä¢ Anomaly detection<br>‚Ä¢ Generative modeling<br>‚Ä¢ Density estimation |
| **Example** | "Predict house price from features with confidence" | "Detect fraudulent transactions" |

## Recommended Configurations

### For Regression (ConditionalINN):
```python
cinn = ConditionalInvertibleNN(
    n_features_in=X.shape[1],
    n_features_out=y.shape[1],
    n_layers=8-10,          # 8-10 layers for good capacity
    hidden_dims=[128, 128],  # Larger networks work better
    seed=42
)
cinn.fit(X, y, maxiter=2000)  # 2000+ iterations recommended

# Predict with many samples for smooth uncertainty
y_pred, y_std = cinn.predict(X_test, return_std=True, n_samples=2000)
```

**Tips for ConditionalINN:**
- Use 8-10 layers for complex relationships
- Larger networks (128+ hidden units) capture distributions better
- Train longer (2000+ iterations) for convergence
- Use 2000-5000 samples for smooth uncertainty estimates
- For heteroscedastic data: use 10+ layers with deep networks [128, 128, 128]

### For Density Estimation (INN):
```python
inn = InvertibleNN(
    n_features=X.shape[1],
    n_layers=10-12,          # 10-12 layers for complex distributions
    hidden_dims=[128, 128],  # Larger capacity
    seed=42
)
inn.fit(X, normalize=True, maxiter=2000)  # Always normalize!

# For anomaly detection
log_probs = inn.log_prob(X_test)
threshold = np.percentile(train_log_probs, 5)
anomalies = log_probs < threshold

# For generation
samples = inn.sample(n_samples=100)
```

**Tips for INN:**
- **Always use normalization** for stability
- 10-12 layers for complex distributions (moons, spirals)
- Train longer (2000+ iterations) for convergence
- Circles/toroidal shapes are challenging - may need 15+ layers

### When to Use More Resources

**Complex problems need:**
- **More layers** (12-15): Spirals, multi-modal distributions, heteroscedastic regression
- **Deeper networks** [128, 128, 128]: Input-dependent uncertainty, high-dimensional data
- **More iterations** (3000-5000): Slow convergence, complex distributions
- **More samples** (5000+): Smooth uncertainty visualization, precise confidence intervals

**Simple problems can use:**
- **Fewer layers** (6-8): Linear/simple nonlinear regression, Gaussian blobs
- **Smaller networks** [64, 64]: Low-dimensional data, simple distributions
- **Fewer iterations** (1000-1500): Fast convergence
- **Fewer samples** (500-1000): Quick predictions

## Further Reading

- Dinh et al., "Density estimation using Real NVP", ICLR 2017
- Winkler et al., "Learning Likelihoods with Conditional Normalizing Flows", 2019
- Ardizzone et al., "Analyzing Inverse Problems with Invertible Neural Networks", ICLR 2019
- Papamakarios et al., "Normalizing Flows for Probabilistic Modeling", JMLR 2021