Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GPU-to-GPU comm in TensorPipeAgent #44418

Closed
wants to merge 63 commits into from
Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b0e83ac
Use streams from pool on RPC callees
mrshenli Sep 9, 2020
6ae4c16
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 10, 2020
0a23d4d
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 10, 2020
c6dec67
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
7b073dc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
7489147
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
2dc32dc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
af38592
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 18, 2020
77b06e7
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 20, 2020
2b84077
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 21, 2020
b445a87
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 21, 2020
3115ad3
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 21, 2020
7fab7fc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 22, 2020
5b33064
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 22, 2020
20ce42b
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 22, 2020
0314cfc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
a25d60a
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
f642020
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
5d74596
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
31a1c64
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 16, 2020
c215331
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 18, 2020
dc846b5
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
0ff581a
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
209010c
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
402fa03
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
eedec85
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 20, 2020
0494738
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 20, 2020
9a22ec3
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 23, 2020
0b2cb77
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 23, 2020
c87efe8
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 23, 2020
27b1a0d
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 25, 2020
46c49e5
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 26, 2020
124d160
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
ced8292
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
6e86c60
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
6988c90
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
5960d23
Update on "Use streams from pool on RPC callees"
mrshenli Dec 9, 2020
4ab8e85
Update on "Use streams from pool on RPC callees"
mrshenli Dec 9, 2020
5a90106
Update on "Use streams from pool on RPC callees"
mrshenli Dec 18, 2020
bb94e5f
Update on "Use streams from pool on RPC callees"
mrshenli Jan 10, 2021
6834072
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
ed0c81e
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
84ecc25
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
59d95ba
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
e9133d0
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
a1dcfe2
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
9c98bc9
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
b90a8af
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 12, 2021
3a0651f
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 12, 2021
43349e4
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
e2b9581
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
0e5e5b6
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
fb6e4ea
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
8c935e4
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
9b5afb1
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
20294df
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
0994ff8
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
4d6ce07
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
2880a01
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
f27608f
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
36fc05b
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
28209c4
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 14, 2021
b302be9
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions torch/csrc/distributed/rpc/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#if defined(USE_CUDA) && !defined(__HIP_PLATFORM_HCC__)
#define USE_CUDA_NOT_ROCM
#endif
115 changes: 95 additions & 20 deletions torch/csrc/distributed/rpc/tensorpipe_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,42 @@ C10_REGISTER_CREATOR(TensorPipeCudaChannelRegistry, cuda_ipc, makeCudaIpcChannel

} // namespace

namespace {

struct FullDeviceContextGuard {
FullDeviceContextGuard(const FullDeviceContextGuard& other) = delete;
FullDeviceContextGuard(FullDeviceContextGuard&& other) = delete;
FullDeviceContextGuard& operator=(const FullDeviceContextGuard& rhs) = delete;
FullDeviceContextGuard& operator=(FullDeviceContextGuard&& rhs) = delete;

#ifndef USE_CUDA_NOT_ROCM
FullDeviceContextGuard(
const std::shared_ptr<FullDeviceContext>& /* unused */) {};
#else
FullDeviceContextGuard(const std::shared_ptr<FullDeviceContext>& ctx) {
const auto& streams = ctx->streams();
std::vector<CUDAStream> prevStreams_;
prevStreams_.reserve(streams.size());
for (const auto& stream: streams) {
prevStreams_.emplace_back(
at::cuda::getCurrentCUDAStream(stream.device_index()));
at::cuda::setCurrentCUDAStream(stream);
}
}

~FullDeviceContextGuard() noexcept {
for (auto& stream : prevStreams_) {
at::cuda::setCurrentCUDAStream(stream);
}
}

private:
std::vector<CUDAStream> prevStreams_;
#endif
};
mrshenli marked this conversation as resolved.
Show resolved Hide resolved

} // namespace

////////////////////////// MetricsTracker /////////////////////////////////

TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
Expand Down Expand Up @@ -408,35 +444,45 @@ void TensorPipeAgent::onListenerAccepted(

void TensorPipeAgent::pipeRead(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
std::function<void(const tensorpipe::Error&, Message&&)> fn) noexcept {
pipe->readDescriptor([fn{std::move(fn)}, pipe](
std::function<void(
const tensorpipe::Error&,
Message&&,
std::shared_ptr<FullDeviceContext>)> fn) noexcept {
pipe->readDescriptor([fn{std::move(fn)}, pipe, this](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
// TODO: pass these streams to TensorPipe when it can accept streams
auto ctx = createFullDeviceContext(
reverseDeviceMaps_.empty() && opts_.deviceMaps.empty());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we only create a context for devices involved in the operation? It looks like we get a stream from the pool for all available devices.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @pritamdamania87. As mentioned above, I think it would be good if we could be "lazy" and only access devices once we need them, to avoid initializing them uselessly.

I suspect one reason for getting a stream for all devices is in case the remote user function returns a tensor on a device that wasn't used by any of the arguments. In such a case, we would need to have set a current stream for that device before invoking the remote user function.

However, we've had to deal with a similar concern regarding CUDAFuture callbacks, and there it proved to be impossible to get a stream for all devices (it ended up causing a deadlock with NCCL), hence we've resorted to just get a stream for the devices used by the future's value, and declare that if the callback uses another device that's undefined behavior.

Hence, I think we could do the same thing for remote users functions (it would be consistent!), and maybe we can find a better solution that covers both use cases in a later release? (As a general API approach, I'd rather start with more restrictions and gradually lift them, than start with a generic but problematic solution that cannot be changed without breaking backwards compatibility).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hence we've resorted to just get a stream for the devices used by the future's value, and declare that if the callback uses another device that's undefined behavior.

I am not sure if this OK for RPC user functions. Suppose we are running pipeline parallel on a 8GPU machine with 2 processes. The 1st process would take input on cuda:0 and produce output on cuda:3, and then send it to cuda:4 on the other process. If we only use devices listed in the input args, does it mean users will have to move the intermediate output to cuda:0 before sending it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, I think it would be good if we could be "lazy" and only access devices once we need them, to avoid initializing them uselessly.

Yep, let me make this change first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 1st process would take input on cuda:0 and produce output on cuda:3

Are we sure this is a common use-case? Correct me if I'm wrong, but I believe that all CUDA kernels (i.e., the computations themselves) operate with inputs and outputs that are on the same device. Hence if a function returns an output on a different device than its inputs, it must be because the function has manually explicitly performed such a transfer, for example by calling .to(idx). I could try to argue that this should be considered non-idiomatic for RPC, since RPC should be used for all data transfers, in order to provide consistency, performance and "scalability" (as in resiliency to topology changes, since it allows to easily switch from local GPUs to remote GPUs without code changes). Hence such a function should probably be split into two parts (each one operating on a single device) which should be connected through RPC calls.

Admittedly the code to do this might be a bit more convoluted than just using .to(...), but I think we can always provide helpers and examples to make this easier to use. (For example, I believe #48790 might already help).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hence such a function should probably be split into two parts (each one operating on a single device) which should be connected through RPC calls.

This would mean we need an overhaul to our torch.distributed.pipe implementation. I am open to either revamp the existing one, or implement a new distributed pipe where each process only exclusively uses one device.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you fill me in on what torch.distributed.pipe does? Is it calling .to() explicitly?


if (error) {
fn(error, Message());
fn(error, Message(), std::move(ctx));
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
return;
}

TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage);
TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage, ctx);

pipe->read(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
if (error) {
fn(error, Message());
fn(error, Message(), std::move(ctx));
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
return;
}

// FIXME This does some unpickling, which could be a bit expensive:
// perhaps it would be best to perform it inside the worker threads?
ctx->blockCurrentStreams();
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
Message rpcMessage = tensorpipeDeserialize(
std::move(tpMessage), std::move(*tpBuffers));

fn(error, std::move(rpcMessage));
ctx->recordDataPtrs(tpBuffers->tensors);
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
fn(error, std::move(rpcMessage), std::move(ctx));
});
});
}
Expand All @@ -445,27 +491,32 @@ void TensorPipeAgent::pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
Message&& rpcMessage,
std::vector<c10::DeviceIndex>&& devices,
std::shared_ptr<FullDeviceContext> ctx,
std::function<void(const tensorpipe::Error&)> fn) noexcept {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers tpBuffers;

std::tie(tpMessage, tpBuffers) =
tensorpipeSerialize(std::move(rpcMessage), std::move(devices));

std::tie(tpMessage, tpBuffers) = tensorpipeSerialize(
std::move(rpcMessage), std::move(devices), ctx);

pipe->write(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error, tensorpipe::Message /* unused */) {
ctx->recordTensors(tpBuffers->tensors);
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
fn(error);
});
}

void TensorPipeAgent::sendCompletedResponseMessage(
std::shared_ptr<tensorpipe::Pipe>& pipe,
std::shared_ptr<FutureMessage>& futureResponseMessage,
uint64_t messageId) {
uint64_t messageId,
std::shared_ptr<FullDeviceContext> ctx) {
if (!rpcAgentRunning_.load()) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " won't send response to request #" << messageId << " to "
Expand Down Expand Up @@ -494,7 +545,9 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipe,
std::move(responseMessage),
std::move(devices),
[this, pipe, messageId](const tensorpipe::Error& error) {
std::move(ctx),
[this, pipe, messageId](
const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
<< "RPC agent for " << workerInfo_.name_
Expand All @@ -512,7 +565,8 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipeWrite(
pipe,
createExceptionResponse(error->what(), responseMessage.id()),
{},
/* devices */{},
std::move(ctx),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
Expand All @@ -534,7 +588,9 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
pipeRead(
pipe,
[this, pipe](
const tensorpipe::Error& error, Message&& requestMessage) mutable {
const tensorpipe::Error& error,
Message&& requestMessage,
std::shared_ptr<FullDeviceContext> ctx) mutable {
if (error) {
// FIXME This is not a correct way to check whether this error was
// "intentionally" caused by the remote end shutting down. We should
Expand Down Expand Up @@ -567,7 +623,10 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
threadPool_.run([this,
pipe,
messageId,
requestMessage{std::move(requestMessage)}]() mutable {
requestMessage{std::move(requestMessage)},
ctx{std::move(ctx)}]() mutable {
// create guards again as this function runs on a different thread
FullDeviceContextGuard guard(ctx);
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " is running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
Expand All @@ -584,16 +643,21 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
if (futureResponseMessage->completed()) {
decreaseCallCount(serverActiveCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
pipe, futureResponseMessage, messageId, std::move(ctx));
} else {
// Not complete yet
increaseCallCount(serverActiveAsyncCalls_);
futureResponseMessage->addCallback(
[this, pipe, futureResponseMessage, messageId]() mutable {
[this,
pipe,
futureResponseMessage,
messageId,
ctx{std::move(ctx)}]() mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
FullDeviceContextGuard guard(ctx);
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
pipe, futureResponseMessage, messageId, std::move(ctx));
});
}

Expand Down Expand Up @@ -682,10 +746,13 @@ std::shared_ptr<FutureMessage> TensorPipeAgent::send(
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
<< messageId << " to " << clientPipe.pipe_->getRemoteName();

auto ctx = createFullDeviceContext(devices.empty());
ctx->waitForCurrentStreams();
pipeWrite(
clientPipe.pipe_,
std::move(requestMessage),
std::move(devices),
std::move(ctx),
[this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
Expand All @@ -712,7 +779,10 @@ std::shared_ptr<FutureMessage> TensorPipeAgent::send(
pipeRead(
clientPipe.pipe_,
[this, &clientPipe](
const tensorpipe::Error& error, Message&& responseMessage) {
const tensorpipe::Error& error,
Message&& responseMessage,
std::shared_ptr<FullDeviceContext> ctx) {
ctx->blockCurrentStreams();
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
!rpcAgentRunning_.load()) {
Expand Down Expand Up @@ -1092,6 +1162,7 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
std::vector<c10::DeviceIndex> deviceIndices;
deviceIndices.reserve(message.tensors().size());
const auto& deviceMap = iter->second;
bool hasCudaTensor = false;
for (const auto& t : message.tensors()) {
if (t.device().is_cpu()) {
deviceIndices.push_back(-1);
Expand All @@ -1104,8 +1175,12 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
t.device(),
" but received a tensor on that device.");
deviceIndices.push_back(deviceIter->second);
hasCudaTensor = true;
}
}
if (!hasCudaTensor) {
deviceIndices.clear();
}
return deviceIndices;
}
}
Expand Down
17 changes: 11 additions & 6 deletions torch/csrc/distributed/rpc/tensorpipe_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <c10d/PrefixStore.hpp>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <torch/csrc/distributed/rpc/macros.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>


Expand All @@ -17,10 +18,6 @@

namespace tensorpipe {

#if defined(USE_CUDA) && !defined(__HIP_PLATFORM_HCC__)
#define USE_CUDA_NOT_ROCM
#endif

class CpuBuffer;

#ifdef USE_CUDA_NOT_ROCM
Expand Down Expand Up @@ -59,6 +56,8 @@ namespace torch {
namespace distributed {
namespace rpc {

class FullDeviceContext;

using steady_clock_time_point =
std::chrono::time_point<std::chrono::steady_clock>;

Expand Down Expand Up @@ -163,6 +162,7 @@ struct AggregatedNetworkData {
uint64_t totalErrors{0};
};


// TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe)
// to transparently move tensors and payloads through the fastest available
// transport or channel. It acts like a hybrid RPC transport, providing shared
Expand Down Expand Up @@ -229,14 +229,18 @@ class TensorPipeAgent : public RpcAgent {
// by client, and read request messages by server.
void pipeRead(
const std::shared_ptr<tensorpipe::Pipe>&,
std::function<void(const tensorpipe::Error&, Message&&)>) noexcept;
std::function<void(
const tensorpipe::Error&,
Message&&,
std::shared_ptr<FullDeviceContext>)>) noexcept;

// TensorPipe write function that could be used to write response
// messages by server, and write request messages by client.
void pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>&,
Message&& message,
std::vector<c10::DeviceIndex>&& devices,
std::shared_ptr<FullDeviceContext> ctx,
std::function<void(const tensorpipe::Error&)>) noexcept;

// Callback of listener accept()
Expand All @@ -250,7 +254,8 @@ class TensorPipeAgent : public RpcAgent {
void sendCompletedResponseMessage(
std::shared_ptr<tensorpipe::Pipe>& pipe,
std::shared_ptr<FutureMessage>& futureResponseMessage,
uint64_t messageId);
uint64_t messageId,
std::shared_ptr<FullDeviceContext> ctx);

// Collects metrics from successful RPC calls
void trackNetworkData(
Expand Down