Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,37 @@ Train on the [included enwik8 dataset](data/README.md), character-level modeling
# 100k batches on enwik8, 35M param Llama
python train.py --config configs/simple.yaml

# Quick test (tiny model, 10 batches)
# Nano run for CPU / MPS shakedowns (10k steps, L6 · H384 · ~9M params)
python train.py --config configs/nano.yaml

# Quick smoke test (tiny model, 10 batches)
python train.py --config configs/test.yaml
```

## Device Selection & Precision

- The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda → mps → cpu`, returning `(device, device_type, amp_dtype)` and printing the accelerator picked.
- Override detection with `FORCE_DEVICE=cuda`, `FORCE_DEVICE=cpu`, or even `FORCE_DEVICE=cuda:1` to pick a specific index (also available as the `force=` argument).
- Mixed precision uses `torch.autocast` with `torch.bfloat16`; toggle via config if you want full fp32.

## Device Support

| Device | Status | Notes |
| ------------- | ------ | --------------------------------------------------- |
| NVIDIA GPU | ✅ | Best performance, fused optimizer & flash attention |
| Apple Silicon | ✅ | Good performance, autocast can be flaky |
| CPU | ✅ | Slow but works; use `configs/nano.yaml` |

## Structure

```text
decoder-pytorch-template/
├── decoder_pytorch/ # Model implementation
│ ├── llama.py # Llama architecture
│ └── utils.py # Sampling utilities
│ └── utils.py # Sampling & device helpers
├── configs/ # Training configs
│ ├── simple.yaml # Default config
│ ├── nano.yaml # Quick CPU/MPS config
│ └── test.yaml # Quick test config
├── data/
│ └── enwik8.gz # Character-level dataset
Expand Down Expand Up @@ -61,7 +79,7 @@ To add your own model architecture:
4. **Update training script**: Modify `train.py` line 16 and 88:

```python
from decoder_pytorch import YourModel, configure_tf32, model_summary
from decoder_pytorch import YourModel, model_summary
# ...
model = YourModel(
num_tokens=config.get("num_tokens", 256),
Expand Down Expand Up @@ -113,7 +131,7 @@ Dependencies:
- einops, pyyaml, tqdm
- [rotary-embedding-torch](https://github.com/lucidrains/rotary-embedding-torch)

[^2]: If using PyTorch <2.9, you'll need to modify the TF32 configuration in `decoder_pytorch/utils.py` to use the legacy API (`torch.set_float32_matmul_precision("high")`) or skip TF32 setup entirely.
[^2]: If using PyTorch <2.9, you may need to adjust the bfloat16/autocast behaviour or fall back to full fp32 depending on hardware support.

## License

Expand Down
36 changes: 36 additions & 0 deletions configs/nano.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Smaller than simple, larger than test (L6 · H384 · ~9M params)
run_dir: runs/nano

# Model
num_tokens: 256
dim: 384
depth: 6
heads: 6
dim_head: 64
tied_embedding: true
ffn_dim_multiplier: 1.5
flash_attn: true
compile: true # turn off if torch.compile is unavailable
use_autocast: true

# Training schedule
num_batches: 10000
batch_size: 1
grad_accum_every: 16
learning_rate: 0.002
weight_decay: 0.0003
grad_clip_norm: 1.0

# Data
data_path: data/enwik8.gz
seq_len: 512

# training/validation/generation
validate_every: 250
val_batches: 20
generate_every: 250
save_every: 2000
temperature: 1.0
min_p: 0.1

seed: 7
1 change: 1 addition & 0 deletions configs/simple.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tied_embedding: true # share in/out embeddings
ffn_dim_multiplier: 1.5 # hidden dim multiplier --> FFN size (here, 768)
flash_attn: true # use flash attn (through torch api)
compile: false # speed up training
use_autocast: true # enable mixed precision (bfloat16)

# Training
num_batches: 100000 # total steps
Expand Down
4 changes: 2 additions & 2 deletions configs/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dim_head: 32
tied_embedding: true
flash_attn: true
compile: false
use_autocast: true

# Quick training
num_batches: 10
Expand All @@ -28,5 +29,4 @@ val_batches: 50
generate_every: 1000 # Don't generate during test
save_every: 1000 # Don't save during test

# Random seed for reproducibility
seed: 42
seed: 7
4 changes: 2 additions & 2 deletions decoder_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .llama import Llama
from .utils import (
configure_tf32,
get_optimal_device,
gumbel_noise,
gumbel_sample,
log,
Expand All @@ -22,6 +22,6 @@
"top_k_filter",
"top_p_filter",
# Torch utilities
"configure_tf32",
"model_summary",
"get_optimal_device",
]
19 changes: 19 additions & 0 deletions decoder_pytorch/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def __init__(
self.dim_head = dim_head
self.causal = causal
self.flash_attn = flash_attn
if self.flash_attn and not self._flash_attn_available():
print(
"Warning: Flash attention requested but not available, using standard attention."
)
self.flash_attn = False
self.max_seq_len = max_seq_len

inner_dim = heads * dim_head
Expand All @@ -158,6 +163,20 @@ def __init__(
# Register causal mask buffer (will be created on first use)
self.register_buffer("causal_mask", None, persistent=False)

def _flash_attn_available(self) -> bool:
"""Return True if flash attention kernels are available."""
if torch.cuda.is_available():
return True
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return True
dynamo = getattr(torch, "_dynamo", None)
if dynamo is not None:
try:
return bool(dynamo.is_compiling())
except Exception:
return False
return False

def forward(
self,
x: torch.Tensor,
Expand Down
173 changes: 141 additions & 32 deletions decoder_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -108,36 +109,85 @@ def top_p_filter(logits: Tensor, p: float = 0.9) -> Tensor:
# --------------------------------------------------------------------------


def configure_tf32() -> bool:
"""Enable TF32 precision for GPUs with compute capability >= 8.0 (Ampere+).
def _mps_available() -> bool:
"""Return True if MPS is available."""
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()

Uses the PyTorch 2.9+ API for TF32 configuration.

:return: True if TF32 was enabled, False otherwise
def get_optimal_device(
force: Optional[str] = None,
) -> Tuple[torch.device, str, torch.dtype]:
"""Return best available accelerator (device, device_type, amp_dtype).

The function tries CUDA → MPS → CPU, unless the user forces a choice via
the ``force`` argument or the ``FORCE_DEVICE`` environment variable. The
return value is intentionally simple—a tuple that works well with tuple
unpacking in training scripts.
"""
if not torch.cuda.is_available():
print("No GPU detected, running on CPU.")
return False

try:
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability(device)
major, minor = capability
gpu_name = torch.cuda.get_device_name(device)

if major >= 8:
# PyTorch 2.9+ API for TF32 configuration
torch.backends.cudnn.conv.fp32_precision = "tf32"
torch.backends.cuda.matmul.fp32_precision = "tf32"
print(f"{gpu_name} (compute {major}.{minor}) - TF32 enabled")
return True
else:
print(f"{gpu_name} (compute {major}.{minor}) - TF32 not supported")
return False

except Exception as e:
print(f"Error: failed to configure GPU: {e}")
return False
def _normalize(device_str: str) -> str:
return device_str.split(":", 1)[0]

requested = (force or os.getenv("FORCE_DEVICE", "")).strip().lower()
valid_types = {"cuda", "mps", "cpu"}

if requested:
requested_type = _normalize(requested)
if requested_type not in valid_types:
print(
f"Warning: unsupported FORCE_DEVICE='{requested}'. "
"Falling back to auto-detect."
)
requested = ""
elif requested_type == "cuda" and not torch.cuda.is_available():
print("Warning: CUDA requested but not available, falling back.")
requested = ""
elif requested_type == "mps" and not _mps_available():
print("Warning: MPS requested but not available, falling back.")
requested = ""

if requested:
try:
device = torch.device(requested)
except (RuntimeError, ValueError) as err:
print(f"Warning: could not create device '{requested}' ({err}).")
requested = ""
else:
device_type = _normalize(requested)
if device_type == "cuda":
index = device.index or 0
device_count = torch.cuda.device_count()
if index >= device_count:
print(
f"Warning: CUDA index {index} unavailable "
f"(found {device_count} device(s)). Falling back."
)
requested = ""
else:
name = torch.cuda.get_device_name(index)
print(f"Using CUDA device {index}: {name}")
return device, "cuda", torch.bfloat16
elif device_type == "mps":
print("Using Apple Silicon (MPS)")
return device, "mps", torch.bfloat16
else:
print("Using CPU (forced)")
return device, "cpu", torch.bfloat16

if torch.cuda.is_available():
device = torch.device("cuda")
name = torch.cuda.get_device_name(0)
print(f"Using CUDA: {name}")
return device, "cuda", torch.bfloat16

if _mps_available():
device = torch.device("mps")
print("Using Apple Silicon (MPS)")
return device, "mps", torch.bfloat16

device = torch.device("cpu")
print("Using CPU (no GPU acceleration available)")
return device, "cpu", torch.bfloat16


@dataclass
Expand All @@ -151,13 +201,19 @@ class _LayerSummary:


def model_summary(
model: nn.Module, max_depth: int = 4, show_param_shapes: bool = False
model: nn.Module,
max_depth: int = 4,
show_param_shapes: bool = False,
show_frozen_breakdown: bool = False,
) -> None:
"""Print hierarchical summary of model with parameter counts.

:param model: PyTorch model to summarize
:param max_depth: Maximum depth of hierarchy to display
:param show_param_shapes: Whether to show parameter shapes
:param show_frozen_breakdown: If True, display separate trainable/frozen counts
per module. Defaults to False for a simpler view that highlights whether
a module is fully trainable, fully frozen, or mixed.
"""

# ---------- formatting helpers ----------
Expand Down Expand Up @@ -248,15 +304,60 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]:
max(len(_format_shape(s.param_shape)) for s in summary_list),
)

params_col_width = 12
trainable_col_width = 10
col_spacing = " "
params_col_width = max(
len("Param #"),
max(len(_format_number(s.inclusive_total_params)) for s in summary_list),
)

header_parts = [f"{'Layer (type)':<{name_col_width}}"]
if show_param_shapes:
header_parts.append(f"{'Param Shape':>{shape_col_width}}")

header_parts.append(f"{'Param #':>{params_col_width}}")
header_parts.append(f"{'Trainable':>{trainable_col_width}}")

if show_frozen_breakdown:
trainable_col_width = max(
len("Trainable #"),
max(
len(_format_number(s.inclusive_trainable_params)) for s in summary_list
),
)
frozen_col_width = max(
len("Frozen #"),
max(
len(
_format_number(
s.inclusive_total_params - s.inclusive_trainable_params
)
)
for s in summary_list
),
)
header_parts.append(f"{'Trainable #':>{trainable_col_width}}")
header_parts.append(f"{'Frozen #':>{frozen_col_width}}")
else:

def _grad_state(total: int, trainable: int) -> str:
if trainable == 0:
return "frozen"
if trainable == total:
return "trainable"
return "mixed"

grad_states = [
_grad_state(
s.inclusive_total_params,
s.inclusive_trainable_params,
)
for s in summary_list
]
grad_state_width = max(
len("Grad State"), max(len(state) for state in grad_states)
)
header_parts.append(f"{'Grad State':>{grad_state_width}}")

col_spacing = " "

header = col_spacing.join(header_parts)
sep = "=" * len(header)

Expand All @@ -268,7 +369,15 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]:
if show_param_shapes:
parts.append(f"{_format_shape(e.param_shape):>{shape_col_width}}")
parts.append(f"{_format_number(e.inclusive_total_params):>{params_col_width}}")
parts.append(f"{str(e.inclusive_trainable_params > 0):>{trainable_col_width}}")
if show_frozen_breakdown:
parts.append(
f"{_format_number(e.inclusive_trainable_params):>{trainable_col_width}}"
)
frozen = e.inclusive_total_params - e.inclusive_trainable_params
parts.append(f"{_format_number(frozen):>{frozen_col_width}}")
else:
state = _grad_state(e.inclusive_total_params, e.inclusive_trainable_params)
parts.append(f"{state:>{grad_state_width}}")
print(col_spacing.join(parts))
print(sep)
print(f"Total params: {_format_number(total_params)}")
Expand Down
Loading