Skip to content

Fix #514 #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
@@ -270,13 +270,12 @@ static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_pt
mscclpp::TransportFlags transport) {
std::vector<mscclpp::RegisteredMemory> remoteMemories;
mscclpp::RegisteredMemory memory = comm->registerMemory(buff, bytes, transport);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
for (int i = 0; i < comm->bootstrap()->getNranks(); i++) {
if (i == rank) continue;
remoteRegMemoryFutures.push_back(comm->recvMemoryOnSetup(i, 0));
comm->sendMemoryOnSetup(memory, i, 0);
remoteRegMemoryFutures.push_back(comm->recvMemory(i, 0));
comm->sendMemory(memory, i, 0);
}
comm->setup();
std::transform(remoteRegMemoryFutures.begin(), remoteRegMemoryFutures.end(), std::back_inserter(remoteMemories),
[](const auto& future) { return future.get(); });
return remoteMemories;
@@ -602,15 +601,13 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,

static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_ptr<mscclpp::Communicator> mscclppComm,
int rank) {
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;

for (int i = 0; i < mscclppComm->bootstrap()->getNranks(); i++) {
if (i == rank) continue;
mscclpp::Transport transport = getTransport(rank, i);
connectionFutures.push_back(mscclppComm->connectOnSetup(i, 0, transport));
connectionFutures.push_back(mscclppComm->connect(i, 0, transport));
}
mscclppComm->setup();

std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const auto& future) { return future.get(); });
@@ -625,7 +622,6 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
}
}

mscclppComm->setup();
commPtr->connections = std::move(connections);
if (mscclpp::isNvlsSupported()) {
commPtr->nvlsConnections = setupNvlsConnections(commPtr, NVLS_BUFFER_SIZE);
14 changes: 5 additions & 9 deletions docs/getting-started/tutorials/initialization.md
Original file line number Diff line number Diff line change
@@ -32,29 +32,25 @@ void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {

std::vector<mscclpp::SemaphoreId> semaphoreIds;
std::vector<mscclpp::RegisteredMemory> localMemories;
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemories;

for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
// Connect with all other ranks
connections[r] = comm.connectOnSetup(r, 0, transport);
connections[r] = comm.connect(r, 0, transport);
auto memory = comm.registerMemory(data, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
localMemories.push_back(memory);
comm.sendMemoryOnSetup(memory, r, 0);
remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0));
comm.sendMemory(memory, r, 0);
remoteMemories.push_back(comm.recvMemory(r, 0));
}

comm.setup();

for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
}

comm.setup();

std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
portChannels.push_back(mscclpp::deviceHandle(mscclpp::PortChannel(
13 changes: 5 additions & 8 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
@@ -101,13 +101,12 @@ def make_connection(
if endpoint.transport == Transport.Nvls:
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
else:
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
connections[rank] = self.communicator.connect(rank, 0, endpoint)
connections = {rank: connections[rank].get() for rank in connections}
return connections

def register_tensor_with_connections(
self, tensor: Type[cp.ndarray] or Type[np.ndarray], connections: dict[int, Connection]
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
) -> dict[int, RegisteredMemory]:
transport_flags = TransportFlags()
for rank in connections:
@@ -125,22 +124,20 @@ def register_tensor_with_connections(
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
for rank in connections:
self.communicator.send_memory_on_setup(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory_on_setup(rank, 0)
self.communicator.setup()
self.communicator.send_memory(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory(rank, 0)
for rank in connections:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories

def make_semaphore(
self,
connections: dict[int, Connection],
semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[MemoryDevice2DeviceSemaphore],
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
) -> dict[int, Host2HostSemaphore]:
semaphores = {}
for rank in connections:
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
self.communicator.setup()
return semaphores

def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
23 changes: 12 additions & 11 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
@@ -27,9 +27,9 @@ extern void register_npkit(nb::module_& m);
extern void register_gpu_utils(nb::module_& m);

template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str()).def("get", &NonblockingFuture<T>::get);
void def_shared_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("shared_future_") + typestr;
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
}

void register_core(nb::module_& m) {
@@ -158,8 +158,8 @@ void register_core(nb::module_& m) {
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));

def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");

nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
@@ -172,14 +172,15 @@ void register_core(nb::module_& m) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
.def("setup", [](Communicator*) {});
}

NB_MODULE(_mscclpp, m) {
5 changes: 3 additions & 2 deletions src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
@@ -52,14 +52,14 @@ MSCCLPP_API_CPP void Bootstrap::groupBarrier(const std::vector<int>& ranks) {
MSCCLPP_API_CPP void Bootstrap::send(const std::vector<char>& data, int peer, int tag) {
size_t size = data.size();
send((void*)&size, sizeof(size_t), peer, tag);
send((void*)data.data(), data.size(), peer, tag + 1);
send((void*)data.data(), data.size(), peer, tag);
}

MSCCLPP_API_CPP void Bootstrap::recv(std::vector<char>& data, int peer, int tag) {
size_t size;
recv((void*)&size, sizeof(size_t), peer, tag);
data.resize(size);
recv((void*)data.data(), data.size(), peer, tag + 1);
recv((void*)data.data(), data.size(), peer, tag);
}

struct UniqueIdInternal {
@@ -528,6 +528,7 @@ std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag)
if (recvPeer == peer && recvTag == tag) {
return sock;
}
// TODO(chhwang): set an exit condition or timeout
}
}

67 changes: 50 additions & 17 deletions src/communicator.cc
Original file line number Diff line number Diff line change
@@ -17,6 +17,22 @@ Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<C
}
}

void Communicator::Impl::setLastRecvItem(int remoteRank, int tag, std::shared_ptr<BaseRecvItem> item) {
lastRecvItems_[{remoteRank, tag}] = item;
}

std::shared_ptr<BaseRecvItem> Communicator::Impl::getLastRecvItem(int remoteRank, int tag) {
auto it = lastRecvItems_.find({remoteRank, tag});
if (it == lastRecvItems_.end()) {
return nullptr;
}
if (it->second->isReady()) {
lastRecvItems_.erase(it);
return nullptr;
}
return it->second;
}

MSCCLPP_API_CPP Communicator::~Communicator() = default;

MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context)
@@ -31,30 +47,47 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
}

MSCCLPP_API_CPP void Communicator::sendMemory(RegisteredMemory memory, int remoteRank, int tag) {
pimpl_->bootstrap_->send(memory.serialize(), remoteRank, tag);
bootstrap()->send(memory.serialize(), remoteRank, tag);
}

MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(int remoteRank, int tag) {
return std::async(std::launch::deferred, [this, remoteRank, tag]() {
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
return RegisteredMemory::deserialize(data);
});
auto future = std::async(std::launch::deferred,
[this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
return RegisteredMemory::deserialize(data);
});
auto shared_future = std::shared_future<RegisteredMemory>(std::move(future));
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<RegisteredMemory>>(shared_future));
return shared_future;
}

MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
EndpointConfig localConfig) {
auto localEndpoint = pimpl_->context_->createEndpoint(localConfig);
pimpl_->bootstrap_->send(localEndpoint.serialize(), remoteRank, tag);

return std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint = std::move(localEndpoint)]() mutable {
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
auto localEndpoint = context()->createEndpoint(localConfig);
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);

auto future =
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
localEndpoint = std::move(localEndpoint)]() mutable {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
return shared_future;
}

MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
13 changes: 5 additions & 8 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
@@ -212,13 +212,12 @@ struct Executor::Impl {
void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan, size_t sendBufferSize,
size_t recvBufferSize) {
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers(rank);
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
for (int peer : connectedPeers) {
Transport transport =
inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
connectionFutures.push_back(this->comm->connectOnSetup(peer, 0, transport));
connectionFutures.push_back(this->comm->connect(peer, 0, transport));
}
this->comm->setup();
for (size_t i = 0; i < connectionFutures.size(); i++) {
context.connections[connectedPeers[i]] = connectionFutures[i].get();
}
@@ -262,16 +261,15 @@ struct Executor::Impl {
RegisteredMemory memory =
this->comm->registerMemory(getBufferInfo(bufferType).first, getBufferInfo(bufferType).second, transportFlags);
std::vector<int> connectedPeers = getConnectedPeers(channelInfos);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
for (int peer : connectedPeers) {
comm->sendMemoryOnSetup(memory, peer, 0);
comm->sendMemory(memory, peer, 0);
}
channelInfos = plan.impl_->getChannelInfos(rank, bufferType);
connectedPeers = getConnectedPeers(channelInfos);
for (int peer : connectedPeers) {
remoteRegMemoryFutures.push_back(comm->recvMemoryOnSetup(peer, 0));
remoteRegMemoryFutures.push_back(comm->recvMemory(peer, 0));
}
comm->setup();
for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) {
context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get());
}
@@ -307,7 +305,6 @@ struct Executor::Impl {
channelInfos = plan.impl_->getUnpairedChannelInfos(rank, nranks, channelType);
processChannelInfos(channelInfos);
}
this->comm->setup();
context.memorySemaphores = std::move(memorySemaphores);
context.proxySemaphores = std::move(proxySemaphores);

35 changes: 34 additions & 1 deletion src/include/communicator.hpp
Original file line number Diff line number Diff line change
@@ -9,9 +9,29 @@
#include <unordered_map>
#include <vector>

#include "utils_internal.hpp"

namespace mscclpp {

class ConnectionBase;
class BaseRecvItem {
public:
virtual ~BaseRecvItem() = default;
virtual void wait() = 0;
virtual bool isReady() const = 0;
};

template <typename T>
class RecvItem : public BaseRecvItem {
public:
RecvItem(std::shared_future<T> future) : future_(future) {}

void wait() { future_.wait(); }

bool isReady() const { return future_.wait_for(std::chrono::seconds(0)) == std::future_status::ready; }

private:
std::shared_future<T> future_;
};

struct ConnectionInfo {
int remoteRank;
@@ -22,9 +42,22 @@ struct Communicator::Impl {
std::shared_ptr<Bootstrap> bootstrap_;
std::shared_ptr<Context> context_;
std::unordered_map<const Connection*, ConnectionInfo> connectionInfos_;
std::shared_ptr<BaseRecvItem> lastRecvItem_;

// Temporary storage for the latest RecvItem of each {remoteRank, tag} pair.
// If the RecvItem gets ready, it will be removed at the next call to getLastRecvItem.
std::unordered_map<std::pair<int, int>, std::shared_ptr<BaseRecvItem>, PairHash> lastRecvItems_;

Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context);

// Set the last RecvItem for a {remoteRank, tag} pair.
// This is used to store the corresponding RecvItem of a future returned by recvMemory() or connect().
void setLastRecvItem(int remoteRank, int tag, std::shared_ptr<BaseRecvItem> item);

// Return the last RecvItem that is not ready.
// If the item is ready, it will be removed from the map and nullptr will be returned.
std::shared_ptr<BaseRecvItem> getLastRecvItem(int remoteRank, int tag);

struct Connector;
};

8 changes: 4 additions & 4 deletions src/semaphore.cc
Original file line number Diff line number Diff line change
@@ -9,14 +9,14 @@

namespace mscclpp {

static NonblockingFuture<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
void* localInboundSemaphoreId) {
static std::shared_future<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
void* localInboundSemaphoreId) {
auto localInboundSemaphoreIdsRegMem =
communicator.registerMemory(localInboundSemaphoreId, sizeof(uint64_t), connection->transport());
int remoteRank = communicator.remoteRankOf(*connection);
int tag = communicator.tagOf(*connection);
communicator.sendMemoryOnSetup(localInboundSemaphoreIdsRegMem, remoteRank, tag);
return communicator.recvMemoryOnSetup(remoteRank, tag);
communicator.sendMemory(localInboundSemaphoreIdsRegMem, remoteRank, tag);
return communicator.recvMemory(remoteRank, tag);
}

static detail::UniqueGpuPtr<uint64_t> createGpuSemaphoreId() {
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.