# seed_everything

Comprehensive seeding for reproducible machine learning training across all major frameworks.

This notebook consolidates the entire `seed_everything` package into a single file. With a single function call, you can seed:
- Python's built-in `random` module
- NumPy
- PyTorch (CPU + CUDA + cuDNN)
- TensorFlow / Keras
- JAX
- scikit-learn
- Distributed training frameworks (torch.distributed, Horovod, DeepSpeed)

## Utilities

Utility functions for seed validation and management.

In [None]:
import logging
from typing import Dict, Any

# Configure logger for the package
logger = logging.getLogger(__name__)


def validate_seed(seed: int) -> None:
    """
    Validate that the seed is a non-negative integer within the valid range.

    Args:
        seed: The seed value to validate

    Raises:
        TypeError: If seed is not an integer
        ValueError: If seed is negative or out of valid range
    """
    if not isinstance(seed, int):
        raise TypeError(f"Seed must be an integer, got {type(seed).__name__}")

    if seed < 0:
        raise ValueError(f"Seed must be non-negative, got {seed}")

    # Maximum seed value for most systems (2^32 - 1)
    max_seed = 2**32 - 1
    if seed > max_seed:
        raise ValueError(f"Seed must be <= {max_seed}, got {seed}")


def get_seed_info() -> Dict[str, Any]:
    """
    Get information about the current seed state for all detected frameworks.

    Returns:
        Dictionary containing seed state information for available frameworks
    """
    info = {
        "python_available": True,
        "numpy_available": False,
        "torch_available": False,
        "tensorflow_available": False,
        "jax_available": False,
    }

    try:
        import numpy as np
        info["numpy_available"] = True
        info["numpy_version"] = np.__version__
    except ImportError:
        pass

    try:
        import torch
        info["torch_available"] = True
        info["torch_version"] = torch.__version__
        info["cuda_available"] = torch.cuda.is_available()
        if info["cuda_available"]:
            info["cuda_device_count"] = torch.cuda.device_count()
    except ImportError:
        pass

    try:
        import tensorflow as tf
        info["tensorflow_available"] = True
        info["tensorflow_version"] = tf.__version__
    except ImportError:
        pass

    try:
        import jax
        info["jax_available"] = True
        info["jax_version"] = jax.__version__
    except ImportError:
        pass

    return info


def log_seeding(framework: str, seed: int, warn: bool = True) -> None:
    """
    Log seeding operation for a framework.

    Args:
        framework: Name of the framework being seeded
        seed: The seed value used
        warn: Whether to emit warnings about potential non-deterministic operations
    """
    logger.info(f"Seeded {framework} with seed={seed}")

    if warn and framework == "torch":
        logger.warning(
            "PyTorch seeding configured. Note that some operations may still be non-deterministic. "
            "See https://pytorch.org/docs/stable/notes/randomness.html for details."
        )
    elif warn and framework == "tensorflow":
        logger.warning(
            "TensorFlow seeding configured. Some operations may still be non-deterministic. "
            "See https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism"
        )

## Python Seeding

Seed Python's built-in `random` module and set `PYTHONHASHSEED`.

In [None]:
import os
import random
from typing import Optional


def seed_python(seed: int, warn: bool = True) -> None:
    """
    Seed Python's built-in random module and set PYTHONHASHSEED.

    Args:
        seed: The seed value (non-negative integer)
        warn: Whether to emit warnings (default: True)
    """
    validate_seed(seed)

    # Seed the random module
    random.seed(seed)

    # Set PYTHONHASHSEED environment variable
    # Note: This must be set before Python starts to be fully effective,
    # but we set it here for subprocess and documentation purposes
    os.environ['PYTHONHASHSEED'] = str(seed)

    log_seeding("Python", seed, warn)

## NumPy Seeding

Seed NumPy's random number generator (both legacy and modern Generator).

In [None]:
import numpy as np


def seed_numpy(seed: int, warn: bool = True) -> Optional['np.random.Generator']:
    """
    Seed NumPy's random number generator.

    This function seeds both the legacy numpy.random.seed() for backwards compatibility
    and creates a modern numpy.random.Generator with PCG64 for new code.

    Args:
        seed: The seed value (non-negative integer)
        warn: Whether to emit warnings (default: True)

    Returns:
        A numpy.random.Generator instance with PCG64
    """
    validate_seed(seed)

    # Seed legacy global RNG
    np.random.seed(seed)

    # Create and return a modern Generator instance
    rng = np.random.Generator(np.random.PCG64(seed))

    log_seeding("NumPy", seed, warn)

    return rng

## PyTorch Seeding

Seed PyTorch for CPU and CUDA devices with optional deterministic configuration.

In [None]:
import torch


def seed_torch(
    seed: int,
    deterministic: bool = True,
    benchmark: bool = False,
    warn: bool = True
) -> None:
    """
    Seed PyTorch for CPU and CUDA devices with optional deterministic configuration.

    Args:
        seed: The seed value (non-negative integer)
        deterministic: If True, configure PyTorch for deterministic operations (default: True)
        benchmark: If True, enable cuDNN benchmark mode for performance (default: False)
                   Note: benchmark mode may introduce non-determinism
        warn: Whether to emit warnings (default: True)
    """
    validate_seed(seed)

    # Seed PyTorch CPU
    torch.manual_seed(seed)

    # Seed PyTorch CUDA (if available)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU

    # Configure cuDNN for determinism
    if hasattr(torch.backends, 'cudnn'):
        torch.backends.cudnn.deterministic = deterministic
        torch.backends.cudnn.benchmark = benchmark

    # Enable deterministic algorithms (PyTorch >= 1.8)
    if deterministic and hasattr(torch, 'use_deterministic_algorithms'):
        try:
            torch.use_deterministic_algorithms(True)
        except RuntimeError as e:
            # Some operations don't support deterministic mode
            if warn:
                logger.warning(
                    f"Could not enable torch.use_deterministic_algorithms: {e}. "
                    "Some operations may still be non-deterministic."
                )

    log_seeding("PyTorch", seed, warn)

## TensorFlow Seeding

Seed TensorFlow and configure for deterministic operations.

In [None]:
import tensorflow as tf


def seed_tensorflow(seed: int, deterministic: bool = True, warn: bool = True) -> None:
    """
    Seed TensorFlow and configure for deterministic operations.

    Args:
        seed: The seed value (non-negative integer)
        deterministic: If True, configure TensorFlow for deterministic operations (default: True)
        warn: Whether to emit warnings (default: True)
    """
    validate_seed(seed)

    # Seed TensorFlow
    tf.random.set_seed(seed)

    # Configure for deterministic operations
    if deterministic:
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

        # For TensorFlow 2.x, also try to enable op determinism
        if hasattr(tf.config.experimental, 'enable_op_determinism'):
            try:
                tf.config.experimental.enable_op_determinism()
            except Exception as e:
                if warn:
                    logger.warning(f"Could not enable TensorFlow op determinism: {e}")

    log_seeding("TensorFlow", seed, warn)

## JAX Seeding

Create a JAX PRNG key from a given seed. JAX uses explicit PRNG keys instead of global state.

In [None]:
import jax
import jax.random


def seed_jax(seed: int, warn: bool = True) -> Any:
    """
    Create a JAX PRNG key from the given seed.

    JAX uses a different random number generation approach than other frameworks.
    Instead of global state, it uses explicit PRNG keys that must be passed around.

    Args:
        seed: The seed value (non-negative integer)
        warn: Whether to emit warnings (default: True)

    Returns:
        A JAX PRNG key
    """
    validate_seed(seed)

    # Create a PRNG key from the seed
    key = jax.random.PRNGKey(seed)

    log_seeding("JAX", seed, warn)

    return key

## scikit-learn Seeding

Utilities for seeding scikit-learn estimators with `RandomState`.

In [None]:
def get_sklearn_random_state(seed: int, warn: bool = True) -> 'np.random.RandomState':
    """
    Create a numpy RandomState instance for use with scikit-learn estimators.

    scikit-learn estimators accept a random_state parameter that can be an integer
    or a numpy.random.RandomState instance. This function creates a RandomState
    for consistent seeding across sklearn operations.

    Args:
        seed: The seed value (non-negative integer)
        warn: Whether to emit warnings (default: True)

    Returns:
        A numpy.random.RandomState instance
    """
    validate_seed(seed)

    random_state = np.random.RandomState(seed)

    log_seeding("scikit-learn", seed, warn)

    return random_state


def seed_sklearn_estimator(estimator: Any, seed: int) -> Any:
    """
    Set the random_state of a scikit-learn estimator if it has that parameter.

    Args:
        estimator: A scikit-learn estimator object
        seed: The seed value (non-negative integer)

    Returns:
        The estimator with random_state set (returns original if no random_state parameter)
    """
    validate_seed(seed)

    # Try set_params first (most common sklearn pattern)
    if hasattr(estimator, 'set_params'):
        try:
            estimator.set_params(random_state=seed)
            return estimator
        except (ValueError, TypeError):
            # Estimator doesn't have random_state parameter or set_params failed
            pass

    # Try direct attribute setting (fallback)
    if hasattr(estimator, 'random_state'):
        try:
            estimator.random_state = seed
        except AttributeError:
            # random_state is a read-only property
            pass

    return estimator

## Distributed Training

Distributed training seeding utilities including rank detection, rank-aware seeding, and DataLoader worker initialization.

In [None]:
from typing import Callable


def get_rank() -> Optional[int]:
    """
    Get the rank of the current process in distributed training.

    Attempts to detect rank from various distributed training frameworks:
    - PyTorch Distributed (torch.distributed)
    - Horovod
    - DeepSpeed
    - Environment variables (RANK, LOCAL_RANK, SLURM_PROCID)

    Returns:
        The rank as an integer, or None if not in distributed mode
    """
    # Try PyTorch distributed
    try:
        import torch.distributed as dist
        if dist.is_available() and dist.is_initialized():
            return dist.get_rank()
    except (ImportError, RuntimeError):
        pass

    # Try Horovod
    try:
        import horovod.torch as hvd
        hvd.init()
        return hvd.rank()
    except (ImportError, ValueError):
        pass

    # Try DeepSpeed
    try:
        import deepspeed
        if hasattr(deepspeed, 'comm') and deepspeed.comm.is_initialized():
            return deepspeed.comm.get_rank()
    except (ImportError, AttributeError):
        pass

    # Try environment variables
    for env_var in ['RANK', 'LOCAL_RANK', 'SLURM_PROCID', 'PMI_RANK']:
        if env_var in os.environ:
            try:
                return int(os.environ[env_var])
            except ValueError:
                pass

    return None


def seed_distributed(
    seed: int,
    rank: Optional[int] = None,
    warn: bool = True
) -> int:
    """
    Seed for distributed training with rank-aware seeding.

    Each rank/worker gets a deterministic but different seed (base_seed + rank)
    to ensure different data ordering per worker while maintaining reproducibility.

    Args:
        seed: The base seed value (non-negative integer)
        rank: The rank of the current process. If None, attempts to auto-detect.
        warn: Whether to emit warnings (default: True)

    Returns:
        The actual seed used (base_seed + rank)
    """
    validate_seed(seed)

    # Auto-detect rank if not provided
    if rank is None:
        rank = get_rank()
        if rank is None:
            rank = 0  # Default to rank 0 if not in distributed mode

    # Compute rank-specific seed
    rank_seed = seed + rank

    # Ensure rank_seed doesn't overflow
    max_seed = 2**32 - 1
    if rank_seed > max_seed:
        rank_seed = rank_seed % max_seed

    log_seeding(f"Distributed (rank={rank})", rank_seed, warn)

    # Configure NCCL for reproducibility
    if 'NCCL_DEBUG' not in os.environ:
        os.environ['NCCL_DEBUG'] = 'WARN'

    return rank_seed


def get_worker_init_fn(base_seed: int = 42) -> Callable[[int], None]:
    """
    Get a worker initialization function for PyTorch DataLoader.

    This ensures each DataLoader worker has a deterministic but different seed.
    Use this with torch.utils.data.DataLoader(worker_init_fn=...).

    Example:
        >>> loader = DataLoader(dataset, worker_init_fn=get_worker_init_fn(42))

    Args:
        base_seed: The base seed value (default: 42)

    Returns:
        A function that can be used as worker_init_fn for DataLoader
    """
    def _init_fn(worker_id: int) -> None:
        """Initialize worker with deterministic seed."""
        # Compute worker-specific seed
        worker_seed = base_seed + worker_id

        # Seed Python
        random.seed(worker_seed)

        # Seed NumPy
        np.random.seed(worker_seed)

        # Seed PyTorch
        torch.manual_seed(worker_seed)

    return _init_fn

## Core: `seed_everything`

The main entry point that seeds all available ML frameworks and Python's random module with a single call.

In [None]:
def seed_everything(
    seed: int = 42,
    deterministic: bool = True,
    warn: bool = True
) -> Dict[str, Any]:
    """
    Seed all available ML frameworks and Python's random module.

    This is the main entry point. It seeds:
    - Python's built-in random module and PYTHONHASHSEED
    - NumPy
    - PyTorch CPU and CUDA
    - TensorFlow
    - JAX

    Args:
        seed: The seed value (non-negative integer, default: 42)
        deterministic: If True, configure frameworks for deterministic operations (default: True)
        warn: Whether to emit warnings about non-deterministic operations (default: True)

    Returns:
        Dictionary with seeding results and JAX PRNG key (if JAX is available)

    Raises:
        TypeError: If seed is not an integer
        ValueError: If seed is negative or out of valid range

    Example:
        >>> result = seed_everything(42)
        >>> print(result)
    """
    validate_seed(seed)

    result = {
        'seed': seed,
        'deterministic': deterministic,
    }

    # Seed Python standard library
    seed_python(seed, warn=warn)
    result['python'] = True

    # Seed NumPy
    numpy_rng = seed_numpy(seed, warn=warn)
    result['numpy'] = numpy_rng is not None
    if numpy_rng is not None:
        result['numpy_rng'] = numpy_rng

    # Seed PyTorch
    seed_torch(seed, deterministic=deterministic, benchmark=False, warn=warn)
    result['torch'] = True
    result['torch_cuda'] = torch.cuda.is_available()

    # Seed TensorFlow
    seed_tensorflow(seed, deterministic=deterministic, warn=warn)
    result['tensorflow'] = True

    # Seed JAX
    jax_key = seed_jax(seed, warn=warn)
    result['jax'] = jax_key is not None
    if jax_key is not None:
        result['jax_key'] = jax_key

    return result

## Usage Examples

### Basic Usage

Seed all available frameworks with a single call.

In [None]:
# Seed all available frameworks with a single call
result = seed_everything(42)

# Now all random operations are reproducible
print(random.random())       # Reproducible
print(np.random.rand())      # Reproducible
print(torch.rand(1))         # Reproducible

### Check Framework Availability

Get information about which frameworks are installed and their versions.

In [None]:
info = get_seed_info()
for key, value in info.items():
    print(f"{key}: {value}")

### Advanced Usage

Seed with custom options and check which frameworks were seeded.

In [None]:
# Seed with custom options
result = seed_everything(
    seed=42,
    deterministic=True,   # Enable deterministic operations (may impact performance)
    warn=True             # Emit warnings about potential non-deterministic operations
)

# Check which frameworks were seeded
print(f"NumPy seeded: {result['numpy']}")
print(f"PyTorch seeded: {result['torch']}")
print(f"TensorFlow seeded: {result['tensorflow']}")

# For JAX, you get a PRNG key
if result['jax']:
    jax_key = result['jax_key']
    print(f"JAX PRNG key: {jax_key}")

### Individual Framework Seeding

Seed each framework independently.

In [None]:
# Seed only specific frameworks
seed_python(42)
seed_numpy(42)
seed_torch(42, deterministic=True)
seed_tensorflow(42)
jax_key = seed_jax(42)

### Distributed Training

Use rank-aware seeding for distributed training.

In [None]:
# Automatically detect rank and seed accordingly
seed = seed_distributed(base_seed=42)  # rank 0 gets seed 42, rank 1 gets 43, etc.
print(f"Distributed seed: {seed}")

# Or explicitly specify rank
seed = seed_distributed(base_seed=42, rank=3)  # This worker gets seed 45
print(f"Distributed seed (rank 3): {seed}")

### DataLoader Worker Seeding (PyTorch)

Create a worker initialization function for reproducible data loading.

In [None]:
from torch.utils.data import DataLoader, TensorDataset

# Create a simple dataset for demonstration
dataset = TensorDataset(torch.randn(100, 10), torch.randint(0, 2, (100,)))

# Create a DataLoader with deterministic worker seeding
loader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=0,  # Set to >0 for multi-worker usage
    worker_init_fn=get_worker_init_fn(base_seed=42)
)

print(f"DataLoader created with {len(loader)} batches")

### scikit-learn Support

Create `RandomState` instances and seed scikit-learn estimators.

In [None]:
from sklearn.ensemble import RandomForestClassifier

# Get a numpy RandomState for sklearn
random_state = get_sklearn_random_state(42)
print(f"RandomState: {random_state}")

# Use it with sklearn estimators
clf = RandomForestClassifier(random_state=random_state)
print(f"Classifier random_state set: {clf.random_state}")

# Or seed an existing estimator
clf2 = RandomForestClassifier()
seed_sklearn_estimator(clf2, 42)
print(f"Classifier random_state seeded: {clf2.random_state}")

### Seed Validation

The `validate_seed` function ensures seed values are valid.

In [None]:
# Valid seeds
validate_seed(0)
validate_seed(42)
validate_seed(2**32 - 1)
print("All valid seeds passed!")

# Invalid seeds (will raise errors)
try:
    validate_seed(-1)
except ValueError as e:
    print(f"ValueError: {e}")

try:
    validate_seed("not_an_int")
except TypeError as e:
    print(f"TypeError: {e}")

try:
    validate_seed(2**32)
except ValueError as e:
    print(f"ValueError: {e}")