From 9e1fb98e7fb2f81bca49626ccf2aaa8ad5eb8a92 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 25 Oct 2025 01:19:42 +0000 Subject: [PATCH 1/8] Add device utility with MPS support and improve device handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds comprehensive device detection and selection utilities that support CUDA, MPS (Apple Silicon), and CPU backends with automatic fallback logic. Changes: - Add decoder_pytorch/device.py with DeviceSelection dataclass and get_optimal_device() function - Update decoder_pytorch/__init__.py to export new device utilities - Refactor train.py to use get_optimal_device() instead of hardcoded device selection - Use device-specific autocast dtype (bfloat16 for CUDA/MPS, float32 for CPU) - Integrate TF32 configuration into get_optimal_device for CUDA - Update fused optimizer check to only enable on CUDA (not MPS/CPU) The get_optimal_device() function provides: - Automatic device detection with configurable priority order - Force device selection via parameter or FORCE_DEVICE env var - Integrated TF32 configuration for CUDA devices - Appropriate autocast dtype selection per device type - Detailed device info logging This ensures the codebase works seamlessly across CUDA, MPS, and CPU devices with optimal settings for each platform. πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- decoder_pytorch/__init__.py | 4 ++ decoder_pytorch/device.py | 115 ++++++++++++++++++++++++++++++++++++ train.py | 18 ++++-- 3 files changed, 131 insertions(+), 6 deletions(-) create mode 100644 decoder_pytorch/device.py diff --git a/decoder_pytorch/__init__.py b/decoder_pytorch/__init__.py index ec1f767..72c3b46 100644 --- a/decoder_pytorch/__init__.py +++ b/decoder_pytorch/__init__.py @@ -1,5 +1,6 @@ """Llama-style transformer for language modeling experiments.""" +from .device import DeviceSelection, get_optimal_device from .llama import Llama from .utils import ( configure_tf32, @@ -24,4 +25,7 @@ # Torch utilities "configure_tf32", "model_summary", + # Device utilities + "DeviceSelection", + "get_optimal_device", ] diff --git a/decoder_pytorch/device.py b/decoder_pytorch/device.py new file mode 100644 index 0000000..85175f1 --- /dev/null +++ b/decoder_pytorch/device.py @@ -0,0 +1,115 @@ +# decoder_pytorch/device.py + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Iterable, Optional, Tuple + +import torch + + +__all__ = ["DeviceSelection", "get_optimal_device"] + + +@dataclass(frozen=True) +class DeviceSelection: + """Result of device selection.""" + + device: torch.device # For .to() / model placement + device_type: str # For torch.autocast(device_type=...) + device_info: str # Human-readable description + amp_dtype: torch.dtype # Suggested autocast dtype (bfloat16 on accel) + + +def _cuda_info(index: int) -> str: + name = torch.cuda.get_device_name(index) + major, minor = torch.cuda.get_device_capability(index) + return f"CUDA GPU: {name} (compute {major}.{minor})" + + +def _mps_available() -> bool: + # torch.backends.mps is present only on macOS builds + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +def get_optimal_device( + preferred_order: Iterable[str] = ("cuda", "mps", "cpu"), + device_index: int = 0, + *, + force: Optional[str] = None, + enable_tf32: Optional[bool] = None, + logger: Optional[logging.Logger] = None, +) -> DeviceSelection: + """ + Detect and return the best available device for PyTorch. + + Priority is defined by `preferred_order` (default: CUDA > MPS > CPU). + You can force a specific device via `force="cuda"|"mps"|"cpu"` or the + environment variable FORCE_DEVICE with the same values. + + Returns: + DeviceSelection(device, device_type, device_info, amp_dtype) + + Notes: + - For CUDA & MPS, the suggested autocast dtype is torch.bfloat16. + - CPU autocast is limited and typically uses float32. + - If `enable_tf32` is provided, it toggles TF32 on CUDA (Ampere+). + """ + log = logger or logging.getLogger(__name__) + choice = (force or os.getenv("FORCE_DEVICE", "")).lower().strip() + + # Normalize and validate preference list + normalized_order: Tuple[str, ...] = ( + tuple( + d + for d in (choice,) + tuple(preferred_order) + if d in ("cuda", "mps", "cpu") and d + ) + if choice + else tuple(d for d in preferred_order if d in ("cuda", "mps", "cpu")) + ) + + # Try each device in order + for dev_type in normalized_order: + if dev_type == "cuda" and torch.cuda.is_available(): + # Optional CUDA knobs + if enable_tf32 is not None: + try: + torch.backends.cuda.matmul.allow_tf32 = bool(enable_tf32) + torch.backends.cudnn.allow_tf32 = bool(enable_tf32) + except Exception: + # TF32 toggles aren't critical-ignore if unavailable + pass + + try: + info = _cuda_info(device_index) + except Exception as e: + info = f"CUDA GPU (info unavailable: {e})" + + device = torch.device("cuda", index=device_index) + amp_dtype = torch.bfloat16 + log.info("Using %s", info) + return DeviceSelection(device, "cuda", info, amp_dtype) + + if dev_type == "mps" and _mps_available(): + info = "Apple Silicon (MPS)" + device = torch.device("mps") + amp_dtype = torch.bfloat16 + log.info("Using %s", info) + return DeviceSelection(device, "mps", info, amp_dtype) + + if dev_type == "cpu": + info = "CPU" + device = torch.device("cpu") + amp_dtype = torch.float32 + log.info("Using %s (no GPU acceleration available)", info) + return DeviceSelection(device, "cpu", info, amp_dtype) + + # Absolute fallback (shouldn't be reached) + info = "CPU" + device = torch.device("cpu") + amp_dtype = torch.float32 + log.warning("Falling back to CPU; no preferred devices were available.") + return DeviceSelection(device, "cpu", info, amp_dtype) diff --git a/train.py b/train.py index 776e3bc..5ffb5c6 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm -from decoder_pytorch import Llama, configure_tf32, model_summary +from decoder_pytorch import Llama, get_optimal_device, model_summary # Data utilities @@ -65,9 +65,11 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): with open(config_path) as f: config = yaml.safe_load(f) - # Setup - device = "cuda" if torch.cuda.is_available() else "cpu" - configure_tf32() # Enable TF32 for Ampere+ GPUs + # Setup device + device_sel = get_optimal_device(enable_tf32=True) # Auto-detects cuda/mps/cpu + device = device_sel.device + print(f"Device: {device_sel.device_info}") + if config.get("seed"): torch.manual_seed(config["seed"]) @@ -99,11 +101,13 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): model_summary(model, max_depth=3, show_param_shapes=True) # Optimizer + # Fused optimizer is available for CUDA only (not MPS or CPU) + use_fused = device_sel.device_type == "cuda" optimizer = Adam( model.parameters(), lr=config.get("learning_rate", 1e-3), weight_decay=config.get("weight_decay", 0.0), - fused=torch.cuda.is_available(), + fused=use_fused, ) # Training state @@ -151,7 +155,9 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): inputs = data[:, :-1] targets = data[:, 1:] - with torch.autocast(device_type=device, dtype=torch.bfloat16): + with torch.autocast( + device_type=device_sel.device_type, dtype=device_sel.amp_dtype + ): logits = model(inputs, return_loss=False) # Compute unnormalized loss loss_unreduced = torch.nn.functional.cross_entropy( From cfc97da6fb5113dcf733bfe238b2d4ba409124cf Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 25 Oct 2025 01:29:39 +0000 Subject: [PATCH 2/8] Fix CPU autocast to use bfloat16 instead of float32 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed CPU device selection to use torch.bfloat16 for autocast, consistent with the repo's assumption of bfloat16-compatible hardware (2025 AD standard). This eliminates the warning: "CPU Autocast only supports dtype of torch.bfloat16, torch.float16" All devices (CUDA, MPS, CPU) now uniformly use bfloat16 for mixed precision training. πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- decoder_pytorch/device.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/decoder_pytorch/device.py b/decoder_pytorch/device.py index 85175f1..c817da9 100644 --- a/decoder_pytorch/device.py +++ b/decoder_pytorch/device.py @@ -53,8 +53,8 @@ def get_optimal_device( DeviceSelection(device, device_type, device_info, amp_dtype) Notes: - - For CUDA & MPS, the suggested autocast dtype is torch.bfloat16. - - CPU autocast is limited and typically uses float32. + - All devices (CUDA, MPS, CPU) use torch.bfloat16 for autocast. + - bfloat16-compatible hardware is assumed (2025 AD standard). - If `enable_tf32` is provided, it toggles TF32 on CUDA (Ampere+). """ log = logger or logging.getLogger(__name__) @@ -103,13 +103,13 @@ def get_optimal_device( if dev_type == "cpu": info = "CPU" device = torch.device("cpu") - amp_dtype = torch.float32 + amp_dtype = torch.bfloat16 log.info("Using %s (no GPU acceleration available)", info) return DeviceSelection(device, "cpu", info, amp_dtype) # Absolute fallback (shouldn't be reached) info = "CPU" device = torch.device("cpu") - amp_dtype = torch.float32 + amp_dtype = torch.bfloat16 log.warning("Falling back to CPU; no preferred devices were available.") return DeviceSelection(device, "cpu", info, amp_dtype) From b5ac79fbf28b044c488f0c064193ed5a41724929 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 25 Oct 2025 01:43:07 +0000 Subject: [PATCH 3/8] Add use_autocast config option for mixed precision control MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added configurable autocast support to allow users to enable/disable mixed precision training via config files without modifying code. Changes: - Add use_autocast config option to simple.yaml and test.yaml (default: true) - Update train.py to conditionally use autocast based on config - Use contextlib.nullcontext() when autocast is disabled - Print mixed precision status on startup Usage: use_autocast: true # Enable bfloat16 mixed precision (default) use_autocast: false # Disable, use full fp32 precision Both configurations tested successfully with no warnings. πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/simple.yaml | 1 + configs/test.yaml | 1 + train.py | 16 +++++++++++++--- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/configs/simple.yaml b/configs/simple.yaml index f544948..5e23c0e 100644 --- a/configs/simple.yaml +++ b/configs/simple.yaml @@ -12,6 +12,7 @@ tied_embedding: true # share in/out embeddings ffn_dim_multiplier: 1.5 # hidden dim multiplier --> FFN size (here, 768) flash_attn: true # use flash attn (through torch api) compile: false # speed up training +use_autocast: true # enable mixed precision (bfloat16) # Training num_batches: 100000 # total steps diff --git a/configs/test.yaml b/configs/test.yaml index 42e0dba..c8ceb1d 100644 --- a/configs/test.yaml +++ b/configs/test.yaml @@ -10,6 +10,7 @@ dim_head: 32 tied_embedding: true flash_attn: true compile: false +use_autocast: true # Quick training num_batches: 10 diff --git a/train.py b/train.py index 5ffb5c6..c515dc1 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ import argparse import gzip import json +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -70,6 +71,17 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): device = device_sel.device print(f"Device: {device_sel.device_info}") + # Setup autocast context + use_autocast = config.get("use_autocast", True) + if use_autocast: + autocast_ctx = lambda: torch.autocast( + device_type=device_sel.device_type, dtype=device_sel.amp_dtype + ) + print(f"Mixed precision: enabled ({device_sel.amp_dtype})") + else: + autocast_ctx = lambda: nullcontext() + print("Mixed precision: disabled (full fp32)") + if config.get("seed"): torch.manual_seed(config["seed"]) @@ -155,9 +167,7 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): inputs = data[:, :-1] targets = data[:, 1:] - with torch.autocast( - device_type=device_sel.device_type, dtype=device_sel.amp_dtype - ): + with autocast_ctx(): logits = model(inputs, return_loss=False) # Compute unnormalized loss loss_unreduced = torch.nn.functional.cross_entropy( From 4fa720b5b18ec947faf5880f2ee592176d480876 Mon Sep 17 00:00:00 2001 From: Peter Szemraj <74869040+pszemraj@users.noreply.github.com> Date: Sat, 25 Oct 2025 11:13:00 -0400 Subject: [PATCH 4/8] :sparkles: Ergonomic accelerator detection, safer selection, AMP hook, and docs * Add `DeviceSelection.autocast_context()`; parse `cuda:N`, dedupe prefs, and warn on bad input (decoder_pytorch/device.py:26,53). * Honor forced indices; guard out-of-range CUDA; loud CPU fallback for debug (decoder_pytorch/device.py:107). * Use context helper for train/val; fix E731; AMP toggle driven by config (train.py:74). * Document detection flow, `FORCE_DEVICE`, and autocast usage (README.md:27). --- README.md | 7 +++ decoder_pytorch/device.py | 110 +++++++++++++++++++++++++++++++++----- train.py | 11 ++-- 3 files changed, 107 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 3bb1c7c..d49a850 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,13 @@ python train.py --config configs/simple.yaml python train.py --config configs/test.yaml ``` +## Device Selection & Precision + +- The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda -> mps -> cpu` and prints a human-readable summary of the accelerator in use. +- Override detection with `FORCE_DEVICE=cuda`, `FORCE_DEVICE=cpu`, or even `FORCE_DEVICE=cuda:1` to pick a specific index (also available as the `force=` argument). +- The returned `DeviceSelection` exposes `.autocast_context(enabled=...)` so you can opt into bfloat16 autocast wherever it makes sense; configs set `use_autocast: true` by default. +- CUDA runs enable TF32 when requested (Ampere or newer); invalid device preferences are ignored with a debug log, keeping things easy to reason about. + ## Structure ```text diff --git a/decoder_pytorch/device.py b/decoder_pytorch/device.py index c817da9..1af7865 100644 --- a/decoder_pytorch/device.py +++ b/decoder_pytorch/device.py @@ -4,8 +4,9 @@ import logging import os +from contextlib import nullcontext from dataclasses import dataclass -from typing import Iterable, Optional, Tuple +from typing import ContextManager, Iterable, Optional, Tuple import torch @@ -22,6 +23,19 @@ class DeviceSelection: device_info: str # Human-readable description amp_dtype: torch.dtype # Suggested autocast dtype (bfloat16 on accel) + def autocast_context(self, enabled: bool = True) -> ContextManager[None]: + """ + Return a context manager for automatic mixed precision. + + Args: + enabled: Whether autocast should be enabled. If False, a no-op + context manager is returned. + """ + if not enabled: + return nullcontext() + + return torch.autocast(device_type=self.device_type, dtype=self.amp_dtype) + def _cuda_info(index: int) -> str: name = torch.cuda.get_device_name(index) @@ -34,6 +48,35 @@ def _mps_available() -> bool: return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() +def _parse_device_request(request: str) -> Tuple[Optional[str], Optional[int]]: + """ + Parse a device preference string and optional index. + + Supported formats: + "cuda", "cuda:1", "mps", "cpu" + """ + if request is None: + return None, None + + normalized = request.strip().lower() + if not normalized: + return None, None + + if ":" in normalized: + base, index_str = normalized.split(":", 1) + if base in ("cuda", "mps", "cpu"): + try: + return base, int(index_str) + except ValueError: + return base, None + return None, None + + if normalized in ("cuda", "mps", "cpu"): + return normalized, None + + return None, None + + def get_optimal_device( preferred_order: Iterable[str] = ("cuda", "mps", "cpu"), device_index: int = 0, @@ -47,7 +90,8 @@ def get_optimal_device( Priority is defined by `preferred_order` (default: CUDA > MPS > CPU). You can force a specific device via `force="cuda"|"mps"|"cpu"` or the - environment variable FORCE_DEVICE with the same values. + environment variable FORCE_DEVICE with the same values. Suffix an index + (e.g. "cuda:1") to target a specific accelerator. Returns: DeviceSelection(device, device_type, device_info, amp_dtype) @@ -56,24 +100,57 @@ def get_optimal_device( - All devices (CUDA, MPS, CPU) use torch.bfloat16 for autocast. - bfloat16-compatible hardware is assumed (2025 AD standard). - If `enable_tf32` is provided, it toggles TF32 on CUDA (Ampere+). + - Invalid entries in `preferred_order` are ignored with a debug log. """ log = logger or logging.getLogger(__name__) - choice = (force or os.getenv("FORCE_DEVICE", "")).lower().strip() - - # Normalize and validate preference list - normalized_order: Tuple[str, ...] = ( - tuple( - d - for d in (choice,) + tuple(preferred_order) - if d in ("cuda", "mps", "cpu") and d + force_raw = force or os.getenv("FORCE_DEVICE", "") + forced_type, forced_index = _parse_device_request(force_raw) + + if forced_index is not None: + device_index = forced_index + log.info("Device override requested: %s:%s", forced_type, device_index) + + if forced_type is None and force_raw: + log.warning( + "FORCE_DEVICE=%s ignored; supported values are 'cuda', 'mps', or 'cpu'.", + force_raw, ) - if choice - else tuple(d for d in preferred_order if d in ("cuda", "mps", "cpu")) + + # Normalize and validate preference list (deduplicate while preserving order) + candidate_order = [] + if forced_type: + candidate_order.append(forced_type) + + for entry in preferred_order: + parsed_type, _ = _parse_device_request(str(entry)) + if parsed_type: + candidate_order.append(parsed_type) + else: + log.debug("Ignoring unknown device preference entry: %s", entry) + + if not candidate_order: + candidate_order = ["cuda", "mps", "cpu"] + + seen = set() + normalized_order: Tuple[str, ...] = tuple( + dev for dev in candidate_order if not (dev in seen or seen.add(dev)) ) # Try each device in order for dev_type in normalized_order: if dev_type == "cuda" and torch.cuda.is_available(): + device_count = torch.cuda.device_count() + if device_index < 0 or device_index >= device_count: + available_range = ( + f"0-{device_count - 1}" if device_count else "no devices found" + ) + log.warning( + "Requested CUDA device index %s is out of range (available: %s).", + device_index, + available_range, + ) + continue + # Optional CUDA knobs if enable_tf32 is not None: try: @@ -108,8 +185,15 @@ def get_optimal_device( return DeviceSelection(device, "cpu", info, amp_dtype) # Absolute fallback (shouldn't be reached) + if normalized_order: + log.warning( + "No preferred devices were available (requested order: %s). Falling back to CPU.", + ", ".join(normalized_order), + ) + else: + log.warning("No device preferences provided. Falling back to CPU.") + info = "CPU" device = torch.device("cpu") amp_dtype = torch.bfloat16 - log.warning("Falling back to CPU; no preferred devices were available.") return DeviceSelection(device, "cpu", info, amp_dtype) diff --git a/train.py b/train.py index c515dc1..064b7cd 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,6 @@ import argparse import gzip import json -from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -72,14 +71,10 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): print(f"Device: {device_sel.device_info}") # Setup autocast context - use_autocast = config.get("use_autocast", True) + use_autocast = bool(config.get("use_autocast", True)) if use_autocast: - autocast_ctx = lambda: torch.autocast( - device_type=device_sel.device_type, dtype=device_sel.amp_dtype - ) print(f"Mixed precision: enabled ({device_sel.amp_dtype})") else: - autocast_ctx = lambda: nullcontext() print("Mixed precision: disabled (full fp32)") if config.get("seed"): @@ -167,7 +162,7 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): inputs = data[:, :-1] targets = data[:, 1:] - with autocast_ctx(): + with device_sel.autocast_context(enabled=use_autocast): logits = model(inputs, return_loss=False) # Compute unnormalized loss loss_unreduced = torch.nn.functional.cross_entropy( @@ -211,7 +206,7 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): data = next(val_loader).to(device) inputs = data[:, :-1] targets = data[:, 1:] - with torch.no_grad(): + with torch.no_grad(), device_sel.autocast_context(enabled=use_autocast): logits = model(inputs, return_loss=False) loss_unreduced = torch.nn.functional.cross_entropy( logits.reshape(-1, logits.size(-1)), From 27594bdb4c2cb0b0f494f894ff4085c275fc7b6d Mon Sep 17 00:00:00 2001 From: Peter Szemraj <74869040+pszemraj@users.noreply.github.com> Date: Sat, 25 Oct 2025 13:35:26 -0400 Subject: [PATCH 5/8] clarify, format --- decoder_pytorch/device.py | 1 - decoder_pytorch/utils.py | 71 +++++++++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/decoder_pytorch/device.py b/decoder_pytorch/device.py index 1af7865..d7647cb 100644 --- a/decoder_pytorch/device.py +++ b/decoder_pytorch/device.py @@ -10,7 +10,6 @@ import torch - __all__ = ["DeviceSelection", "get_optimal_device"] diff --git a/decoder_pytorch/utils.py b/decoder_pytorch/utils.py index be6a821..4a9c81c 100644 --- a/decoder_pytorch/utils.py +++ b/decoder_pytorch/utils.py @@ -151,13 +151,19 @@ class _LayerSummary: def model_summary( - model: nn.Module, max_depth: int = 4, show_param_shapes: bool = False + model: nn.Module, + max_depth: int = 4, + show_param_shapes: bool = False, + show_frozen_breakdown: bool = False, ) -> None: """Print hierarchical summary of model with parameter counts. :param model: PyTorch model to summarize :param max_depth: Maximum depth of hierarchy to display :param show_param_shapes: Whether to show parameter shapes + :param show_frozen_breakdown: If True, display separate trainable/frozen counts + per module. Defaults to False for a simpler view that highlights whether + a module is fully trainable, fully frozen, or mixed. """ # ---------- formatting helpers ---------- @@ -248,15 +254,60 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]: max(len(_format_shape(s.param_shape)) for s in summary_list), ) - params_col_width = 12 - trainable_col_width = 10 - col_spacing = " " + params_col_width = max( + len("Param #"), + max(len(_format_number(s.inclusive_total_params)) for s in summary_list), + ) header_parts = [f"{'Layer (type)':<{name_col_width}}"] if show_param_shapes: header_parts.append(f"{'Param Shape':>{shape_col_width}}") + header_parts.append(f"{'Param #':>{params_col_width}}") - header_parts.append(f"{'Trainable':>{trainable_col_width}}") + + if show_frozen_breakdown: + trainable_col_width = max( + len("Trainable #"), + max( + len(_format_number(s.inclusive_trainable_params)) for s in summary_list + ), + ) + frozen_col_width = max( + len("Frozen #"), + max( + len( + _format_number( + s.inclusive_total_params - s.inclusive_trainable_params + ) + ) + for s in summary_list + ), + ) + header_parts.append(f"{'Trainable #':>{trainable_col_width}}") + header_parts.append(f"{'Frozen #':>{frozen_col_width}}") + else: + + def _grad_state(total: int, trainable: int) -> str: + if trainable == 0: + return "frozen" + if trainable == total: + return "trainable" + return "mixed" + + grad_states = [ + _grad_state( + s.inclusive_total_params, + s.inclusive_trainable_params, + ) + for s in summary_list + ] + grad_state_width = max( + len("Grad State"), max(len(state) for state in grad_states) + ) + header_parts.append(f"{'Grad State':>{grad_state_width}}") + + col_spacing = " " + header = col_spacing.join(header_parts) sep = "=" * len(header) @@ -268,7 +319,15 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]: if show_param_shapes: parts.append(f"{_format_shape(e.param_shape):>{shape_col_width}}") parts.append(f"{_format_number(e.inclusive_total_params):>{params_col_width}}") - parts.append(f"{str(e.inclusive_trainable_params > 0):>{trainable_col_width}}") + if show_frozen_breakdown: + parts.append( + f"{_format_number(e.inclusive_trainable_params):>{trainable_col_width}}" + ) + frozen = e.inclusive_total_params - e.inclusive_trainable_params + parts.append(f"{_format_number(frozen):>{frozen_col_width}}") + else: + state = _grad_state(e.inclusive_total_params, e.inclusive_trainable_params) + parts.append(f"{state:>{grad_state_width}}") print(col_spacing.join(parts)) print(sep) print(f"Total params: {_format_number(total_params)}") From 3f35f0fc300ef097acb920613ec3858a76fe9ec2 Mon Sep 17 00:00:00 2001 From: Peter Szemraj <74869040+pszemraj@users.noreply.github.com> Date: Thu, 30 Oct 2025 22:22:26 -0400 Subject: [PATCH 6/8] Simplify device handling with tuple API and improve training stability - Replace device.py with lightweight tuple-based API and auto-fallbacks - Centralize device checks in training; guard autocast and document grad quirks - Graceful Flash Attention degradation when kernels unavailable - Add nano.yaml config for quick CPU/MPS testing - Update docs to reflect new device API and config --- README.md | 25 +++-- decoder_pytorch/__init__.py | 6 +- decoder_pytorch/device.py | 198 ------------------------------------ decoder_pytorch/llama.py | 19 ++++ decoder_pytorch/utils.py | 102 ++++++++++++++----- train.py | 55 +++++++--- 6 files changed, 156 insertions(+), 249 deletions(-) delete mode 100644 decoder_pytorch/device.py diff --git a/README.md b/README.md index d49a850..2e99aee 100644 --- a/README.md +++ b/README.md @@ -20,16 +20,26 @@ Train on the [included enwik8 dataset](data/README.md), character-level modeling # 100k batches on enwik8, 35M param Llama python train.py --config configs/simple.yaml -# Quick test (tiny model, 10 batches) +# Nano run for CPU / MPS quick checks (~30 seconds) +python train.py --config configs/nano.yaml + +# Quick smoke test (tiny model, 10 batches) python train.py --config configs/test.yaml ``` ## Device Selection & Precision -- The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda -> mps -> cpu` and prints a human-readable summary of the accelerator in use. +- The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda β†’ mps β†’ cpu`, returning `(device, device_type, amp_dtype)` and printing the accelerator picked. - Override detection with `FORCE_DEVICE=cuda`, `FORCE_DEVICE=cpu`, or even `FORCE_DEVICE=cuda:1` to pick a specific index (also available as the `force=` argument). -- The returned `DeviceSelection` exposes `.autocast_context(enabled=...)` so you can opt into bfloat16 autocast wherever it makes sense; configs set `use_autocast: true` by default. -- CUDA runs enable TF32 when requested (Ampere or newer); invalid device preferences are ignored with a debug log, keeping things easy to reason about. +- Mixed precision uses `torch.autocast` with `torch.bfloat16` when it provides a benefit; CPU and unstable MPS paths disable autocast automatically. + +## Device Support + +| Device | Status | Notes | +|---------------|--------|----------------------------------------------------| +| NVIDIA GPU | βœ… | Best performance, fused optimizer & flash attention | +| Apple Silicon | βœ… | Good performance, autocast can be flaky | +| CPU | βœ… | Slow but works; use `configs/nano.yaml` | ## Structure @@ -37,9 +47,10 @@ python train.py --config configs/test.yaml decoder-pytorch-template/ β”œβ”€β”€ decoder_pytorch/ # Model implementation β”‚ β”œβ”€β”€ llama.py # Llama architecture -β”‚ └── utils.py # Sampling utilities +β”‚ └── utils.py # Sampling & device helpers β”œβ”€β”€ configs/ # Training configs β”‚ β”œβ”€β”€ simple.yaml # Default config +β”‚ β”œβ”€β”€ nano.yaml # Quick CPU/MPS config β”‚ └── test.yaml # Quick test config β”œβ”€β”€ data/ β”‚ └── enwik8.gz # Character-level dataset @@ -68,7 +79,7 @@ To add your own model architecture: 4. **Update training script**: Modify `train.py` line 16 and 88: ```python - from decoder_pytorch import YourModel, configure_tf32, model_summary + from decoder_pytorch import YourModel, model_summary # ... model = YourModel( num_tokens=config.get("num_tokens", 256), @@ -120,7 +131,7 @@ Dependencies: - einops, pyyaml, tqdm - [rotary-embedding-torch](https://github.com/lucidrains/rotary-embedding-torch) -[^2]: If using PyTorch <2.9, you'll need to modify the TF32 configuration in `decoder_pytorch/utils.py` to use the legacy API (`torch.set_float32_matmul_precision("high")`) or skip TF32 setup entirely. +[^2]: If using PyTorch <2.9, you may need to adjust the bfloat16/autocast behaviour or fall back to full fp32 depending on hardware support. ## License diff --git a/decoder_pytorch/__init__.py b/decoder_pytorch/__init__.py index 72c3b46..71a7f68 100644 --- a/decoder_pytorch/__init__.py +++ b/decoder_pytorch/__init__.py @@ -1,9 +1,8 @@ """Llama-style transformer for language modeling experiments.""" -from .device import DeviceSelection, get_optimal_device from .llama import Llama from .utils import ( - configure_tf32, + get_optimal_device, gumbel_noise, gumbel_sample, log, @@ -23,9 +22,6 @@ "top_k_filter", "top_p_filter", # Torch utilities - "configure_tf32", "model_summary", - # Device utilities - "DeviceSelection", "get_optimal_device", ] diff --git a/decoder_pytorch/device.py b/decoder_pytorch/device.py deleted file mode 100644 index d7647cb..0000000 --- a/decoder_pytorch/device.py +++ /dev/null @@ -1,198 +0,0 @@ -# decoder_pytorch/device.py - -from __future__ import annotations - -import logging -import os -from contextlib import nullcontext -from dataclasses import dataclass -from typing import ContextManager, Iterable, Optional, Tuple - -import torch - -__all__ = ["DeviceSelection", "get_optimal_device"] - - -@dataclass(frozen=True) -class DeviceSelection: - """Result of device selection.""" - - device: torch.device # For .to() / model placement - device_type: str # For torch.autocast(device_type=...) - device_info: str # Human-readable description - amp_dtype: torch.dtype # Suggested autocast dtype (bfloat16 on accel) - - def autocast_context(self, enabled: bool = True) -> ContextManager[None]: - """ - Return a context manager for automatic mixed precision. - - Args: - enabled: Whether autocast should be enabled. If False, a no-op - context manager is returned. - """ - if not enabled: - return nullcontext() - - return torch.autocast(device_type=self.device_type, dtype=self.amp_dtype) - - -def _cuda_info(index: int) -> str: - name = torch.cuda.get_device_name(index) - major, minor = torch.cuda.get_device_capability(index) - return f"CUDA GPU: {name} (compute {major}.{minor})" - - -def _mps_available() -> bool: - # torch.backends.mps is present only on macOS builds - return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() - - -def _parse_device_request(request: str) -> Tuple[Optional[str], Optional[int]]: - """ - Parse a device preference string and optional index. - - Supported formats: - "cuda", "cuda:1", "mps", "cpu" - """ - if request is None: - return None, None - - normalized = request.strip().lower() - if not normalized: - return None, None - - if ":" in normalized: - base, index_str = normalized.split(":", 1) - if base in ("cuda", "mps", "cpu"): - try: - return base, int(index_str) - except ValueError: - return base, None - return None, None - - if normalized in ("cuda", "mps", "cpu"): - return normalized, None - - return None, None - - -def get_optimal_device( - preferred_order: Iterable[str] = ("cuda", "mps", "cpu"), - device_index: int = 0, - *, - force: Optional[str] = None, - enable_tf32: Optional[bool] = None, - logger: Optional[logging.Logger] = None, -) -> DeviceSelection: - """ - Detect and return the best available device for PyTorch. - - Priority is defined by `preferred_order` (default: CUDA > MPS > CPU). - You can force a specific device via `force="cuda"|"mps"|"cpu"` or the - environment variable FORCE_DEVICE with the same values. Suffix an index - (e.g. "cuda:1") to target a specific accelerator. - - Returns: - DeviceSelection(device, device_type, device_info, amp_dtype) - - Notes: - - All devices (CUDA, MPS, CPU) use torch.bfloat16 for autocast. - - bfloat16-compatible hardware is assumed (2025 AD standard). - - If `enable_tf32` is provided, it toggles TF32 on CUDA (Ampere+). - - Invalid entries in `preferred_order` are ignored with a debug log. - """ - log = logger or logging.getLogger(__name__) - force_raw = force or os.getenv("FORCE_DEVICE", "") - forced_type, forced_index = _parse_device_request(force_raw) - - if forced_index is not None: - device_index = forced_index - log.info("Device override requested: %s:%s", forced_type, device_index) - - if forced_type is None and force_raw: - log.warning( - "FORCE_DEVICE=%s ignored; supported values are 'cuda', 'mps', or 'cpu'.", - force_raw, - ) - - # Normalize and validate preference list (deduplicate while preserving order) - candidate_order = [] - if forced_type: - candidate_order.append(forced_type) - - for entry in preferred_order: - parsed_type, _ = _parse_device_request(str(entry)) - if parsed_type: - candidate_order.append(parsed_type) - else: - log.debug("Ignoring unknown device preference entry: %s", entry) - - if not candidate_order: - candidate_order = ["cuda", "mps", "cpu"] - - seen = set() - normalized_order: Tuple[str, ...] = tuple( - dev for dev in candidate_order if not (dev in seen or seen.add(dev)) - ) - - # Try each device in order - for dev_type in normalized_order: - if dev_type == "cuda" and torch.cuda.is_available(): - device_count = torch.cuda.device_count() - if device_index < 0 or device_index >= device_count: - available_range = ( - f"0-{device_count - 1}" if device_count else "no devices found" - ) - log.warning( - "Requested CUDA device index %s is out of range (available: %s).", - device_index, - available_range, - ) - continue - - # Optional CUDA knobs - if enable_tf32 is not None: - try: - torch.backends.cuda.matmul.allow_tf32 = bool(enable_tf32) - torch.backends.cudnn.allow_tf32 = bool(enable_tf32) - except Exception: - # TF32 toggles aren't critical-ignore if unavailable - pass - - try: - info = _cuda_info(device_index) - except Exception as e: - info = f"CUDA GPU (info unavailable: {e})" - - device = torch.device("cuda", index=device_index) - amp_dtype = torch.bfloat16 - log.info("Using %s", info) - return DeviceSelection(device, "cuda", info, amp_dtype) - - if dev_type == "mps" and _mps_available(): - info = "Apple Silicon (MPS)" - device = torch.device("mps") - amp_dtype = torch.bfloat16 - log.info("Using %s", info) - return DeviceSelection(device, "mps", info, amp_dtype) - - if dev_type == "cpu": - info = "CPU" - device = torch.device("cpu") - amp_dtype = torch.bfloat16 - log.info("Using %s (no GPU acceleration available)", info) - return DeviceSelection(device, "cpu", info, amp_dtype) - - # Absolute fallback (shouldn't be reached) - if normalized_order: - log.warning( - "No preferred devices were available (requested order: %s). Falling back to CPU.", - ", ".join(normalized_order), - ) - else: - log.warning("No device preferences provided. Falling back to CPU.") - - info = "CPU" - device = torch.device("cpu") - amp_dtype = torch.bfloat16 - return DeviceSelection(device, "cpu", info, amp_dtype) diff --git a/decoder_pytorch/llama.py b/decoder_pytorch/llama.py index dfb9d25..a8ffa32 100644 --- a/decoder_pytorch/llama.py +++ b/decoder_pytorch/llama.py @@ -140,6 +140,11 @@ def __init__( self.dim_head = dim_head self.causal = causal self.flash_attn = flash_attn + if self.flash_attn and not self._flash_attn_available(): + print( + "Warning: Flash attention requested but not available, using standard attention." + ) + self.flash_attn = False self.max_seq_len = max_seq_len inner_dim = heads * dim_head @@ -158,6 +163,20 @@ def __init__( # Register causal mask buffer (will be created on first use) self.register_buffer("causal_mask", None, persistent=False) + def _flash_attn_available(self) -> bool: + """Return True if flash attention kernels are available.""" + if torch.cuda.is_available(): + return True + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return True + dynamo = getattr(torch, "_dynamo", None) + if dynamo is not None: + try: + return bool(dynamo.is_compiling()) + except Exception: + return False + return False + def forward( self, x: torch.Tensor, diff --git a/decoder_pytorch/utils.py b/decoder_pytorch/utils.py index 4a9c81c..a44511a 100644 --- a/decoder_pytorch/utils.py +++ b/decoder_pytorch/utils.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple @@ -108,36 +109,85 @@ def top_p_filter(logits: Tensor, p: float = 0.9) -> Tensor: # -------------------------------------------------------------------------- -def configure_tf32() -> bool: - """Enable TF32 precision for GPUs with compute capability >= 8.0 (Ampere+). +def _mps_available() -> bool: + """Return True if MPS is available.""" + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() - Uses the PyTorch 2.9+ API for TF32 configuration. - :return: True if TF32 was enabled, False otherwise +def get_optimal_device( + force: Optional[str] = None, +) -> Tuple[torch.device, str, torch.dtype]: + """Return best available accelerator (device, device_type, amp_dtype). + + The function tries CUDA β†’ MPS β†’ CPU, unless the user forces a choice via + the ``force`` argument or the ``FORCE_DEVICE`` environment variable. The + return value is intentionally simpleβ€”a tuple that works well with tuple + unpacking in training scripts. """ - if not torch.cuda.is_available(): - print("No GPU detected, running on CPU.") - return False - - try: - device = torch.cuda.current_device() - capability = torch.cuda.get_device_capability(device) - major, minor = capability - gpu_name = torch.cuda.get_device_name(device) - - if major >= 8: - # PyTorch 2.9+ API for TF32 configuration - torch.backends.cudnn.conv.fp32_precision = "tf32" - torch.backends.cuda.matmul.fp32_precision = "tf32" - print(f"{gpu_name} (compute {major}.{minor}) - TF32 enabled") - return True - else: - print(f"{gpu_name} (compute {major}.{minor}) - TF32 not supported") - return False - except Exception as e: - print(f"Error: failed to configure GPU: {e}") - return False + def _normalize(device_str: str) -> str: + return device_str.split(":", 1)[0] + + requested = (force or os.getenv("FORCE_DEVICE", "")).strip().lower() + valid_types = {"cuda", "mps", "cpu"} + + if requested: + requested_type = _normalize(requested) + if requested_type not in valid_types: + print( + f"Warning: unsupported FORCE_DEVICE='{requested}'. " + "Falling back to auto-detect." + ) + requested = "" + elif requested_type == "cuda" and not torch.cuda.is_available(): + print("Warning: CUDA requested but not available, falling back.") + requested = "" + elif requested_type == "mps" and not _mps_available(): + print("Warning: MPS requested but not available, falling back.") + requested = "" + + if requested: + try: + device = torch.device(requested) + except (RuntimeError, ValueError) as err: + print(f"Warning: could not create device '{requested}' ({err}).") + requested = "" + else: + device_type = _normalize(requested) + if device_type == "cuda": + index = device.index or 0 + device_count = torch.cuda.device_count() + if index >= device_count: + print( + f"Warning: CUDA index {index} unavailable " + f"(found {device_count} device(s)). Falling back." + ) + requested = "" + else: + name = torch.cuda.get_device_name(index) + print(f"Using CUDA device {index}: {name}") + return device, "cuda", torch.bfloat16 + elif device_type == "mps": + print("Using Apple Silicon (MPS)") + return device, "mps", torch.bfloat16 + else: + print("Using CPU (forced)") + return device, "cpu", torch.bfloat16 + + if torch.cuda.is_available(): + device = torch.device("cuda") + name = torch.cuda.get_device_name(0) + print(f"Using CUDA: {name}") + return device, "cuda", torch.bfloat16 + + if _mps_available(): + device = torch.device("mps") + print("Using Apple Silicon (MPS)") + return device, "mps", torch.bfloat16 + + device = torch.device("cpu") + print("Using CPU (no GPU acceleration available)") + return device, "cpu", torch.bfloat16 @dataclass diff --git a/train.py b/train.py index 064b7cd..18fd0db 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ import argparse import gzip import json +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -66,16 +67,34 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): config = yaml.safe_load(f) # Setup device - device_sel = get_optimal_device(enable_tf32=True) # Auto-detects cuda/mps/cpu - device = device_sel.device - print(f"Device: {device_sel.device_info}") + device, device_type, amp_dtype = get_optimal_device() + print(f"Device: {device}") + + device_caps = { + "supports_fused_optimizer": device_type == "cuda", + "supports_flash_attn": device_type in ("cuda", "mps"), + "stable_autocast": device_type != "mps", + } # Setup autocast context use_autocast = bool(config.get("use_autocast", True)) - if use_autocast: - print(f"Mixed precision: enabled ({device_sel.amp_dtype})") - else: - print("Mixed precision: disabled (full fp32)") + if use_autocast and not device_caps["stable_autocast"]: + print("Warning: autocast can be unstable on MPS; disabling mixed precision.") + use_autocast = False + + if use_autocast and device_type == "cpu": + print("Warning: autocast has no benefit on CPU; disabling mixed precision.") + use_autocast = False + + def autocast_context(): + if use_autocast and device_type != "cpu": + return torch.autocast(device_type=device_type, dtype=amp_dtype) + return nullcontext() + + print( + f"Mixed precision: {'enabled' if use_autocast else 'disabled'}" + f"{f' ({amp_dtype})' if use_autocast else ' (full fp32)'}" + ) if config.get("seed"): torch.manual_seed(config["seed"]) @@ -94,6 +113,15 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): val_loader = cycle(val_loader) # Create model + flash_attn_requested = config.get("flash_attn") + if flash_attn_requested is None: + flash_attn_requested = device_caps["supports_flash_attn"] + elif flash_attn_requested and not device_caps["supports_flash_attn"]: + print( + "Warning: flash attention requested but not supported on this device; using standard attention." + ) + flash_attn_requested = False + model = Llama( num_tokens=config.get("num_tokens", 256), dim=config.get("dim", 512), @@ -102,19 +130,18 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): dim_head=config.get("dim_head", 64), tied_embedding=config.get("tied_embedding", True), ffn_dim_multiplier=config.get("ffn_dim_multiplier"), - flash_attn=config.get("flash_attn", True), + flash_attn=bool(flash_attn_requested), ).to(device) model_summary(model, max_depth=3, show_param_shapes=True) # Optimizer # Fused optimizer is available for CUDA only (not MPS or CPU) - use_fused = device_sel.device_type == "cuda" optimizer = Adam( model.parameters(), lr=config.get("learning_rate", 1e-3), weight_decay=config.get("weight_decay", 0.0), - fused=use_fused, + fused=device_caps["supports_fused_optimizer"], ) # Training state @@ -151,7 +178,9 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): model.train() # Training step with gradient accumulation - # Accumulate raw losses and token counts for correct normalization + # Accumulate raw losses and token counts for correct normalization. + # Critical pitfall: normalizing each micro-batch separately leads to + # incorrect gradients. Sum first, divide once at the end. total_loss_sum = 0 total_tokens = 0 loss_accumulator = None # Keep gradient graph alive @@ -162,7 +191,7 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): inputs = data[:, :-1] targets = data[:, 1:] - with device_sel.autocast_context(enabled=use_autocast): + with autocast_context(): logits = model(inputs, return_loss=False) # Compute unnormalized loss loss_unreduced = torch.nn.functional.cross_entropy( @@ -206,7 +235,7 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): data = next(val_loader).to(device) inputs = data[:, :-1] targets = data[:, 1:] - with torch.no_grad(), device_sel.autocast_context(enabled=use_autocast): + with torch.no_grad(), autocast_context(): logits = model(inputs, return_loss=False) loss_unreduced = torch.nn.functional.cross_entropy( logits.reshape(-1, logits.size(-1)), From f05282ae0044da0864ec7980de4cd118ab6e1a13 Mon Sep 17 00:00:00 2001 From: Peter Szemraj <74869040+pszemraj@users.noreply.github.com> Date: Thu, 30 Oct 2025 22:48:23 -0400 Subject: [PATCH 7/8] Enforce autocast and fix nano config alignment - Stop silently disabling autocast; always respect use_autocast flag - Wrap autocast context manager on all devices (no silent fp32 fallback) - Align nano.yaml to ~20M Llama with bf16 autocast enabled - Clarify autocast behavior in README --- README.md | 14 +++++++------- configs/nano.yaml | 36 ++++++++++++++++++++++++++++++++++++ configs/test.yaml | 3 +-- train.py | 10 +--------- 4 files changed, 45 insertions(+), 18 deletions(-) create mode 100644 configs/nano.yaml diff --git a/README.md b/README.md index 2e99aee..742705f 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Train on the [included enwik8 dataset](data/README.md), character-level modeling # 100k batches on enwik8, 35M param Llama python train.py --config configs/simple.yaml -# Nano run for CPU / MPS quick checks (~30 seconds) +# Nano run for CPU / MPS shakedowns (10k steps, ~20M params) python train.py --config configs/nano.yaml # Quick smoke test (tiny model, 10 batches) @@ -31,15 +31,15 @@ python train.py --config configs/test.yaml - The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda β†’ mps β†’ cpu`, returning `(device, device_type, amp_dtype)` and printing the accelerator picked. - Override detection with `FORCE_DEVICE=cuda`, `FORCE_DEVICE=cpu`, or even `FORCE_DEVICE=cuda:1` to pick a specific index (also available as the `force=` argument). -- Mixed precision uses `torch.autocast` with `torch.bfloat16` when it provides a benefit; CPU and unstable MPS paths disable autocast automatically. +- Mixed precision uses `torch.autocast` with `torch.bfloat16`; toggle via config if you want full fp32. ## Device Support -| Device | Status | Notes | -|---------------|--------|----------------------------------------------------| -| NVIDIA GPU | βœ… | Best performance, fused optimizer & flash attention | -| Apple Silicon | βœ… | Good performance, autocast can be flaky | -| CPU | βœ… | Slow but works; use `configs/nano.yaml` | +| Device | Status | Notes | +| ------------- | ------ | --------------------------------------------------- | +| NVIDIA GPU | βœ… | Best performance, fused optimizer & flash attention | +| Apple Silicon | βœ… | Good performance, autocast can be flaky | +| CPU | βœ… | Slow but works; use `configs/nano.yaml` | ## Structure diff --git a/configs/nano.yaml b/configs/nano.yaml new file mode 100644 index 0000000..d2d373f --- /dev/null +++ b/configs/nano.yaml @@ -0,0 +1,36 @@ +# Smaller than simple, larger than test +run_dir: runs/nano + +# Model +num_tokens: 256 +dim: 256 +depth: 8 +heads: 8 +dim_head: 32 +tied_embedding: true +ffn_dim_multiplier: 1.5 +flash_attn: true +compile: false # true to compile model (faster) +use_autocast: true + +# Training schedule +num_batches: 10000 +batch_size: 1 +grad_accum_every: 16 +learning_rate: 0.002 +weight_decay: 0.0003 +grad_clip_norm: 1.0 + +# Data +data_path: data/enwik8.gz +seq_len: 512 + +# training/validation/generation +validate_every: 250 +val_batches: 20 +generate_every: 250 +save_every: 2000 +temperature: 1.0 +min_p: 0.1 + +seed: 7 diff --git a/configs/test.yaml b/configs/test.yaml index c8ceb1d..6e0f54e 100644 --- a/configs/test.yaml +++ b/configs/test.yaml @@ -29,5 +29,4 @@ val_batches: 50 generate_every: 1000 # Don't generate during test save_every: 1000 # Don't save during test -# Random seed for reproducibility -seed: 42 \ No newline at end of file +seed: 7 \ No newline at end of file diff --git a/train.py b/train.py index 18fd0db..3381183 100644 --- a/train.py +++ b/train.py @@ -73,21 +73,13 @@ def train(config_path: str, resume_checkpoint: Optional[str] = None): device_caps = { "supports_fused_optimizer": device_type == "cuda", "supports_flash_attn": device_type in ("cuda", "mps"), - "stable_autocast": device_type != "mps", } # Setup autocast context use_autocast = bool(config.get("use_autocast", True)) - if use_autocast and not device_caps["stable_autocast"]: - print("Warning: autocast can be unstable on MPS; disabling mixed precision.") - use_autocast = False - - if use_autocast and device_type == "cpu": - print("Warning: autocast has no benefit on CPU; disabling mixed precision.") - use_autocast = False def autocast_context(): - if use_autocast and device_type != "cpu": + if use_autocast: return torch.autocast(device_type=device_type, dtype=amp_dtype) return nullcontext() From 13ab4ffec3a2b0b90a6ba423ddc7caa4c4d3ec24 Mon Sep 17 00:00:00 2001 From: Peter Szemraj <74869040+pszemraj@users.noreply.github.com> Date: Thu, 30 Oct 2025 22:57:37 -0400 Subject: [PATCH 8/8] =?UTF-8?q?Set=20nano=20preset=20to=20L6=C2=B7H384=20(?= =?UTF-8?q?~9M)=20and=20update=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update nano.yaml: depth 6, dim 384, torch.compile on - Clarify model scale in README --- README.md | 2 +- configs/nano.yaml | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 742705f..522ef7c 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Train on the [included enwik8 dataset](data/README.md), character-level modeling # 100k batches on enwik8, 35M param Llama python train.py --config configs/simple.yaml -# Nano run for CPU / MPS shakedowns (10k steps, ~20M params) +# Nano run for CPU / MPS shakedowns (10k steps, L6 Β· H384 Β· ~9M params) python train.py --config configs/nano.yaml # Quick smoke test (tiny model, 10 batches) diff --git a/configs/nano.yaml b/configs/nano.yaml index d2d373f..10e1e05 100644 --- a/configs/nano.yaml +++ b/configs/nano.yaml @@ -1,16 +1,16 @@ -# Smaller than simple, larger than test +# Smaller than simple, larger than test (L6 Β· H384 Β· ~9M params) run_dir: runs/nano # Model num_tokens: 256 -dim: 256 -depth: 8 -heads: 8 -dim_head: 32 +dim: 384 +depth: 6 +heads: 6 +dim_head: 64 tied_embedding: true ffn_dim_multiplier: 1.5 flash_attn: true -compile: false # true to compile model (faster) +compile: true # turn off if torch.compile is unavailable use_autocast: true # Training schedule