Skip to content

torch DDP oom caused by weak protocol #106294

@leiwen83

Description

@leiwen83

🐛 Describe the bug

Current torch DDP distributed trainiing protocol is weak, that it use simple tcp protocol listening over master port to choose action to exectue. While it is efficient, but it may cause issue when abonormal network traffic comes.

Here we met an OOM issue, which cause machine quickly run out cpu memory. And we root cause this issue to be torch DDP mistake nmap scan message as addHandler message, and create a very large mmaped pool hold to prepare the incoming message.

Process to reproduce DDP OOM:

nmap -p[master port] -sS -sV [torch DDP master node IP]

The trigger trace is as:

https://github.com/pytorch/pytorch/blob/v2.0.1/torch/csrc/distributed/c10d/TCPStore.cpp#L588
-> https://github.com/pytorch/pytorch/blob/v2.0.1/torch/csrc/distributed/c10d/TCPStore.cpp#L281
-> https://github.com/pytorch/pytorch/blob/v2.0.1/torch/csrc/distributed/c10d/TCPStore.cpp#L384
-> https://github.com/pytorch/pytorch/blob/v2.0.1/torch/csrc/distributed/c10d/Utils.hpp#L659

inline std::string recvString(int socket) {
  SizeType valueSize;
  recvBytes<SizeType>(socket, &valueSize, 1);
  std::vector<char> value(valueSize);
  recvBytes<char>(socket, value.data(), value.size());
  return std::string(value.data(), value.size());
}

@recvString, it would parse nmap message into a very large mesage to recv, more than 1T, which lead torch to request 1T+ memory to system, leading to OOM.

It is very common to use nmap to scan port in data center, so shall we consider to make torch DDP protocol more robust to such kind of "attack"? I think only add some magic number would greatly reduce such issue.

Thx

Versions

Seems to me, since DDP is supported, this issue is existed.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: ddpIssues/PRs related distributed data parallel trainingmodule: memory usagePyTorch is using more memory than it should, or it is leaking memoryoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions