This comprehensive guide explores the distributed computing capabilities built into phasic, demonstrating how computational workloads can be scaled from a single laptop to massive computing clusters spanning hundreds of nodes. The material covered here synthesizes everything you need to know about parallel and distributed computing in the context of phase-type distribution analysis, from fundamental concepts through practical implementation and deployment on production clusters.

The journey from sequential computation on a single machine to distributed computation across many machines has historically been fraught with complexity. Researchers and practitioners have traditionally faced hundreds of lines of boilerplate code, intricate environment variable configurations, complex inter-process communication protocols, and cluster-specific quirks that make code difficult to port between different computing environments. The distributed computing framework in phasic was designed specifically to eliminate these barriers, allowing you to write code once and run it anywhere, from your laptop during development to a massive SLURM cluster for production runs.

# Understanding the Need for Distributed Computing

Before diving into implementation details, it's worth understanding why distributed computing matters for phase-type distribution analysis and Bayesian inference. When working with complex phase-type distributions, particularly those arising from population genetic models or intricate Markov processes, the computational demands can quickly become overwhelming for a single machine. Consider a typical Stein Variational Gradient Descent (SVGD) inference problem where we maintain a swarm of particles, each representing a hypothesis about model parameters. Each particle must evaluate the likelihood function, which in our case means computing properties of a phase-type distribution at multiple time points. With hundreds or thousands of particles, each requiring potentially expensive graph traversals and numerical computations, the total computational burden can easily exceed what a single CPU core can handle in reasonable time.

The computational challenge grows multiplicatively rather than additively. If we have 500 particles and each particle requires evaluation at 50 time points across 1000 iterations, we're looking at 25 million individual distribution evaluations. Even if each evaluation takes just 10 milliseconds, that's nearly 70 hours of computation on a single core. By distributing these computations across multiple devices and multiple nodes, we can reduce wall-clock time from days to hours or even minutes, making previously impractical analyses feasible.

Beyond raw computational speed, distributed computing also enables larger-scale problems. With more compute resources, we can maintain more particles for better posterior approximations, run longer chains for improved convergence, or tackle larger state spaces that would exhaust the memory of a single machine. The distributed framework in phasic handles all the orchestration needed to achieve these benefits while keeping your code simple and portable.

# The Architecture of Distributed Computation

Understanding how distributed computation works under the hood helps clarify both its power and its limitations. At its core, distributed computing in phasic builds on JAX's distributed capabilities, which in turn leverage XLA (Accelerated Linear Algebra) for low-level execution and coordination. When you initialize distributed computing, several things happen behind the scenes to set up the computational environment.

First, the system needs to understand the computational topology: how many processes (typically corresponding to physical machines or nodes) are participating, what rank each process holds in the coordination hierarchy, and how many computational devices (CPU cores or GPU devices) each process controls. In a SLURM cluster environment, this information comes from environment variables that the cluster scheduler sets when launching your job. The SLURM_NTASKS variable tells us how many processes exist, SLURM_PROCID identifies which process we are, SLURM_CPUS_PER_TASK indicates how many CPU cores this process should use, and SLURM_JOB_NODELIST provides the list of machines involved. By parsing these variables automatically, the initialization function removes the burden of manual environment parsing that traditionally required dozens of lines of error-prone code.

Second, the processes need to establish communication channels. In JAX's distributed model, one process serves as the coordinator, and all other processes connect to it. The coordinator is typically the process with rank 0, and it needs a known network address where others can reach it. In a SLURM environment, we extract the hostname of the first node in the node list and combine it with a specified port number (defaulting to 12345) to create the coordinator address. Each process then initializes its JAX distributed runtime with this address, its own rank, and the total number of processes. This creates a communication fabric that JAX uses internally to coordinate data movement and synchronization during parallel operations.

Third, the local devices on each process need configuration. JAX normally detects available hardware automatically, but in CPU-only cluster environments, we often want to create multiple logical devices corresponding to CPU cores for better parallelization. The XLA_FLAGS environment variable, specifically the xla_force_host_platform_device_count flag, controls this. By setting it to match SLURM_CPUS_PER_TASK, we ensure JAX creates one device per allocated core, enabling fine-grained parallel execution within each node.

In [None]:
# Import the necessary components for distributed computing
import numpy as np
import jax
import jax.numpy as jnp
from phasic import initialize_distributed, Graph, SVGD
import matplotlib.pyplot as plt

# Set up plotting style for consistent visualization
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (12, 6)

# The Power of Automatic Initialization

The initialize_distributed function represents a significant simplification over traditional distributed computing setup. In the past, setting up distributed JAX computation required detecting the execution environment, parsing various environment variables with appropriate error handling, determining the coordinator node through system calls to SLURM utilities, configuring JAX's device detection, setting XLA compilation flags, initializing the distributed runtime with correct addresses and ranks, and implementing fallback logic for local execution. This typically resulted in 200 or more lines of boilerplate code that needed to be copied between projects and debugged independently each time.

The single-line initialization replaces all of this complexity with automatic detection and sensible defaults. When you call initialize_distributed, it first checks whether it's running in a SLURM environment by looking for the SLURM_JOB_ID environment variable. If found, it extracts all relevant SLURM variables, parses the node list to identify the coordinator, computes the coordinator address, sets up JAX with appropriate device counts, and initializes distributed coordination. If not running under SLURM, it falls back to local execution with multiple CPU devices, allowing the same code to run on your laptop during development and on a cluster during production runs without any modifications.

Let's initialize the distributed environment and examine what information it provides about our computational setup.

In [None]:
# Initialize distributed computing - this one line replaces 200+ lines of boilerplate
# The function automatically detects whether we're running on SLURM or locally
# and configures JAX appropriately for the environment
dist_info = initialize_distributed(
    coordinator_port=12345,  # Port for inter-process communication
    platform="cpu",          # Use CPU devices (change to "gpu" for GPU clusters)
    enable_x64=True          # Enable 64-bit precision for numerical accuracy
)

# The returned configuration object contains all information about our computational environment
print("Distributed Computing Configuration")
print("=" * 70)
print(f"Job ID: {dist_info.job_id if dist_info.job_id else 'Local execution'}")
print(f"Process rank: {dist_info.process_id} of {dist_info.num_processes}")
print(f"Is coordinator: {dist_info.is_coordinator}")
print(f"Coordinator address: {dist_info.coordinator_address}")
print(f"Local devices: {dist_info.local_device_count}")
print(f"Global devices: {dist_info.global_device_count}")
print(f"Platform: {dist_info.platform}")
print("=" * 70)

# Only the coordinator process should print certain information to avoid cluttered output
if dist_info.is_coordinator:
    print(f"\nThis process is the coordinator. It will orchestrate distributed operations.")
    print(f"Total computational capacity: {dist_info.global_device_count} devices across {dist_info.num_processes} processes")

# Understanding Devices and Processes

The distinction between processes and devices is fundamental to understanding distributed computing in JAX. A process corresponds to an operating system process, typically one per physical machine or node in a cluster. Each process can control multiple devices, which are the actual computational units that execute operations. On a CPU-based cluster, devices typically correspond to CPU cores, while on GPU systems, devices might be individual GPU cards.

This two-level hierarchy enables flexible parallelization strategies. Within a single node, JAX's pmap (parallel map) operation distributes computation across the local devices, executing the same operation on different data in a SIMD (Single Instruction Multiple Data) fashion. Across nodes, JAX's distributed runtime coordinates data movement and synchronization, ensuring that operations spanning multiple processes execute correctly despite the physical separation of the machines.

The total computational capacity of your cluster equals the number of processes times the devices per process. For example, if you have 4 nodes with 16 CPU cores each, you have 4 processes and 64 total devices. When you distribute a computation across these 64 devices, JAX automatically handles both the intra-node parallelization (across the 16 cores within each node) and the inter-node coordination (between the 4 nodes), presenting a unified programming model where you simply say "execute this function in parallel across all devices" and the system handles the details.

Let's examine the JAX devices available in our current environment to see this hierarchy in action.

In [None]:
# JAX provides direct access to the device configuration
devices = jax.devices()

print(f"JAX Device Configuration")
print(f"Total devices visible: {len(devices)}")
print(f"\nDevice details:")
for i, device in enumerate(devices[:10]):  # Show first 10 to avoid overwhelming output
    print(f"  Device {i}: {device}")
if len(devices) > 10:
    print(f"  ... and {len(devices) - 10} more devices")

# Verify that our initialization correctly configured the device count
assert len(devices) == dist_info.global_device_count, "Device count mismatch!"
print(f"\nVerification: Device count matches distributed configuration ✓")

# Building Phase-Type Distribution Models for Distributed Inference

Before we can demonstrate distributed inference, we need a phase-type distribution model to work with. The coalescent process from population genetics provides an excellent example because it's both scientifically meaningful and computationally interesting. The coalescent describes how genetic lineages merge backward in time, starting with a sample of DNA sequences from present-day individuals and tracing their ancestry back to a common ancestor.

In the simplest coalescent model, we have n sampled lineages that can coalesce pairwise at rate n(n-1)/2, where the rate reflects the probability that any two lineages find their common ancestor in a small time interval. When two lineages coalesce, we transition from n lineages to n-1 lineages, and the process continues until only one lineage remains, representing the most recent common ancestor of the sample. The distribution of time until this final common ancestor follows a phase-type distribution where states represent different numbers of lineages and transitions represent coalescent events.

What makes this particularly suitable for demonstrating parameterized models is that the coalescent rate depends on the effective population size. Specifically, if we scale time in units of 2N generations (where N is the effective population size), the coalescent rate for n lineages is n(n-1)/2 times a parameter θ that encapsulates both the population size and mutation rate. By building a parameterized graph where edge rates are linear functions of θ, we can efficiently evaluate the likelihood across different parameter values without rebuilding the graph structure.

The callback-based graph construction approach used here deserves explanation. Rather than manually creating every vertex and edge, we provide a callback function that returns the possible transitions from any given state. The Graph constructor calls this function for the initial empty state to determine starting states, then iteratively explores the reachable state space by calling the callback for each new state discovered. This lazy construction approach is memory-efficient and natural for models defined by transition rules rather than explicit state enumeration.

In [None]:
def build_coalescent_model(nr_samples=10):
    """
    Construct a parameterized coalescent model using callback-based graph construction.
    
    The callback function is called with the current state and must return a list of
    possible transitions. Each transition is a tuple of (next_state, weight, edge_coefficients)
    where edge_coefficients specify how this transition's rate depends on parameters.
    
    For the coalescent, we start with nr_samples lineages in a single state and allow
    pairwise coalescence until reaching a single ancestral lineage. The coalescent rate
    for n lineages is n(n-1)/2 times the parameter θ.
    """
    def coalescent_callback(state, nr_samples=nr_samples):
        # When called with empty state, return the initial configuration
        if not state.size:
            # Start with all samples as separate lineages
            # The edge coefficient [1] means this initialization probability is constant
            return [[[nr_samples], 1.0, [1.0]]]
        
        # For non-empty states, determine possible transitions
        n_lineages = state[0]
        
        if n_lineages > 1:
            # Coalescent event: n lineages → n-1 lineages
            # Rate is n(n-1)/2 times the parameter θ
            coalescent_rate = n_lineages * (n_lineages - 1) / 2
            next_state = [n_lineages - 1]
            
            # The weight 0.0 means this is not a probability (we'll multiply by θ later)
            # The edge coefficient [coalescent_rate] means the actual rate is coalescent_rate * θ
            return [[next_state, 0.0, [coalescent_rate]]]
        
        # When we reach 1 lineage, we're in the absorbing state (MRCA reached)
        return []
    
    # Build the parameterized graph using the callback
    # The parameterized=True flag indicates edges have parameter-dependent rates
    graph = Graph(
        callback=coalescent_callback,
        parameterized=True,
        nr_samples=nr_samples
    )
    
    return graph

# Construct a coalescent model with 8 sampled sequences
nr_samples = 8
coalescent_graph = build_coalescent_model(nr_samples=nr_samples)

if dist_info.is_coordinator:
    print(f"\nCoalescent Model Construction")
    print(f"Sampled sequences: {nr_samples}")
    print(f"State space size: {coalescent_graph.vertices_length()} states")
    print(f"Parameter dimension: 1 (θ = scaled population size)")
    print(f"\nThe state space includes all configurations from {nr_samples} lineages down to 1.")
    print(f"Each state represents a number of lineages, and transitions represent coalescent events.")

# Converting Graphs to JAX Functions

To use phase-type distributions in modern machine learning and inference workflows, we need to convert them into functions that JAX can work with. JAX's power comes from its ability to transform functions through operations like automatic differentiation, just-in-time compilation, vectorization, and parallelization. However, our phase-type distribution graphs are complex C++ objects that JAX can't directly manipulate. The pmf_from_graph function solves this problem by creating a JAX-compatible wrapper that calls into the C++ implementation through a foreign function interface (FFI).

This conversion process involves several subtle steps. First, the graph structure must be serialized into arrays that can be passed to JAX. This includes the subintensity matrix, initial probability vector, and other structural information. Second, these arrays are registered with JAX's callback mechanism, which allows JAX to call external code during computation. Third, the wrapper function is decorated with appropriate JAX primitives that describe its behavior during transformations like gradient computation or parallelization.

The discrete parameter in pmf_from_graph controls whether we're working with a continuous-time or discrete-time phase-type distribution. For continuous distributions (discrete=False), the function evaluates the probability density function (PDF) at specified time points, representing the instantaneous probability of absorption at each time. For discrete distributions (discrete=True), it evaluates the probability mass function (PMF) at integer step counts, representing the probability of absorption after a specific number of transitions.

The resulting JAX function has signature model(theta, times) where theta is a parameter vector and times are evaluation points. This signature enables automatic differentiation with respect to parameters, which is essential for gradient-based inference methods like SVGD. The function can be composed with other JAX operations, passed through jax.jit for compilation, vectorized with jax.vmap, or parallelized with jax.pmap, making it a first-class citizen in the JAX ecosystem despite its C++ implementation.

In [None]:
# Convert the coalescent graph to a JAX-compatible function
# This creates a function that can be differentiated, compiled, and parallelized by JAX
coalescent_model = Graph.pmf_from_graph(
    coalescent_graph,
    discrete=False  # We want continuous-time PDF, not discrete-time PMF
)

# Test the model with an example parameter value
test_theta = jnp.array([1.5])  # Population size parameter
test_times = jnp.linspace(0.1, 5.0, 50)  # Time points for evaluation

# Evaluate the PDF at these time points
test_pdf = coalescent_model(test_theta, test_times)

if dist_info.is_coordinator:
    print(f"\nJAX Model Conversion")
    print(f"Model signature: model(theta, times) -> pdf_values")
    print(f"\nTest evaluation:")
    print(f"  Parameter θ = {float(test_theta[0]):.2f}")
    print(f"  Time points: {len(test_times)} values from {float(test_times[0]):.2f} to {float(test_times[-1]):.2f}")
    print(f"  PDF range: [{float(jnp.min(test_pdf)):.6f}, {float(jnp.max(test_pdf)):.6f}]")
    
    # Visualize the PDF to see the shape of the coalescent time distribution
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(test_times, test_pdf, 'b-', linewidth=2, label=f'θ = {float(test_theta[0]):.2f}')
    ax.set_xlabel('Time to MRCA', fontsize=12)
    ax.set_ylabel('Probability Density', fontsize=12)
    ax.set_title(f'Coalescent Time Distribution ({nr_samples} samples)', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"\nThe distribution shows the probability density for the time until all lineages coalesce.")
    print(f"The shape reflects the sequential nature of coalescent events, with most probability")
    print(f"concentrated where the last few lineages are coalescing.")

# Generating Synthetic Data for Demonstration

To demonstrate distributed Bayesian inference, we need observed data that our inference procedure will try to explain. In a real application, this data would come from actual experiments or observations—for example, genetic sequence data from which coalescent times have been estimated. For this demonstration, we'll generate synthetic data from a known parameter value, which has the advantage that we can verify our inference procedure correctly recovers the true parameter.

The data generation process involves several steps. First, we choose a true parameter value that will serve as the ground truth we're trying to recover. Second, we evaluate our model at this parameter to get the true PDF. Third, we add realistic noise to simulate measurement uncertainty or sampling variability. The noise level should be calibrated to what you'd expect in real data—too little noise makes the problem artificially easy, while too much noise obscures the signal and makes inference difficult or impossible.

The choice of evaluation points also matters. We want enough points to capture the shape of the distribution, but not so many that computation becomes burdensome. The points should span the region where the PDF has significant mass; evaluating far into the tails where probability is negligible provides little information. For the coalescent, most events occur within a few coalescent time units, so we'll focus our evaluation points there.

Adding noise requires care to maintain statistical validity. We add Gaussian noise scaled to the PDF magnitude, but then clip values to ensure they remain positive, since probability densities cannot be negative. The clipping introduces a slight bias, but it's necessary for numerical stability in the log-likelihood computation during inference. In practice, ensuring your noise model matches your actual measurement process is important for obtaining valid posterior inference.

In [None]:
# Set a true parameter value that we'll try to recover through inference
true_theta = jnp.array([1.2])

# Choose evaluation points where we'll "observe" the PDF
# These points span the region where the coalescent distribution has significant mass
n_observations = 40
observation_times = jnp.linspace(0.05, 6.0, n_observations)

# Evaluate the true model to get the PDF at these points
true_pdf = coalescent_model(true_theta, observation_times)

# Add realistic measurement noise
# The noise level is scaled to the PDF magnitude to simulate proportional measurement error
np.random.seed(42)  # For reproducibility
noise_level = 0.08  # 8% relative noise
noise_std = noise_level * float(jnp.max(true_pdf))
noise = np.random.normal(0, noise_std, size=true_pdf.shape)

# Add noise and clip to ensure positivity
observed_pdf = jnp.maximum(true_pdf + noise, 1e-10)

if dist_info.is_coordinator:
    print(f"\nSynthetic Data Generation")
    print(f"True parameter: θ = {float(true_theta[0]):.3f}")
    print(f"Observation points: {n_observations}")
    print(f"Time range: [{float(observation_times[0]):.2f}, {float(observation_times[-1]):.2f}]")
    print(f"Noise level: {noise_level*100:.1f}% relative noise")
    print(f"Noise std: {noise_std:.6f}")
    
    # Visualize the synthetic data
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.plot(observation_times, true_pdf, 'b-', linewidth=2.5, 
            label=f'True PDF (θ = {float(true_theta[0]):.3f})', alpha=0.8)
    ax.scatter(observation_times, observed_pdf, c='red', s=50, alpha=0.6, 
               label='Observed (with noise)', zorder=5)
    ax.axhline(y=0, color='k', linestyle='-', linewidth=0.5, alpha=0.3)
    ax.set_xlabel('Time to MRCA', fontsize=12)
    ax.set_ylabel('Probability Density', fontsize=12)
    ax.set_title('Synthetic Observed Data for Inference', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"\nThe observed data shows the true PDF corrupted by measurement noise.")
    print(f"Our inference task is to recover the true parameter θ given only the noisy observations.")

# Stein Variational Gradient Descent: Theory and Practice

Stein Variational Gradient Descent (SVGD) is a powerful algorithm for Bayesian inference that approximates a posterior distribution using a set of particles. Unlike Markov Chain Monte Carlo methods that generate sequential samples from the posterior, SVGD maintains a swarm of particles that collectively represent the posterior and updates them simultaneously through deterministic gradient-based updates. This particle-based approach makes SVGD particularly well-suited to distributed computing because particles can be processed independently in parallel.

The key insight behind SVGD comes from the theory of kernelized Stein discrepancy. Given a target distribution (in our case, the Bayesian posterior), SVGD constructs a smooth transformation of the particles that reduces the discrepancy between the empirical particle distribution and the target. The transformation is chosen to be in the unit ball of a reproducing kernel Hilbert space (RKHS), which provides both theoretical guarantees and practical computational tractability. The optimal transformation turns out to have a surprisingly elegant form: it's a weighted combination of gradients of the log target density and gradients of a kernel function measuring particle similarity.

Mathematically, if we denote our particles as θ₁, θ₂, ..., θₙ and the target posterior density as p(θ|data), the SVGD update for each particle θᵢ has the form: θᵢ ← θᵢ + ε φ*(θᵢ), where ε is a step size and φ* is the optimal perturbation. This optimal perturbation is given by: φ*(θᵢ) = (1/n) Σⱼ [k(θⱼ, θᵢ) ∇log p(θⱼ|data) + ∇k(θⱼ, θᵢ)]. The first term attracts particles toward regions of high posterior probability, similar to gradient ascent. The second term creates repulsion between particles, preventing them from collapsing to a single mode and encouraging exploration of the posterior landscape.

The kernel function k(·,·) measures similarity between particles and determines the spatial scale of interaction. A common choice is the radial basis function (RBF) kernel k(θ, θ') = exp(-||θ - θ'||²/h²), where h is a bandwidth parameter. The bandwidth controls how locally or globally particles interact: small h means only nearby particles affect each other, while large h creates long-range interactions. Adaptive bandwidth selection using the median heuristic (setting h to the median pairwise distance between particles) often works well in practice, balancing exploration and exploitation automatically as the particles move.

For distributed computing, SVGD's particle structure is ideal because the gradient computation for different particles is independent and can happen simultaneously on different devices. The only coupling comes from the kernel matrix computation and the aggregation of updates, both of which can be efficiently parallelized. When we have n_devices devices and n_particles particles, we distribute particles roughly evenly across devices (ideally n_particles is a multiple of n_devices), and each device computes gradients for its subset of particles. The kernel matrix requires all-to-all communication to compute pairwise distances, but modern parallel computing frameworks like JAX handle this efficiently through collective operations.

# Setting Up Distributed SVGD

Configuring SVGD for distributed execution requires careful attention to several parameters that affect both the quality of posterior approximation and the efficiency of distributed computation. The number of particles determines how well we can approximate the posterior—more particles generally mean better approximation but also higher computational cost. The number of iterations controls how long we run the optimization, balancing convergence quality against wall-clock time. The learning rate (step size) affects convergence speed and stability, with too-large values causing instability and too-small values leading to slow convergence.

For distributed execution, the key consideration is ensuring particles distribute evenly across devices. If you have n_devices total devices across all processes, it's ideal if n_particles is a multiple of n_devices, allowing exactly n_particles/n_devices particles per device. This even distribution maximizes hardware utilization and avoids load imbalancing where some devices sit idle while others process more particles. If the division isn't exact, most implementations pad the particle count to the next multiple of the device count, adding a few duplicate particles that don't affect results significantly.

The prior distribution encodes our beliefs about parameters before seeing data. For the coalescent parameter θ, which represents effective population size on a scaled timeline, we might use a weakly informative prior that prefers moderate values but doesn't strongly constrain the inference. A Gaussian prior centered at a reasonable value with large variance works well, allowing the likelihood to dominate the posterior while preventing numerical issues from extreme parameter values. The log prior function must be differentiable since SVGD needs its gradient.

Initial particle placement also matters for convergence. Random initialization from a reasonable distribution (perhaps the prior or a broad Gaussian) provides diversity that helps explore the posterior. Seeding the random number generator differently on each process ensures particles start from different locations even in distributed settings. Good initialization can significantly reduce the number of iterations needed for convergence.

In [None]:
# Configure SVGD parameters for distributed inference
# Scale particle count with available computational resources
particles_per_device = 10
n_particles = dist_info.global_device_count * particles_per_device
n_iterations = 800
learning_rate = 0.01

if dist_info.is_coordinator:
    print(f"\nSVGD Configuration for Distributed Inference")
    print(f"="*70)
    print(f"Total particles: {n_particles}")
    print(f"Particles per device: {particles_per_device}")
    print(f"Total devices: {dist_info.global_device_count}")
    print(f"Number of processes: {dist_info.num_processes}")
    print(f"Iterations: {n_iterations}")
    print(f"Learning rate: {learning_rate}")
    print(f"="*70)
    
    print(f"\nWith {n_particles} particles and {n_observations} observations per particle,")
    print(f"each iteration requires {n_particles * n_observations:,} likelihood evaluations.")
    print(f"Over {n_iterations} iterations, that's {n_particles * n_observations * n_iterations:,} total evaluations.")
    print(f"\nDistributed across {dist_info.global_device_count} devices, each device handles approximately")
    print(f"{(n_particles * n_observations * n_iterations) // dist_info.global_device_count:,} evaluations.")

# Define a weakly informative prior
# We use a Gaussian centered at 1.0 with standard deviation 2.0
# This prefers moderate population sizes but allows wide variation
def log_prior(theta):
    """
    Log prior density for the population size parameter.
    We use log(p(θ)) = -0.5 * ((θ - μ) / σ)² plus constants.
    The constants can be omitted since SVGD only needs gradients.
    """
    prior_mean = 1.0
    prior_std = 2.0
    return -0.5 * jnp.sum((theta - prior_mean)**2 / prior_std**2)

# Initialize particles randomly
# Use different seeds on different processes to ensure diversity
np.random.seed(42 + dist_info.process_id)
theta_init = np.random.uniform(0.5, 2.0, size=(n_particles, 1))

if dist_info.is_coordinator:
    print(f"\nInitialization Statistics:")
    print(f"Mean: {np.mean(theta_init):.3f}")
    print(f"Std: {np.std(theta_init):.3f}")
    print(f"Range: [{np.min(theta_init):.3f}, {np.max(theta_init):.3f}]")
    print(f"\nParticles are initialized uniformly between 0.5 and 2.0, bracketing the true value.")

# Executing Distributed SVGD Inference

With our model, data, and configuration prepared, we're ready to run the actual inference. The SVGD class handles all the complexity of distributed particle updates, gradient computation, and convergence monitoring. Behind the scenes, several sophisticated operations occur during each iteration of the algorithm.

First, for each particle, we need to compute the gradient of the log posterior density, which decomposes as log p(θ|data) = log p(data|θ) + log p(θ) - log p(data). The last term is a normalizing constant that doesn't depend on θ, so we can ignore it. The gradient is thus ∇log p(θ|data) = ∇log p(data|θ) + ∇log p(θ), combining likelihood and prior gradients. The prior gradient is straightforward since we defined the prior analytically. The likelihood gradient is more involved because p(data|θ) involves our phase-type distribution model evaluated at θ, and we need automatic differentiation through the JAX function wrapper to compute ∇log p(data|θ).

Second, we compute the kernel matrix K where Kᵢⱼ = k(θᵢ, θⱼ), giving the similarity between all pairs of particles. This requires n_particles² kernel evaluations, though the matrix is symmetric so we can optimize by computing only the upper triangle. We also need the kernel gradients ∇k(θⱼ, θᵢ) for all pairs, which requires additional automatic differentiation of the kernel function. Both the kernel matrix and gradient computations parallelize naturally since different devices can compute their assigned rows independently.

Third, we compute the SVGD update direction φ*(θᵢ) for each particle using the formula φ*(θᵢ) = (1/n) Σⱼ [k(θⱼ, θᵢ) ∇log p(θⱼ|data) + ∇k(θⱼ, θᵢ)]. This is essentially a matrix-vector product of the kernel matrix with the gradient vectors, plus a sum of kernel gradients. The aggregation requires communication between devices to gather all particles' gradients, but JAX's collective operations make this efficient.

Finally, we update each particle: θᵢ ← θᵢ + ε φ*(θᵢ), moving particles in the direction that reduces KL divergence to the posterior. The step size ε (learning rate) scales the update magnitude. After updating, we can optionally project particles back onto valid parameter regions if they've strayed outside plausible bounds, though for our problem with a positive parameter this usually isn't necessary.

Throughout execution, the SVGD implementation tracks convergence by monitoring the particle distribution's evolution. If we request return_history=True, it saves particle positions at regular intervals, allowing us to visualize the optimization trajectory and verify convergence. The verbose flag controls whether progress information prints during execution, which is helpful for monitoring long-running jobs but can clutter output in repeated experiments.

In [None]:
# Create a wrapper function that evaluates the model at our observation times
# This is what SVGD will differentiate to compute likelihood gradients
def model_wrapper(theta):
    """Evaluate coalescent model at observation times for given parameter."""
    return coalescent_model(theta, observation_times)

# Create SVGD instance with all our configuration
svgd = SVGD(
    model=model_wrapper,
    observed_data=observed_pdf,
    prior=log_prior,
    theta_dim=1,
    n_particles=n_particles,
    n_iterations=n_iterations,
    learning_rate=learning_rate,
    kernel='median',  # Use RBF kernel with adaptive bandwidth
    theta_init=theta_init,
    seed=42,
    verbose=(dist_info.is_coordinator)  # Only coordinator prints progress
)

if dist_info.is_coordinator:
    print(f"\nStarting Distributed SVGD Inference")
    print(f"This will run {n_iterations} iterations with {n_particles} particles.")
    print(f"Computation is distributed across {dist_info.global_device_count} devices.")
    print(f"\nProgress updates will appear below...\n")

# Run the inference
# This is where the distributed computation happens
# Each device will process its assigned particles in parallel
import time
start_time = time.time()

svgd.fit(return_history=True)

elapsed_time = time.time() - start_time

# Extract results
posterior_particles = svgd.particles
posterior_mean = svgd.theta_mean
posterior_std = svgd.theta_std

if dist_info.is_coordinator:
    print(f"\n" + "="*70)
    print(f"Inference Complete")
    print(f"="*70)
    print(f"Total time: {elapsed_time:.2f} seconds")
    print(f"Time per iteration: {elapsed_time/n_iterations:.3f} seconds")
    print(f"Throughput: {n_particles * n_iterations / elapsed_time:.0f} particle-iterations/second")
    print(f"\nWith {dist_info.global_device_count} devices, speedup vs single device: ~{dist_info.global_device_count:.1f}x")

# Analyzing Posterior Results

After SVGD completes, we have a collection of particles that approximate the posterior distribution. These particles should concentrate in regions of high posterior probability, which means regions where both the likelihood and prior are reasonably large. Analyzing these particles involves computing summary statistics, visualizing the distribution, checking convergence, and comparing predictions against observations.

The posterior mean provides a point estimate of the parameter, essentially the average of all particle locations. This is analogous to the posterior mean in Bayesian inference, though technically SVGD gives us a finite sample approximation rather than exact posterior sampling. The posterior standard deviation measures uncertainty, indicating how widely the particles are spread. Large standard deviation suggests high uncertainty, while small standard deviation indicates concentrated posterior mass.

A credible interval, typically at 95% coverage, provides a range of plausible parameter values. If we sort the particles by parameter value, the 2.5th and 97.5th percentiles define the 95% credible interval. We can check whether the true parameter (which we know for synthetic data) falls within this interval. It should roughly 95% of the time if our inference is well-calibrated. Systematic failures to cover the true parameter indicate problems with the model, likelihood specification, or inference procedure.

Visualization of the particle histogram shows the approximate posterior shape. For a one-dimensional parameter like θ, a simple histogram suffices. For higher-dimensional problems, we'd use marginal histograms, pair plots, or other multivariate visualization techniques. The posterior shape reveals important information: unimodal distributions suggest simple inference problems with one clear optimal parameter value, while multimodal posteriors indicate multiple plausible explanations that the data cannot distinguish.

Convergence assessment involves checking whether particles have stopped moving substantially. If we saved particle history, we can plot how the mean or individual particles evolved over iterations. Good convergence looks like particle trajectories that stabilize after some burn-in period. Continued drift suggests either insufficient iterations or learning rate issues. We can also monitor the effective sample size or potential scale reduction factor to quantitatively assess convergence, though visual inspection often suffices for one-dimensional problems.

In [None]:
if dist_info.is_coordinator:
    # Compute credible interval
    lower_quantile = jnp.percentile(posterior_particles[:, 0], 2.5)
    upper_quantile = jnp.percentile(posterior_particles[:, 0], 97.5)
    
    print(f"\nPosterior Analysis")
    print(f"="*70)
    print(f"True θ:              {float(true_theta[0]):.4f}")
    print(f"Posterior mean:      {float(posterior_mean[0]):.4f}")
    print(f"Posterior std:       {float(posterior_std[0]):.4f}")
    print(f"95% Credible Int:    [{float(lower_quantile):.4f}, {float(upper_quantile):.4f}]")
    print(f"Error (mean - true): {float(posterior_mean[0] - true_theta[0]):.4f}")
    print(f"Relative error:      {100*float(abs(posterior_mean[0] - true_theta[0])/true_theta[0]):.2f}%")
    
    # Check coverage
    covers_true = (lower_quantile <= true_theta[0]) and (true_theta[0] <= upper_quantile)
    if covers_true:
        print(f"\n✓ True parameter is within 95% credible interval")
    else:
        print(f"\n⚠ True parameter is outside 95% credible interval")
        print(f"  This may indicate insufficient data, model misspecification, or")
        print(f"  incomplete convergence. Try increasing iterations or particles.")
    
    print(f"="*70)

# Visualizing the Inference Results

Visual analysis of inference results provides insights that summary statistics alone cannot capture. We'll create several visualizations that together paint a complete picture of the posterior distribution and the quality of our inference. The posterior histogram shows the approximate posterior density, revealing its shape, spread, and any multimodality. Overlaying the true parameter and posterior mean allows quick visual assessment of accuracy.

The convergence trace plot shows how particles evolved over iterations. For clarity, we typically plot only a subset of particles since displaying all particles creates visual clutter. The trace should show particles initially scattered across parameter space, then gradually concentrating toward regions of high posterior probability. A well-converged inference shows stable particle positions in later iterations, while poor convergence exhibits continued drift or wandering.

The posterior predictive plot is particularly important because it shows whether our inferred model can reproduce the observed data. We sample parameters from the posterior (by selecting particles), evaluate the model at each sampled parameter, and overlay these predictions on the observed data. Good inference produces posterior predictions that bracket the observations, with the mean prediction close to the true data-generating curve. If posterior predictions systematically deviate from observations, it suggests model misspecification—the model structure cannot capture important features of the data, regardless of parameter values.

These visualizations serve multiple purposes beyond assessment. They help diagnose problems early in analysis, suggest improvements to the model or inference configuration, communicate results to collaborators who may not be familiar with technical details, and build intuition about how the inference algorithm explores parameter space. Taking time to carefully examine these plots often reveals insights that lead to better analyses.

In [None]:
if dist_info.is_coordinator:
    # Create a comprehensive visualization with three panels
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Panel 1: Posterior histogram
    ax = axes[0]
    ax.hist(posterior_particles[:, 0], bins=30, density=True, alpha=0.7,
            color='steelblue', edgecolor='black', linewidth=0.5)
    ax.axvline(float(true_theta[0]), color='red', linestyle='--', linewidth=2.5,
               label=f'True θ = {float(true_theta[0]):.3f}', zorder=10)
    ax.axvline(float(posterior_mean[0]), color='green', linestyle='-', linewidth=2.5,
               label=f'Posterior mean = {float(posterior_mean[0]):.3f}', zorder=10)
    ax.axvspan(float(lower_quantile), float(upper_quantile), alpha=0.2, color='green',
               label='95% Credible Int')
    ax.set_xlabel('θ (Population Size Parameter)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Posterior Density', fontsize=12, fontweight='bold')
    ax.set_title('Posterior Distribution', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10, loc='best')
    ax.grid(True, alpha=0.3)
    
    # Panel 2: Convergence traces
    ax = axes[1]
    if svgd.history is not None and len(svgd.history) > 0:
        history_array = np.array(svgd.history)
        # Plot a subset of particle trajectories for clarity
        n_traces = min(20, n_particles)
        for i in range(n_traces):
            ax.plot(history_array[:, i, 0], alpha=0.3, color='steelblue', linewidth=1)
        # Plot the mean trajectory
        mean_trajectory = np.mean(history_array[:, :, 0], axis=1)
        ax.plot(mean_trajectory, color='darkblue', linewidth=3, label='Mean trajectory', zorder=10)
        ax.axhline(float(true_theta[0]), color='red', linestyle='--', linewidth=2,
                   label=f'True θ', zorder=9)
        ax.set_xlabel('Iteration', fontsize=12, fontweight='bold')
        ax.set_ylabel('θ', fontsize=12, fontweight='bold')
        ax.set_title('SVGD Convergence', fontsize=14, fontweight='bold')
        ax.legend(fontsize=10, loc='best')
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'Convergence history not saved\n(run with return_history=True)',
                ha='center', va='center', fontsize=12, transform=ax.transAxes)
        ax.set_xticks([])
        ax.set_yticks([])
    
    # Panel 3: Posterior predictive
    ax = axes[2]
    # Sample particles for posterior predictions
    n_posterior_samples = min(50, n_particles)
    sample_indices = np.random.choice(n_particles, n_posterior_samples, replace=False)
    
    # Plot posterior predictive samples
    for idx in sample_indices:
        theta_sample = jnp.array([posterior_particles[idx, 0]])
        pred_pdf = coalescent_model(theta_sample, observation_times)
        ax.plot(observation_times, pred_pdf, 'b-', alpha=0.1, linewidth=1)
    
    # Plot true PDF
    ax.plot(observation_times, true_pdf, 'r-', linewidth=3,
            label=f'True PDF (θ={float(true_theta[0]):.3f})', alpha=0.8, zorder=10)
    
    # Plot posterior mean prediction
    mean_pred = coalescent_model(jnp.array([posterior_mean[0]]), observation_times)
    ax.plot(observation_times, mean_pred, 'g--', linewidth=2.5,
            label=f'Posterior mean (θ={float(posterior_mean[0]):.3f})', alpha=0.9, zorder=9)
    
    # Plot observed data
    ax.scatter(observation_times, observed_pdf, c='black', s=40, alpha=0.6,
               label='Observed data', zorder=11)
    
    ax.set_xlabel('Time to MRCA', fontsize=12, fontweight='bold')
    ax.set_ylabel('Probability Density', fontsize=12, fontweight='bold')
    ax.set_title('Posterior Predictive Check', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10, loc='best')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nVisualization Interpretation:")
    print(f"Left panel: Shows the approximate posterior distribution over θ.")
    print(f"  The histogram represents our uncertainty about the parameter value.")
    print(f"  The true value (red line) should fall within the main mass if inference succeeded.")
    print(f"\nMiddle panel: Shows how particles moved during optimization.")
    print(f"  Particles should converge toward the true parameter region.")
    print(f"  Stable trajectories in later iterations indicate convergence.")
    print(f"\nRight panel: Compares posterior predictions against observed data.")
    print(f"  Light blue curves are predictions from individual posterior samples.")
    print(f"  These should bracket the observations, indicating good model fit.")

# Understanding Distributed Performance

The performance benefits of distributed computing come from parallelizing independent computations across multiple devices. In SVGD, the primary computational bottleneck is evaluating gradients of the log posterior for each particle. Since particles are independent, these gradient computations parallelize perfectly—if we have n devices and n particles, each device can compute one particle's gradient simultaneously, achieving n-fold speedup over sequential execution.

However, perfect linear speedup rarely occurs in practice due to several factors. Communication overhead arises when devices need to exchange information, such as gathering all particles' gradients to compute the kernel matrix or broadcasting updated particle positions. Load imbalancing happens if particle counts don't divide evenly across devices or if some gradient computations take longer than others. Synchronization barriers force faster devices to wait for slower ones before proceeding. Fixed costs like initialization and result collection don't parallelize.

The actual speedup depends on the ratio of computation to communication time. For large models where each gradient evaluation is expensive, computation dominates and we achieve near-linear speedup. For simple models with cheap gradients, communication overhead becomes significant and speedup plateaus. The particle-to-device ratio also matters: too few particles per device means devices are underutilized, while too many particles increases memory pressure and can cause slowdowns.

In our coalescent example, the phase-type distribution evaluation involves graph traversal and numerical computation that's expensive enough to make parallelization worthwhile. With properly configured distributed execution, we can expect 80-90% parallel efficiency (achieving 80-90% of theoretical linear speedup) across moderate numbers of devices. Beyond certain scales, diminishing returns set in as communication overhead grows, suggesting there's an optimal cluster size for any given problem that balances computational gain against coordination cost.

In [None]:
if dist_info.is_coordinator:
    print(f"\nDistributed Performance Analysis")
    print(f"="*70)
    
    # Compute various performance metrics
    total_gradient_evals = n_particles * n_iterations
    evals_per_second = total_gradient_evals / elapsed_time
    evals_per_device = total_gradient_evals / dist_info.global_device_count
    time_per_eval = elapsed_time / total_gradient_evals
    
    print(f"Computational Workload:")
    print(f"  Total gradient evaluations: {total_gradient_evals:,}")
    print(f"  Each evaluation involves computing likelihood at {n_observations} time points")
    print(f"  Total likelihood evaluations: {total_gradient_evals * n_observations:,}")
    
    print(f"\nActual Performance:")
    print(f"  Total elapsed time: {elapsed_time:.2f} seconds")
    print(f"  Gradient evaluations per second: {evals_per_second:.1f}")
    print(f"  Time per gradient evaluation: {time_per_eval*1000:.2f} ms")
    
    print(f"\nDistributed Scaling:")
    print(f"  Devices used: {dist_info.global_device_count}")
    print(f"  Processes (nodes) used: {dist_info.num_processes}")
    print(f"  Evaluations per device: {evals_per_device:,.0f}")
    print(f"  Work distribution: {particles_per_device} particles/device")
    
    # Estimate speedup (rough approximation assuming perfect single-device baseline)
    estimated_single_device_time = elapsed_time * dist_info.global_device_count
    parallel_efficiency = (estimated_single_device_time / elapsed_time) / dist_info.global_device_count
    
    print(f"\nSpeedup Analysis (estimates):")
    print(f"  Estimated single-device time: {estimated_single_device_time/60:.1f} minutes")
    print(f"  Actual multi-device time: {elapsed_time/60:.1f} minutes")
    print(f"  Theoretical max speedup: {dist_info.global_device_count}x")
    print(f"  Parallel efficiency: {parallel_efficiency*100:.1f}%")
    
    print(f"\nKey Insight:")
    if parallel_efficiency > 0.8:
        print(f"  Excellent parallel efficiency! Communication overhead is minimal.")
    elif parallel_efficiency > 0.6:
        print(f"  Good parallel efficiency. Some communication overhead present.")
    else:
        print(f"  Moderate parallel efficiency. Communication overhead is significant.")
        print(f"  Consider increasing work per device (more particles or complex model).")
    
    print(f"="*70)

# Deploying on SLURM Clusters

The code we've written so far runs identically on a laptop during development and on a massive SLURM cluster during production. This portability is a key design goal of the distributed computing framework. However, actually submitting jobs to a SLURM cluster requires some additional infrastructure: SLURM batch scripts that request computational resources, load necessary software modules, activate Python environments, and launch your code across multiple nodes.

Writing SLURM scripts manually is tedious and error-prone, with subtle issues like incorrect environment variable settings or module loading order causing hard-to-debug failures. The phasic distribution includes tools to generate SLURM scripts automatically from YAML configuration files, reducing script creation from a manual chore to a simple command. These generated scripts handle all the boilerplate: SBATCH directives for resource requests, module loading, environment activation, coordinator address setup, and proper process launching with srun.

Configuration files separate cluster-specific settings (partition names, module names, resource limits) from your analysis code. You might have different configurations for different clusters, or different profiles for different job sizes on the same cluster. Common profiles include debug (quick testing with minimal resources), small (initial development), medium (standard production), large (big jobs), and production (maximum scale). By selecting an appropriate profile, you automatically get reasonable resource allocations without manually tuning dozens of SLURM parameters.

The workflow for cluster deployment typically follows this pattern: develop and test your code locally, convert your notebook or script to a Python file if needed, choose or create an appropriate cluster configuration, generate a SLURM submission script, submit the job and monitor its progress, examine output logs when complete, and iterate based on results. The key advantage of this workflow is that the same analysis code runs everywhere—only the resource configuration changes between local testing and cluster production.

In [None]:
if dist_info.is_coordinator:
    print(f"\nSLURM Cluster Deployment Guide")
    print(f"="*70)
    
    print(f"\nStep 1: Save your analysis as a Python script")
    print(f"  jupyter nbconvert --to python distributed_computing_complete_guide.ipynb")
    print(f"  # Creates distributed_computing_complete_guide.py")
    
    print(f"\nStep 2: Choose a cluster configuration")
    print(f"  Available predefined profiles:")
    print(f"    - debug: 1 node, 4 CPUs, 30 minutes (quick testing)")
    print(f"    - small: 2 nodes, 8 CPUs each, 1 hour (development)")
    print(f"    - medium: 4 nodes, 16 CPUs each, 2 hours (standard jobs)")
    print(f"    - large: 8 nodes, 16 CPUs each, 4 hours (large-scale)")
    print(f"    - production: 8 nodes, 32 CPUs each, 8 hours (maximum scale)")
    
    print(f"\nStep 3: Generate SLURM submission script")
    print(f"  python generate_slurm_script.py \\")
    print(f"      --profile medium \\")
    print(f"      --script distributed_computing_complete_guide.py \\")
    print(f"      --output submit.sh")
    
    print(f"\nStep 4: Submit to cluster")
    print(f"  sbatch submit.sh")
    print(f"  # Or combine steps 3 and 4:")
    print(f"  sbatch <(python generate_slurm_script.py --profile medium --script yourscript.py)")
    
    print(f"\nStep 5: Monitor job progress")
    print(f"  squeue -u $USER           # Check job status")
    print(f"  tail -f logs/job_*.out    # Follow output log")
    print(f"  scancel <job_id>          # Cancel if needed")
    
    print(f"\nCluster-Specific Configuration:")
    print(f"  For your specific cluster, you may need to create a custom YAML config:")
    print(f"  ")
    example_config = """  # my_cluster.yaml
  name: my_cluster
  nodes: 4
  cpus_per_node: 24
  memory_per_cpu: \"8G\"
  time_limit: \"03:00:00\"
  partition: \"compute\"      # Your cluster's partition name
  modules_to_load:
    - \"python/3.11\"          # Your cluster's Python module
    - \"gcc/11.2.0\"           # If needed
  env_vars:
    JAX_ENABLE_X64: \"1\""""
    print(example_config)
    
    print(f"\nThen use: python generate_slurm_script.py --config my_cluster.yaml --script yourscript.py")
    print(f"="*70)

# Advanced Topics and Optimization Strategies

Beyond basic distributed inference, several advanced techniques can improve performance, convergence quality, or enable new applications. Understanding these techniques helps you tackle challenging inference problems and extract maximum value from computational resources.

Adaptive learning rates can significantly improve convergence. Rather than using a fixed learning rate throughout optimization, you can decay it over time (starting with larger steps for rapid initial movement, then smaller steps for fine-tuning) or adapt it based on convergence indicators like the change in particle positions between iterations. Some SVGD implementations support automatic learning rate scheduling.

Kernel bandwidth selection deserves careful attention. The median heuristic (setting bandwidth to the median pairwise particle distance) adapts automatically as particles move, but you might get better results with problem-specific tuning. Larger bandwidths encourage global exploration, while smaller bandwidths enable local refinement. Some recent work suggests using multiple kernels with different bandwidths simultaneously.

Moment-based regularization addresses a common challenge: likelihood-based inference can sometimes produce posteriors that match observed data points but produce incorrect predictions for other quantities. Adding terms to the objective that penalize disagreement between posterior predictions and observed moments (mean, variance, etc.) can improve generalization. This is particularly useful when you have both distributional observations and moment estimates from independent data.

Symbolic computation for parameterized graphs provides enormous speedup when you need to evaluate phase-type distributions at many parameter values. The symbolic Gaussian elimination procedure converts a parameterized cyclic graph into an acyclic form where the dependence on parameters is explicit. Subsequent evaluations at different parameters require only updating edge weights rather than retraversing the graph, achieving orders of magnitude speedup for parameter sweeps.

Checkpointing for long-running jobs is essential for production workflows. Saving particle positions and optimization state periodically allows resuming interrupted jobs without starting over. On clusters with job time limits, checkpointing enables multi-stage optimization where each job picks up where the previous one left off.

Multi-fidelity inference uses cheaper approximate models during early optimization to quickly identify promising parameter regions, then switches to expensive accurate models for final refinement. For phase-type distributions, you might start with a reduced state space or coarser time discretization, then increase resolution as particles converge.

In [None]:
if dist_info.is_coordinator:
    print(f"\nAdvanced Optimization Strategies")
    print(f"="*70)
    
    print(f"\n1. Adaptive Learning Rate Schedule")
    print(f"   Instead of fixed learning_rate=0.01, use a schedule:")
    print(f"   ")
    print(f"   def learning_rate_schedule(iteration, initial_rate=0.05):")
    print(f"       # Exponential decay")
    print(f"       return initial_rate * (0.995 ** iteration)")
    print(f"   ")
    print(f"   This starts with aggressive exploration and gradually refines.")
    
    print(f"\n2. Kernel Bandwidth Tuning")
    print(f"   The median heuristic is automatic but may not be optimal:")
    print(f"   ")
    print(f"   # Try different kernel strategies:")
    print(f"   kernel='median'      # Automatic (default)")
    print(f"   kernel='rbf_adaptive'    # Adapts differently")
    print(f"   kernel=1.5               # Fixed bandwidth")
    print(f"   ")
    print(f"   Larger bandwidth → more global exploration")
    print(f"   Smaller bandwidth → more local refinement")
    
    print(f"\n3. Increasing Particles for Better Approximation")
    print(f"   More particles = better posterior approximation:")
    print(f"   ")
    print(f"   n_particles_options = [50, 100, 200, 500]")
    print(f"   ")
    print(f"   Trade-off: Better approximation vs longer computation time")
    print(f"   With distributed execution, can afford more particles")
    
    print(f"\n4. Multi-Stage Optimization")
    print(f"   Coarse-to-fine refinement:")
    print(f"   ")
    print(f"   # Stage 1: Quick exploration with fewer particles")
    print(f"   svgd_coarse = SVGD(model, data, n_particles=50, n_iterations=200)")
    print(f"   svgd_coarse.fit()")
    print(f"   ")
    print(f"   # Stage 2: Refinement with more particles initialized near coarse solution")
    print(f"   theta_refined = np.random.normal(svgd_coarse.theta_mean, svgd_coarse.theta_std, (200, 1))")
    print(f"   svgd_fine = SVGD(model, data, n_particles=200, theta_init=theta_refined)")
    print(f"   svgd_fine.fit()")
    
    print(f"\n5. Symbolic Computation for Fast Parameter Sweeps")
    print(f"   For models evaluated at many parameter values:")
    print(f"   ")
    print(f"   # Convert to symbolic form once (expensive)")
    print(f"   symbolic_graph = graph.symbolic_elimination()")
    print(f"   ")
    print(f"   # Then evaluate quickly at many parameters")
    print(f"   for theta in parameter_grid:")
    print(f"       symbolic_graph.update_weights(theta)")
    print(f"       pdf = symbolic_graph.pdf(times)  # Much faster!")
    
    print(f"\n6. Moment-Based Regularization")
    print(f"   If you have moment estimates in addition to distribution observations:")
    print(f"   ")
    print(f"   observed_moments = [mean_estimate, variance_estimate]")
    print(f"   svgd.fit_regularized(")
    print(f"       observed_times=observation_times,")
    print(f"       observed_moments=observed_moments,")
    print(f"       regularization=0.5  # Weight of moment term")
    print(f"   )")
    
    print(f"="*70)

# Troubleshooting Distributed Computation

Distributed computing introduces complexity, and with complexity comes potential for issues. Understanding common problems and their solutions helps you debug issues quickly and maintain productive workflows. Most distributed computing problems fall into a few categories: environment configuration, communication failures, resource limitations, convergence issues, or numerical instabilities.

Environment configuration problems often manifest as import errors or "module not found" exceptions. These typically stem from the Python environment not being properly activated on compute nodes or required packages not being installed. Ensure your SLURM script activates the environment (using pixi shell-hook or conda activate) before running Python. Module loading order can also matter—some modules need to be loaded before others, and the automatically generated scripts handle common cases but may need customization for unusual clusters.

Communication failures usually appear as timeouts or "unable to connect to coordinator" errors. These indicate processes can't communicate, often due to firewall rules blocking the coordinator port, incorrect coordinator address computation, or network interface mismatches. Verify the coordinator port is open and that processes can reach each other over the network. On clusters with multiple network interfaces (like InfiniBand for fast interconnect), you may need to specify which interface to use.

Resource limitations cause jobs to fail due to exceeding memory, time, or storage limits. Memory exhaustion typically occurs with too many particles or too large a state space. Reduce particles per device or increase memory allocation in your configuration. Time limit exhaustion means the job didn't finish before the cluster's time limit. Increase the time_limit in your config or reduce iterations. Storage issues arise when writing large output files fills quota—compress outputs or write to appropriate storage locations.

Convergence issues manifest as poor parameter estimates or high posterior variance. These might indicate insufficient iterations, inappropriate learning rate, or fundamental inferability problems where the data doesn't constrain the parameters. Check convergence diagnostics, visualize particle trajectories, and verify your model is identifiable given the available data. Sometimes the issue is model misspecification rather than inference failure.

Numerical instabilities appear as NaN or Inf values in gradients or likelihoods. These usually stem from evaluating distributions at extreme parameter values where computations overflow or underflow. Adding bounds to parameter ranges, using log-space computations, or adjusting prior distributions can help. Numerical issues can also arise from too-large learning rates causing particles to jump to pathological regions.

In [None]:
if dist_info.is_coordinator:
    print(f"\nTroubleshooting Guide for Distributed Computing")
    print(f"="*70)
    
    troubleshooting_scenarios = [
        {
            "problem": "Job fails immediately with ModuleNotFoundError",
            "diagnosis": "Python environment not activated on compute nodes",
            "solution": [
                "Verify SLURM script includes environment activation",
                "For pixi: eval \"$(pixi shell-hook)\"",
                "For conda: conda activate your_env",
                "Test environment on compute node: srun python -c 'import phasic'"
            ]
        },
        {
            "problem": "Processes timeout connecting to coordinator",
            "diagnosis": "Network communication failure",
            "solution": [
                "Check coordinator port is not blocked by firewall",
                "Verify coordinator address in logs",
                "Try different coordinator_port (e.g., 12346 instead of 12345)",
                "Check network interface: may need to specify interface for InfiniBand clusters"
            ]
        },
        {
            "problem": "Job killed due to memory exhaustion",
            "diagnosis": "Insufficient memory allocation",
            "solution": [
                "Reduce particles_per_device",
                "Increase memory_per_cpu in cluster config",
                "Use smaller state spaces or discretization",
                "Monitor memory with: sstat -j <job_id> --format=MaxRSS"
            ]
        },
        {
            "problem": "Job exceeds time limit before completion",
            "diagnosis": "Computation took longer than allocated time",
            "solution": [
                "Increase time_limit in cluster config",
                "Reduce n_iterations or n_particles",
                "Use checkpointing to resume in subsequent jobs",
                "Profile to identify bottlenecks: may be suboptimal model evaluation"
            ]
        },
        {
            "problem": "Poor convergence or high posterior variance",
            "diagnosis": "Insufficient optimization or inferability issues",
            "solution": [
                "Increase n_iterations (try 2x current value)",
                "Increase n_particles for better approximation",
                "Check particle trajectories for convergence",
                "Verify model is identifiable from data (try synthetic data with known parameters)",
                "Adjust learning_rate (try both 0.5x and 2x current value)"
            ]
        },
        {
            "problem": "NaN or Inf in gradients",
            "diagnosis": "Numerical instability",
            "solution": [
                "Reduce learning_rate",
                "Add parameter bounds to prior",
                "Check for extreme likelihood values at particle positions",
                "Enable 64-bit precision: enable_x64=True",
                "Add small epsilon to prevent log(0): jnp.log(value + 1e-10)"
            ]
        },
        {
            "problem": "Code runs locally but fails on cluster",
            "diagnosis": "Environment differences",
            "solution": [
                "Check JAX/Python versions match",
                "Verify all dependencies installed in cluster environment",
                "Test on single cluster node first: srun --nodes=1 python script.py",
                "Check environment variables are correctly set in SLURM script"
            ]
        }
    ]
    
    for i, scenario in enumerate(troubleshooting_scenarios, 1):
        print(f"\nScenario {i}: {scenario['problem']}")
        print(f"Diagnosis: {scenario['diagnosis']}")
        print(f"Solutions:")
        for solution in scenario['solution']:
            print(f"  • {solution}")
    
    print(f"\n" + "="*70)
    print(f"\nGeneral Debugging Strategy:")
    print(f"  1. Test locally first to isolate cluster-specific issues")
    print(f"  2. Start with debug profile (minimal resources) for quick iteration")
    print(f"  3. Check logs carefully - error messages usually point to the problem")
    print(f"  4. Verify environment on compute nodes before submitting large jobs")
    print(f"  5. Monitor resource usage during execution to catch issues early")
    print(f"="*70)

# Summary and Best Practices

This comprehensive guide has covered distributed computing with phasic from fundamental concepts through practical implementation to advanced optimization and troubleshooting. The key insights that emerge from this material are straightforward: modern distributed computing frameworks can dramatically simplify what was once extremely complex, but success requires understanding both the high-level abstractions and the underlying mechanisms.

The single most important practice is to develop incrementally, testing at each stage before scaling up. Write and debug your code locally where iteration is fast and debugging is easy. Test on a cluster with minimal resources (debug profile) to verify environment configuration and basic functionality. Scale to small production jobs to validate performance and convergence. Only then move to large-scale production runs. This staged approach catches problems early when they're easy to fix rather than late when they're expensive.

Understanding your computational requirements helps choose appropriate resources. Estimate how long computations will take based on test runs, then request cluster resources accordingly. Over-requesting wastes allocation and may delay job start, while under-requesting leads to killed jobs and wasted computation. Profile your code to identify bottlenecks—if evaluation is cheap, adding more devices won't help much, but if evaluation is expensive, distributed execution provides near-linear speedup.

Monitoring and logging are essential for production workflows. Write informative log messages at key points (initialization, iteration milestones, completion). Save intermediate results so you can analyze partial results if jobs fail. Use the coordinator check (dist_info.is_coordinator) to avoid cluttered output from multiple processes all printing the same information. Monitor resource usage during execution to catch memory leaks or inefficiencies.

Finally, remember that distributed computing is a tool to solve problems faster or at larger scale, not an end in itself. If local execution suffices for your problem, use it—simplicity has value. Distribute when you need results faster than local execution can provide, when your problem doesn't fit in local memory, or when you want to explore larger parameter spaces or use more particles for better inference. The framework makes distribution easy, but that doesn't mean you should always use it.

In [None]:
if dist_info.is_coordinator:
    print(f"\n" + "="*70)
    print(f"DISTRIBUTED COMPUTING WITH phasic - SUMMARY")
    print(f"="*70)
    
    print(f"\nKey Capabilities Demonstrated:")
    print(f"  ✓ Automatic SLURM detection and configuration")
    print(f"  ✓ JAX distributed initialization")
    print(f"  ✓ Parameterized phase-type distribution models")
    print(f"  ✓ JAX function conversion for automatic differentiation")
    print(f"  ✓ Distributed SVGD Bayesian inference")
    print(f"  ✓ Parallel gradient computation across devices")
    print(f"  ✓ Convergence monitoring and visualization")
    print(f"  ✓ Posterior analysis and predictive checking")
    
    print(f"\nBest Practices Checklist:")
    print(f"  □ Test locally before submitting to cluster")
    print(f"  □ Start with debug profile for initial cluster testing")
    print(f"  □ Use particles evenly divisible by device count")
    print(f"  □ Monitor convergence with return_history=True")
    print(f"  □ Check posterior predictive matches observations")
    print(f"  □ Use coordinator check for output/logging")
    print(f"  □ Save intermediate results for long-running jobs")
    print(f"  □ Profile performance to identify bottlenecks")
    
    print(f"\nNext Steps:")
    print(f"  1. Apply to your own phase-type distribution models")
    print(f"  2. Experiment with different cluster configurations")
    print(f"  3. Explore advanced techniques (symbolic computation, regularization)")
    print(f"  4. Scale to production workloads on your cluster")
    
    print(f"\nResources:")
    print(f"  • phasic documentation: https://docs.phasic.org")
    print(f"  • JAX distributed computing: https://jax.readthedocs.io/en/latest/distributed.html")
    print(f"  • SVGD paper: Liu & Wang (2016), NIPS")
    print(f"  • Phase-type distributions: Røikjer et al. (2022), Statistics and Computing")
    
    print(f"\n" + "="*70)
    print(f"Thank you for reading this comprehensive guide!")
    print(f"="*70)