Skip to content

[ROCm] MI300X Tunable ops causes 100GB of Memory Leak leading to OOM #138532

@functionstackx

Description

@functionstackx

🐛 Describe the bug

Hi AMD Team,

Turning on tunableop causes OOM. When not enabling it, it does not OOM. I have tried it on torch nightly.

Can you take a look?

Thanks,
Oren

cc: @hliuca

Without enabling tunable ops working Command

python ./reprod_oom.py --bsz=28

Enabling Tunable Op causes OOM Command

PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_VERBOSE=1 python ./reprod_oom.py --bsz=28

Stack Trace

Traceback (most recent call last):
  File "/workspace/llm-train-bench/./reprod_oom.py", line 99, in <module>
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/workspace/llm-train-bench/./reprod_oom.py", line 87, in train
    
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/llm-train-bench/./reprod_oom.py", line 57, in forward
    
torch.OutOfMemoryError: HIP out of memory. Tried to allocate 2.69 GiB. GPU 0 has a total capacity of 191.98 GiB of which 170.00 MiB is free. Of the allocated memory 97.66 GiB is allocated by PyTorch, and 2.30 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
additional tuning results available, rewriting file tunableop_results0.csv

Reprod Script

import torch
import torch.nn.functional as F
import torch.nn as nn

class CausalSelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__()
        self.d_head = d_embd // n_heads  # D
        self.attn_proj = nn.Linear(d_embd, 3*d_embd)
        self.out_proj = nn.Linear(d_embd, d_embd)
 
    def forward(self, x_BTE):
        qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
        split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
        q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
        o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
        o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
        y_BTE = self.out_proj(o_BTE)
        return y_BTE

class GPTBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_embd)
        self.attn = CausalSelfAttention(d_embd, **kwargs)
        self.ffn_norm = nn.LayerNorm(d_embd)
        self.ffn = nn.Sequential(
            nn.Linear(d_embd, 4*d_embd),
            nn.GELU(),
            nn.Linear(4*d_embd, d_embd)
        )

    def forward(self, x_BTE):
        x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
        y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
        return y_BTE

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV


def train(
    bsz: int = 28,
):
    torch.manual_seed(3985)
    torch.cuda.set_device(0)

    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
    }

    model = GPT(**cfg_json).to('cuda:0')

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)

    model.train()

    for step_idx in range(100):
        input_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
        label_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')

        with torch.amp.autocast('cuda', torch.bfloat16):
            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        torch.cuda.synchronize()
        print(f"finish {step_idx} step")

if __name__ == "__main__":
    import fire
    fire.Fire(train)

Versions

Pytorch Nightly Version

$ pip list | grep torch
pytorch-triton-rocm     3.1.0+cf34004b8a
torch                   2.6.0.dev20241021+rocm6.2
torchvision             0.18.0a0+68ba7ec

Dockerfile (uses 2.3 release as base image then install pytorch nightly)

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install -y nano wget

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Metadata

Metadata

Assignees

Labels

module: rocmAMD GPU support for Pytorchrocm priorityhigh priority ROCm PRs from performance or other aspectstriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions