-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: rocmAMD GPU support for PytorchAMD GPU support for Pytorchrocm priorityhigh priority ROCm PRs from performance or other aspectshigh priority ROCm PRs from performance or other aspectstriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Hi AMD Team,
Turning on tunableop causes OOM. When not enabling it, it does not OOM. I have tried it on torch nightly.
Can you take a look?
Thanks,
Oren
cc: @hliuca
Without enabling tunable ops working Command
python ./reprod_oom.py --bsz=28Enabling Tunable Op causes OOM Command
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_VERBOSE=1 python ./reprod_oom.py --bsz=28Stack Trace
Traceback (most recent call last):
File "/workspace/llm-train-bench/./reprod_oom.py", line 99, in <module>
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/workspace/llm-train-bench/./reprod_oom.py", line 87, in train
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/llm-train-bench/./reprod_oom.py", line 57, in forward
torch.OutOfMemoryError: HIP out of memory. Tried to allocate 2.69 GiB. GPU 0 has a total capacity of 191.98 GiB of which 170.00 MiB is free. Of the allocated memory 97.66 GiB is allocated by PyTorch, and 2.30 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
additional tuning results available, rewriting file tunableop_results0.csvReprod Script
import torch
import torch.nn.functional as F
import torch.nn as nn
class CausalSelfAttention(nn.Module):
def __init__(self, d_embd, n_heads, **kwargs):
super().__init__()
self.d_head = d_embd // n_heads # D
self.attn_proj = nn.Linear(d_embd, 3*d_embd)
self.out_proj = nn.Linear(d_embd, d_embd)
def forward(self, x_BTE):
qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
y_BTE = self.out_proj(o_BTE)
return y_BTE
class GPTBlock(nn.Module):
def __init__(self, d_embd, **kwargs):
super().__init__()
self.attn_norm = nn.LayerNorm(d_embd)
self.attn = CausalSelfAttention(d_embd, **kwargs)
self.ffn_norm = nn.LayerNorm(d_embd)
self.ffn = nn.Sequential(
nn.Linear(d_embd, 4*d_embd),
nn.GELU(),
nn.Linear(4*d_embd, d_embd)
)
def forward(self, x_BTE):
x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
return y_BTE
class GPT(nn.Module):
def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
super().__init__()
self.tok_embd = nn.Embedding(vocab_size, d_embd)
self.pos_embd = nn.Embedding(max_seq_len, d_embd)
self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
self.out_norm = nn.LayerNorm(d_embd)
def forward(self, idx_BT):
pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)
for tsfmr_blk in self.tsfmr_blks:
x_BTE = tsfmr_blk(x_BTE)
x_BTE = self.out_norm(x_BTE)
logits_BTV = x_BTE @ self.tok_embd.weight.T # Weight tying
return logits_BTV
def train(
bsz: int = 28,
):
torch.manual_seed(3985)
torch.cuda.set_device(0)
cfg_json = {
"n_layers": 48,
"n_heads": 25,
"d_embd": 1600,
"max_seq_len": 1024,
"vocab_size": 50304,
}
model = GPT(**cfg_json).to('cuda:0')
optimizer = torch.optim.AdamW(model.parameters(), fused=True)
model.train()
for step_idx in range(100):
input_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
label_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
with torch.amp.autocast('cuda', torch.bfloat16):
logits_BTV = model(input_BT)
loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
torch.cuda.synchronize()
print(f"finish {step_idx} step")
if __name__ == "__main__":
import fire
fire.Fire(train)Versions
Pytorch Nightly Version
$ pip list | grep torch
pytorch-triton-rocm 3.1.0+cf34004b8a
torch 2.6.0.dev20241021+rocm6.2
torchvision 0.18.0a0+68ba7ecDockerfile (uses 2.3 release as base image then install pytorch nightly)
FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0
RUN apt install -y nano wget
RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic pybind11
RUN pip3 uninstall -y torch
RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
WORKDIR /workspace/llm-train-bench/
CMD ["/usr/bin/bash"]cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd
Metadata
Metadata
Assignees
Labels
module: rocmAMD GPU support for PytorchAMD GPU support for Pytorchrocm priorityhigh priority ROCm PRs from performance or other aspectshigh priority ROCm PRs from performance or other aspectstriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Type
Projects
Status
Done