diff --git a/README.md b/README.md index 3bb1c7c..522ef7c 100644 --- a/README.md +++ b/README.md @@ -20,19 +20,37 @@ Train on the [included enwik8 dataset](data/README.md), character-level modeling # 100k batches on enwik8, 35M param Llama python train.py --config configs/simple.yaml -# Quick test (tiny model, 10 batches) +# Nano run for CPU / MPS shakedowns (10k steps, L6 · H384 · ~9M params) +python train.py --config configs/nano.yaml + +# Quick smoke test (tiny model, 10 batches) python train.py --config configs/test.yaml ``` +## Device Selection & Precision + +- The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda → mps → cpu`, returning `(device, device_type, amp_dtype)` and printing the accelerator picked. +- Override detection with `FORCE_DEVICE=cuda`, `FORCE_DEVICE=cpu`, or even `FORCE_DEVICE=cuda:1` to pick a specific index (also available as the `force=` argument). +- Mixed precision uses `torch.autocast` with `torch.bfloat16`; toggle via config if you want full fp32. + +## Device Support + +| Device | Status | Notes | +| ------------- | ------ | --------------------------------------------------- | +| NVIDIA GPU | ✅ | Best performance, fused optimizer & flash attention | +| Apple Silicon | ✅ | Good performance, autocast can be flaky | +| CPU | ✅ | Slow but works; use `configs/nano.yaml` | + ## Structure ```text decoder-pytorch-template/ ├── decoder_pytorch/ # Model implementation │ ├── llama.py # Llama architecture -│ └── utils.py # Sampling utilities +│ └── utils.py # Sampling & device helpers ├── configs/ # Training configs │ ├── simple.yaml # Default config +│ ├── nano.yaml # Quick CPU/MPS config │ └── test.yaml # Quick test config ├── data/ │ └── enwik8.gz # Character-level dataset @@ -61,7 +79,7 @@ To add your own model architecture: 4. **Update training script**: Modify `train.py` line 16 and 88: ```python - from decoder_pytorch import YourModel, configure_tf32, model_summary + from decoder_pytorch import YourModel, model_summary # ... model = YourModel( num_tokens=config.get("num_tokens", 256), @@ -113,7 +131,7 @@ Dependencies: - einops, pyyaml, tqdm - [rotary-embedding-torch](https://github.com/lucidrains/rotary-embedding-torch) -[^2]: If using PyTorch <2.9, you'll need to modify the TF32 configuration in `decoder_pytorch/utils.py` to use the legacy API (`torch.set_float32_matmul_precision("high")`) or skip TF32 setup entirely. +[^2]: If using PyTorch <2.9, you may need to adjust the bfloat16/autocast behaviour or fall back to full fp32 depending on hardware support. ## License diff --git a/configs/nano.yaml b/configs/nano.yaml new file mode 100644 index 0000000..10e1e05 --- /dev/null +++ b/configs/nano.yaml @@ -0,0 +1,36 @@ +# Smaller than simple, larger than test (L6 · H384 · ~9M params) +run_dir: runs/nano + +# Model +num_tokens: 256 +dim: 384 +depth: 6 +heads: 6 +dim_head: 64 +tied_embedding: true +ffn_dim_multiplier: 1.5 +flash_attn: true +compile: true # turn off if torch.compile is unavailable +use_autocast: true + +# Training schedule +num_batches: 10000 +batch_size: 1 +grad_accum_every: 16 +learning_rate: 0.002 +weight_decay: 0.0003 +grad_clip_norm: 1.0 + +# Data +data_path: data/enwik8.gz +seq_len: 512 + +# training/validation/generation +validate_every: 250 +val_batches: 20 +generate_every: 250 +save_every: 2000 +temperature: 1.0 +min_p: 0.1 + +seed: 7 diff --git a/configs/simple.yaml b/configs/simple.yaml index f544948..5e23c0e 100644 --- a/configs/simple.yaml +++ b/configs/simple.yaml @@ -12,6 +12,7 @@ tied_embedding: true # share in/out embeddings ffn_dim_multiplier: 1.5 # hidden dim multiplier --> FFN size (here, 768) flash_attn: true # use flash attn (through torch api) compile: false # speed up training +use_autocast: true # enable mixed precision (bfloat16) # Training num_batches: 100000 # total steps diff --git a/configs/test.yaml b/configs/test.yaml index 42e0dba..6e0f54e 100644 --- a/configs/test.yaml +++ b/configs/test.yaml @@ -10,6 +10,7 @@ dim_head: 32 tied_embedding: true flash_attn: true compile: false +use_autocast: true # Quick training num_batches: 10 @@ -28,5 +29,4 @@ val_batches: 50 generate_every: 1000 # Don't generate during test save_every: 1000 # Don't save during test -# Random seed for reproducibility -seed: 42 \ No newline at end of file +seed: 7 \ No newline at end of file diff --git a/decoder_pytorch/__init__.py b/decoder_pytorch/__init__.py index ec1f767..71a7f68 100644 --- a/decoder_pytorch/__init__.py +++ b/decoder_pytorch/__init__.py @@ -2,7 +2,7 @@ from .llama import Llama from .utils import ( - configure_tf32, + get_optimal_device, gumbel_noise, gumbel_sample, log, @@ -22,6 +22,6 @@ "top_k_filter", "top_p_filter", # Torch utilities - "configure_tf32", "model_summary", + "get_optimal_device", ] diff --git a/decoder_pytorch/llama.py b/decoder_pytorch/llama.py index dfb9d25..a8ffa32 100644 --- a/decoder_pytorch/llama.py +++ b/decoder_pytorch/llama.py @@ -140,6 +140,11 @@ def __init__( self.dim_head = dim_head self.causal = causal self.flash_attn = flash_attn + if self.flash_attn and not self._flash_attn_available(): + print( + "Warning: Flash attention requested but not available, using standard attention." + ) + self.flash_attn = False self.max_seq_len = max_seq_len inner_dim = heads * dim_head @@ -158,6 +163,20 @@ def __init__( # Register causal mask buffer (will be created on first use) self.register_buffer("causal_mask", None, persistent=False) + def _flash_attn_available(self) -> bool: + """Return True if flash attention kernels are available.""" + if torch.cuda.is_available(): + return True + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return True + dynamo = getattr(torch, "_dynamo", None) + if dynamo is not None: + try: + return bool(dynamo.is_compiling()) + except Exception: + return False + return False + def forward( self, x: torch.Tensor, diff --git a/decoder_pytorch/utils.py b/decoder_pytorch/utils.py index be6a821..a44511a 100644 --- a/decoder_pytorch/utils.py +++ b/decoder_pytorch/utils.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple @@ -108,36 +109,85 @@ def top_p_filter(logits: Tensor, p: float = 0.9) -> Tensor: # -------------------------------------------------------------------------- -def configure_tf32() -> bool: - """Enable TF32 precision for GPUs with compute capability >= 8.0 (Ampere+). +def _mps_available() -> bool: + """Return True if MPS is available.""" + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() - Uses the PyTorch 2.9+ API for TF32 configuration. - :return: True if TF32 was enabled, False otherwise +def get_optimal_device( + force: Optional[str] = None, +) -> Tuple[torch.device, str, torch.dtype]: + """Return best available accelerator (device, device_type, amp_dtype). + + The function tries CUDA → MPS → CPU, unless the user forces a choice via + the ``force`` argument or the ``FORCE_DEVICE`` environment variable. The + return value is intentionally simple—a tuple that works well with tuple + unpacking in training scripts. """ - if not torch.cuda.is_available(): - print("No GPU detected, running on CPU.") - return False - - try: - device = torch.cuda.current_device() - capability = torch.cuda.get_device_capability(device) - major, minor = capability - gpu_name = torch.cuda.get_device_name(device) - - if major >= 8: - # PyTorch 2.9+ API for TF32 configuration - torch.backends.cudnn.conv.fp32_precision = "tf32" - torch.backends.cuda.matmul.fp32_precision = "tf32" - print(f"{gpu_name} (compute {major}.{minor}) - TF32 enabled") - return True - else: - print(f"{gpu_name} (compute {major}.{minor}) - TF32 not supported") - return False - except Exception as e: - print(f"Error: failed to configure GPU: {e}") - return False + def _normalize(device_str: str) -> str: + return device_str.split(":", 1)[0] + + requested = (force or os.getenv("FORCE_DEVICE", "")).strip().lower() + valid_types = {"cuda", "mps", "cpu"} + + if requested: + requested_type = _normalize(requested) + if requested_type not in valid_types: + print( + f"Warning: unsupported FORCE_DEVICE='{requested}'. " + "Falling back to auto-detect." + ) + requested = "" + elif requested_type == "cuda" and not torch.cuda.is_available(): + print("Warning: CUDA requested but not available, falling back.") + requested = "" + elif requested_type == "mps" and not _mps_available(): + print("Warning: MPS requested but not available, falling back.") + requested = "" + + if requested: + try: + device = torch.device(requested) + except (RuntimeError, ValueError) as err: + print(f"Warning: could not create device '{requested}' ({err}).") + requested = "" + else: + device_type = _normalize(requested) + if device_type == "cuda": + index = device.index or 0 + device_count = torch.cuda.device_count() + if index >= device_count: + print( + f"Warning: CUDA index {index} unavailable " + f"(found {device_count} device(s)). Falling back." + ) + requested = "" + else: + name = torch.cuda.get_device_name(index) + print(f"Using CUDA device {index}: {name}") + return device, "cuda", torch.bfloat16 + elif device_type == "mps": + print("Using Apple Silicon (MPS)") + return device, "mps", torch.bfloat16 + else: + print("Using CPU (forced)") + return device, "cpu", torch.bfloat16 + + if torch.cuda.is_available(): + device = torch.device("cuda") + name = torch.cuda.get_device_name(0) + print(f"Using CUDA: {name}") + return device, "cuda", torch.bfloat16 + + if _mps_available(): + device = torch.device("mps") + print("Using Apple Silicon (MPS)") + return device, "mps", torch.bfloat16 + + device = torch.device("cpu") + print("Using CPU (no GPU acceleration available)") + return device, "cpu", torch.bfloat16 @dataclass @@ -151,13 +201,19 @@ class _LayerSummary: def model_summary( - model: nn.Module, max_depth: int = 4, show_param_shapes: bool = False + model: nn.Module, + max_depth: int = 4, + show_param_shapes: bool = False, + show_frozen_breakdown: bool = False, ) -> None: """Print hierarchical summary of model with parameter counts. :param model: PyTorch model to summarize :param max_depth: Maximum depth of hierarchy to display :param show_param_shapes: Whether to show parameter shapes + :param show_frozen_breakdown: If True, display separate trainable/frozen counts + per module. Defaults to False for a simpler view that highlights whether + a module is fully trainable, fully frozen, or mixed. """ # ---------- formatting helpers ---------- @@ -248,15 +304,60 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]: max(len(_format_shape(s.param_shape)) for s in summary_list), ) - params_col_width = 12 - trainable_col_width = 10 - col_spacing = " " + params_col_width = max( + len("Param #"), + max(len(_format_number(s.inclusive_total_params)) for s in summary_list), + ) header_parts = [f"{'Layer (type)':<{name_col_width}}"] if show_param_shapes: header_parts.append(f"{'Param Shape':>{shape_col_width}}") + header_parts.append(f"{'Param #':>{params_col_width}}") - header_parts.append(f"{'Trainable':>{trainable_col_width}}") + + if show_frozen_breakdown: + trainable_col_width = max( + len("Trainable #"), + max( + len(_format_number(s.inclusive_trainable_params)) for s in summary_list + ), + ) + frozen_col_width = max( + len("Frozen #"), + max( + len( + _format_number( + s.inclusive_total_params - s.inclusive_trainable_params + ) + ) + for s in summary_list + ), + ) + header_parts.append(f"{'Trainable #':>{trainable_col_width}}") + header_parts.append(f"{'Frozen #':>{frozen_col_width}}") + else: + + def _grad_state(total: int, trainable: int) -> str: + if trainable == 0: + return "frozen" + if trainable == total: + return "trainable" + return "mixed" + + grad_states = [ + _grad_state( + s.inclusive_total_params, + s.inclusive_trainable_params, + ) + for s in summary_list + ] + grad_state_width = max( + len("Grad State"), max(len(state) for state in grad_states) + ) + header_parts.append(f"{'Grad State':>{grad_state_width}}") + + col_spacing = " " + header = col_spacing.join(header_parts) sep = "=" * len(header) @@ -268,7 +369,15 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]: if show_param_shapes: parts.append(f"{_format_shape(e.param_shape):>{shape_col_width}}") parts.append(f"{_format_number(e.inclusive_total_params):>{params_col_width}}") - parts.append(f"{str(e.inclusive_trainable_params > 0):>{trainable_col_width}}") + if show_frozen_breakdown: + parts.append( + f"{_format_number(e.inclusive_trainable_params):>{trainable_col_width}}" + ) + frozen = e.inclusive_total_params - e.inclusive_trainable_params + parts.append(f"{_format_number(frozen):>{frozen_col_width}}") + else: + state = _grad_state(e.inclusive_total_params, e.inclusive_trainable_params) + parts.append(f"{state:>{grad_state_width}}") print(col_spacing.join(parts)) print(sep) print(f"Total params: {_format_number(total_params)}") diff --git a/train.py b/train.py index 776e3bc..3381183 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ import argparse import gzip import json +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -13,7 +14,7 @@ from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm -from decoder_pytorch import Llama, configure_tf32, model_summary +from decoder_pytorch import Llama, get_optimal_device, model_summary # Data utilities @@ -65,9 +66,28 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): with open(config_path) as f: config = yaml.safe_load(f) - # Setup - device = "cuda" if torch.cuda.is_available() else "cpu" - configure_tf32() # Enable TF32 for Ampere+ GPUs + # Setup device + device, device_type, amp_dtype = get_optimal_device() + print(f"Device: {device}") + + device_caps = { + "supports_fused_optimizer": device_type == "cuda", + "supports_flash_attn": device_type in ("cuda", "mps"), + } + + # Setup autocast context + use_autocast = bool(config.get("use_autocast", True)) + + def autocast_context(): + if use_autocast: + return torch.autocast(device_type=device_type, dtype=amp_dtype) + return nullcontext() + + print( + f"Mixed precision: {'enabled' if use_autocast else 'disabled'}" + f"{f' ({amp_dtype})' if use_autocast else ' (full fp32)'}" + ) + if config.get("seed"): torch.manual_seed(config["seed"]) @@ -85,6 +105,15 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): val_loader = cycle(val_loader) # Create model + flash_attn_requested = config.get("flash_attn") + if flash_attn_requested is None: + flash_attn_requested = device_caps["supports_flash_attn"] + elif flash_attn_requested and not device_caps["supports_flash_attn"]: + print( + "Warning: flash attention requested but not supported on this device; using standard attention." + ) + flash_attn_requested = False + model = Llama( num_tokens=config.get("num_tokens", 256), dim=config.get("dim", 512), @@ -93,17 +122,18 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): dim_head=config.get("dim_head", 64), tied_embedding=config.get("tied_embedding", True), ffn_dim_multiplier=config.get("ffn_dim_multiplier"), - flash_attn=config.get("flash_attn", True), + flash_attn=bool(flash_attn_requested), ).to(device) model_summary(model, max_depth=3, show_param_shapes=True) # Optimizer + # Fused optimizer is available for CUDA only (not MPS or CPU) optimizer = Adam( model.parameters(), lr=config.get("learning_rate", 1e-3), weight_decay=config.get("weight_decay", 0.0), - fused=torch.cuda.is_available(), + fused=device_caps["supports_fused_optimizer"], ) # Training state @@ -140,7 +170,9 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): model.train() # Training step with gradient accumulation - # Accumulate raw losses and token counts for correct normalization + # Accumulate raw losses and token counts for correct normalization. + # Critical pitfall: normalizing each micro-batch separately leads to + # incorrect gradients. Sum first, divide once at the end. total_loss_sum = 0 total_tokens = 0 loss_accumulator = None # Keep gradient graph alive @@ -151,7 +183,7 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): inputs = data[:, :-1] targets = data[:, 1:] - with torch.autocast(device_type=device, dtype=torch.bfloat16): + with autocast_context(): logits = model(inputs, return_loss=False) # Compute unnormalized loss loss_unreduced = torch.nn.functional.cross_entropy( @@ -195,7 +227,7 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): data = next(val_loader).to(device) inputs = data[:, :-1] targets = data[:, 1:] - with torch.no_grad(): + with torch.no_grad(), autocast_context(): logits = model(inputs, return_loss=False) loss_unreduced = torch.nn.functional.cross_entropy( logits.reshape(-1, logits.size(-1)),