Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling Infiniband support for Gloo data channel with auto IB detection #4795

Merged
merged 5 commits into from Jan 24, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions setup.py
Expand Up @@ -25,6 +25,7 @@
from tools.setup_helpers.split_types import split_types
from tools.setup_helpers.generate_code import generate_code
from tools.setup_helpers.ninja_builder import NinjaBuilder, ninja_build_ext
from tools.setup_helpers.ib_detect import WITH_IB_DEVICES

DEBUG = check_env_flag('DEBUG')

Expand Down Expand Up @@ -138,6 +139,10 @@ def build_libs(libs):
my_env["CUDNN_LIBRARY"] = CUDNN_LIBRARY
my_env["CUDNN_INCLUDE_DIR"] = CUDNN_INCLUDE_DIR

if WITH_DISTRIBUTED and (WITH_IB_DEVICES or

This comment was marked as off-topic.

This comment was marked as off-topic.

check_env_flag("WITH_GLOO_IBVERBS")):
build_libs_cmd += ['--with-gloo-ibverbs']

if subprocess.call(build_libs_cmd + libs, env=my_env) != 0:
sys.exit(1)

Expand Down
76 changes: 76 additions & 0 deletions tools/setup_helpers/ib_detect.py
@@ -0,0 +1,76 @@
import os
import subprocess
import re


WITH_IB_DEVICES = False
IB_DEVINFO_CMD = "ibv_devinfo"


def get_command_path(command):

This comment was marked as off-topic.

This comment was marked as off-topic.

"""
Helper function that get the full path of a given linux command
"""
def excutable(command_path):
return os.path.isfile(command_path) and os.access(command_path, os.X_OK)

for path in os.environ["PATH"].split(os.pathsep):
command_path = os.path.join(path, command)
if excutable(command_path):
return command_path

return None


def detect_ib_devices():
"""
Helper function that detects if there are Infiniband devices on the host,
and returns the number of IB devices detected or None for failure to detect
"""
try:
full_cmd_path = get_command_path(IB_DEVINFO_CMD)
if not full_cmd_path:
return None
out = subprocess.check_output([full_cmd_path, "--list"])
# find the first line of the output
# The outpyt should be either:
#
# > ibv_devinfo --list
# 0 HCAs founds:
#
# or
#
# > ibv_devinfo --list
# 4 HCAs found:
# mlx5_3
# mlx5_2
# mlx5_1
# mlx5_0
first_line = out.decode().split('\n')[0]
res = re.findall("\d+", first_line)
if len(res) != 1:
raise Exception("-- IB_detect: unexpected parsing error while "
"trying to find the number of available devices.")
return int(res[0])

This comment was marked as off-topic.

This comment was marked as off-topic.


except Exception as ex:
# We just take all the exceptions here without affecting the build
print("-- IB_detect: encountered an exception: {}".format(str(ex)))
return None


num_ib_devices = detect_ib_devices()

if num_ib_devices is None:
print("-- IB_detect: unable to detect IB devices, "
"compiling with no IB support by default unless overridden "
"by WITH_GLOO_IBVERBS")

elif num_ib_devices > 0:
print("-- IB_detect: {} IB devices detected, compiling with IB support."
.format(num_ib_devices))
WITH_IB_DEVICES = True

else:
print("-- IB_detect: no IB device detected, compiling with no IB support "

This comment was marked as off-topic.

This comment was marked as off-topic.

"by default unless overridden by WITH_GLOO_IBVERBS")
7 changes: 7 additions & 0 deletions torch/lib/THD/CMakeLists.txt
Expand Up @@ -67,6 +67,13 @@ ENDIF()

IF(GLOO_FOUND)
ADD_DEFINITIONS(-DWITH_GLOO=1)
MESSAGE(STATUS "Found Gloo, will compile with Gloo distributed backend")
IF(WITH_GLOO_IBVERBS)
MESSAGE(STATUS "Building the gloo backend with both TCP and infiniband support")
ADD_DEFINITIONS(-DWITH_GLOO_IBVERBS=1)
ELSE()
MESSAGE(STATUS "Building the gloo backend with TCP support only")
ENDIF()
ENDIF()

IF(NCCL_FOUND)
Expand Down
41 changes: 35 additions & 6 deletions torch/lib/THD/base/data_channels/DataChannelGloo.cpp
Expand Up @@ -3,6 +3,10 @@
#include "GlooCache.hpp"
#include "Store.hpp"

#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS
#include "gloo/transport/ibverbs/device.h"
#endif

#include "gloo/transport/tcp/device.h"

#include <algorithm>
Expand Down Expand Up @@ -69,11 +73,36 @@ DataChannelGloo::DataChannelGloo(InitMethod::Config config)
{
_num_processes = config.world_size;

// Default options listen on this host's name.
// NOTE: when hostname has bad configuration in `/etc/hosts` processes
// will not connect to each other.
::gloo::transport::tcp::attr attr(config.public_address.c_str());
_device = ::gloo::transport::tcp::CreateDevice(attr);
#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS

// This helper function automatically detects the IB device in the system
auto ibDeviceNames = ::gloo::transport::ibverbs::getDeviceNames();

// If there are IB devices, we will use IB
if (!ibDeviceNames.empty()) {
// Currently, gloo only supports a single IB device and will use the first
auto ibDeviceToUse = ibDeviceNames[0];

::gloo::transport::ibverbs::attr attr = {
.name = ibDeviceToUse,
.port = 1,
.index = 0,
};

_deviceList.push_back(::gloo::transport::ibverbs::CreateDevice(attr));

// Otherwise, fallback to use TCP instead
} else

#endif

{

This comment was marked as off-topic.

// Default options listen on this host's name.
// NOTE: when hostname has bad configuration in `/etc/hosts` processes
// will not connect to each other.
::gloo::transport::tcp::attr attr(config.public_address.c_str());
_deviceList.push_back(::gloo::transport::tcp::CreateDevice(attr));
}

if (_rank == 0) {
_addr = "localhost";
Expand All @@ -91,7 +120,7 @@ DataChannelGloo::~DataChannelGloo() {}
void DataChannelGloo::destroy() {}

bool DataChannelGloo::init() {
_cache = std::unique_ptr<GlooCache>(new GlooCache(_rank, _device));
_cache = std::unique_ptr<GlooCache>(new GlooCache(_rank, _deviceList));

std::vector<rank_type> ranks;
ranks.reserve(_num_processes);
Expand Down
11 changes: 10 additions & 1 deletion torch/lib/THD/base/data_channels/DataChannelGloo.hpp
Expand Up @@ -103,7 +103,16 @@ struct DataChannelGloo : DataChannel {
std::string _addr;
port_type _port;
rank_type _num_processes; // Number of processes in network
std::shared_ptr<::gloo::transport::Device> _device;
/**
* The list of network devices (such as Infiniband) that will be used by Gloo.
* Currently Gloo only supports a single network device. Therefore:
*
* _deviceList.size() will always be equal or less than 1.
*
* We make it a vector for the purpose of future extension to support multiple
* network devices.
*/
std::vector<std::shared_ptr<::gloo::transport::Device>> _deviceList;

This comment was marked as off-topic.

This comment was marked as off-topic.

std::unordered_map<THDGroup, Group> _groups;
int _listen_socket;

Expand Down
75 changes: 46 additions & 29 deletions torch/lib/THD/base/data_channels/GlooCache.hpp
Expand Up @@ -81,9 +81,10 @@ struct GlooCache {
std::shared_ptr<std::mutex> // mutex to protect same algorithm from running concurrently
>;

GlooCache(rank_type rank, std::shared_ptr<::gloo::transport::Device> device)
GlooCache(rank_type rank,
std::vector<std::shared_ptr<::gloo::transport::Device>> deviceList)
: _rank(rank)
, _device(device)
, _deviceList(deviceList)
{}

GlooCache(GlooCache const&) = delete;
Expand Down Expand Up @@ -113,10 +114,25 @@ struct GlooCache {
const DataChannelGloo::Group& group,
const std::string& prefix
) {
/**
* We currently only supports a single Infiniband interface. In other words,
* if there are multiple Infiniband devices in the system, Gloo will detect
* all of them and use the first device.
*
* TODO: This can be extended later to utilize multiple Infiniband devices
*
* For ethernet, _deviceList[0] will always have the default ethernet
* device that is detected from the user's provided IP address and there
* won't be multiple one device in _deviceList
*
* For Infiniband, _deviceList[0], which is the first found IB interfance,
* will be used by all Gloo operations.
*/
size_t curDevice = 0;
auto context = std::make_shared<context_type>(
group.mustGetGroupRank(_rank), group.size());
prefix_store_type prefix_store(prefix, *group._store);
context->connectFullMesh(prefix_store, _device);
context->connectFullMesh(prefix_store, _deviceList[curDevice]);
return context;
}

Expand Down Expand Up @@ -210,7 +226,7 @@ struct GlooCache {
}

rank_type _rank;
std::shared_ptr<::gloo::transport::Device> _device;
std::vector<std::shared_ptr<::gloo::transport::Device>> _deviceList;
std::shared_ptr<store_type> _store;

std::mutex _mutex;
Expand Down Expand Up @@ -309,7 +325,7 @@ struct algorithm_spec<CollectiveType::ALL_REDUCE, T> {
}
auto stream = THCState_getCurrentStream(THDGetCudaState());

#if defined(GLOO_USE_IBVERBS) && GLOO_USE_IBVERBS
#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS
// Only enable GPU direct if the device supports it
if (context->getDevice()->hasGPUDirect()) {
algo = std::make_shared<::gloo::CudaAllreduceHalvingDoublingPipelined<T,
Expand Down Expand Up @@ -375,31 +391,32 @@ struct algorithm_spec<CollectiveType::BROADCAST, T> {
#ifdef WITH_CUDA
} else if (device == DeviceType::CUDA) {
auto stream = THCState_getCurrentStream(THDGetCudaState());

#if defined(GLOO_USE_IBVERBS) && GLOO_USE_IBVERBS
// Only enable GPU direct if the device supports it
if (context->getDevice()->hasGPUDirect()) {
algo = std::make_shared<::gloo::CudaBroadcastOneToAll<T,
::gloo::CudaDeviceWorkspace<T>>>(
context,
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
count,
src_rank,
0,
std::vector<cudaStream_t>{stream});
} else
#endif
{
algo = std::make_shared<::gloo::CudaBroadcastOneToAll<T,
::gloo::CudaHostWorkspace<T>>>(
context,
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
count,
src_rank,
0,
std::vector<cudaStream_t>{stream});
}

#if defined(WITH_GLOO_IBVERBS) && WITH_GLOO_IBVERBS
// Only enable GPU direct if the device supports it
if (context->getDevice()->hasGPUDirect()) {
algo = std::make_shared<::gloo::CudaBroadcastOneToAll<T,
::gloo::CudaDeviceWorkspace<T>>>(
context,
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
count,
src_rank,
0,
std::vector<cudaStream_t>{stream});
} else
#endif
{
algo = std::make_shared<::gloo::CudaBroadcastOneToAll<T,
::gloo::CudaHostWorkspace<T>>>(
context,
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
count,
src_rank,
0,
std::vector<cudaStream_t>{stream});
}
#endif

This comment was marked as off-topic.

This comment was marked as off-topic.


} else {
throw std::runtime_error("unsupported tensor device in Gloo broadcast");
}
Expand Down
14 changes: 14 additions & 0 deletions torch/lib/build_libs.sh
Expand Up @@ -21,6 +21,12 @@ if [[ "$1" == "--with-nnpack" ]]; then
shift
fi

WITH_GLOO_IBVERBS=0
if [[ "$1" == "--with-gloo-ibverbs" ]]; then
WITH_GLOO_IBVERBS=1
shift
fi

cd "$(dirname "$0")/../.."
PWD=`printf "%q\n" "$(pwd)"`
BASE_DIR="$PWD"
Expand All @@ -47,10 +53,16 @@ else
fi
CPP_FLAGS=" -std=c++11 "
GLOO_FLAGS=""
THD_FLAGS=""
NCCL_ROOT_DIR=${NCCL_ROOT_DIR:-$INSTALL_DIR}
if [[ $WITH_CUDA -eq 1 ]]; then
GLOO_FLAGS="-DUSE_CUDA=1 -DNCCL_ROOT_DIR=$NCCL_ROOT_DIR"
fi
# Gloo infiniband support
if [[ $WITH_GLOO_IBVERBS -eq 1 ]]; then
GLOO_FLAGS+=" -DUSE_IBVERBS=1 -DBUILD_SHARED_LIBS=1"
THD_FLAGS="-DWITH_GLOO_IBVERBS=1"
fi
CWRAP_FILES="\
$BASE_DIR/torch/lib/ATen/Declarations.cwrap;\
$BASE_DIR/torch/lib/THNN/generic/THNN.h;\
Expand Down Expand Up @@ -181,6 +193,8 @@ for arg in "$@"; do
build gloo $GLOO_FLAGS
elif [[ "$arg" == "ATen" ]]; then
build_aten
elif [[ "$arg" == "THD" ]]; then
build THD $THD_FLAGS
else
build $arg
fi
Expand Down