-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Open
Description
Your issue may already be reported!
Please search on the issue tracker before creating one.
Context
- Pytorch version: 2.7.1
- Operating System and version: Ubuntu
Your Environment
- Installed using source? [yes/no]: no
- Are you planning to deploy it using docker container? [yes/no]: no
- Is it a CPU or GPU environment?: yes
- Which example are you using: fsdp2
- Link to code or data to repro [if any]: https://github.com/pytorch/examples/blob/main/distributed/FSDP2/utils.py
Expected Behavior
def inspect_mixed_precision(model: FSDPModule):
model.unshard()
for param in model.parameters(recurse=False):
assert param.dtype == torch.bfloat16
model.reshard()
This function is highly misleading. In the provided example, the Transformer model itself has no direct parameters, all parameters are contained within its submodules. As a result, the current loop only inspects the model’s direct parameters, which means it effectively checks nothing. I would expect it to inspect some actual params.
Current Behavior
As stated above.
Possible Solution
Steps to Reproduce
- Replace the assertion with printing
- Run the example
or use this standalone script:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
@dataclass
class ModelArgs:
n_layers: int = 2
vocab_size: int = 8
max_seq_len: int = 16
dim: int = 16
n_heads: int = 4
dropout_p: float = 0.1
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.dim % args.n_heads == 0
self.head_dim = args.dim // args.n_heads
self.n_heads = args.n_heads
self.dropout_p = args.dropout_p
self.resid_dropout = nn.Dropout(args.dropout_p)
self.wq = nn.Linear(args.dim, args.dim, bias=False)
self.wk = nn.Linear(args.dim, args.dim, bias=False)
self.wv = nn.Linear(args.dim, args.dim, bias=False)
self.wo = nn.Linear(args.dim, args.dim, bias=False)
def forward(self, x):
bsz, seq_len, _ = x.size()
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
output = F.scaled_dot_product_attention(
queries,
keys,
values,
None,
self.dropout_p if self.training else 0,
)
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.resid_dropout(self.wo(output))
def reset_parameters(self):
self.wq.reset_parameters()
self.wk.reset_parameters()
self.wv.reset_parameters()
self.wo.reset_parameters()
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout_p):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim)
self.gelu = nn.GELU()
self.w2 = nn.Linear(hidden_dim, dim)
self.resid_dropout = nn.Dropout(dropout_p)
def forward(self, x):
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
def reset_parameters(self):
self.w1.reset_parameters()
self.w2.reset_parameters()
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention_norm = nn.LayerNorm(args.dim)
self.attention = Attention(args)
self.ffn_norm = nn.LayerNorm(args.dim)
self.feed_forward = FeedForward(
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
)
def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
def reset_parameters(self):
self.attention_norm.reset_parameters()
self.attention.reset_parameters()
self.ffn_norm.reset_parameters()
self.feed_forward.reset_parameters()
# A toy transformer model, partly inspired by the nanoGPT model:
# https://github.com/karpathy/nanoGPT.
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size is not None
assert args.max_seq_len is not None
self.model_args = args
self.max_seq_len = args.max_seq_len
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
self.dropout = nn.Dropout(args.dropout_p)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
self.norm = nn.LayerNorm(args.dim)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def forward(self, tokens):
_bsz, seq_len = tokens.size()
assert seq_len <= self.max_seq_len
h = self.tok_embeddings(tokens)
pos = torch.arange(0, seq_len, device=tokens.device)
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
h = h + p
h = self.dropout(h)
for layer in self.layers:
h = layer(h)
h = self.norm(h)
output = self.output(h).float()
return output
def reset_parameters(self):
self.tok_embeddings.reset_parameters()
self.pos_embeddings.reset_parameters()
self.norm.reset_parameters()
self.output.reset_parameters()
model = Transformer(ModelArgs())
fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
}
for layer in model.layers:
fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
# sharded parameters are float32
for name, param in model.named_parameters():
# print("local", name, param.dtype, param.device)
pass
# unsharded parameters are bfloat16
model.unshard()
for name, param in model.named_parameters(recurse=False):
print("unsharded", name, param.dtype, param.device)
model.reshard()
Failure Logs [if any]
[]
Metadata
Metadata
Assignees
Labels
No labels