Skip to content

Commit

Permalink
FP16 support for GPU tensors in all frameworks (#529)
Browse files Browse the repository at this point in the history
* Initial support for FP16

Bump version to a dev release

Cast vars to fp16 before allreduce to compress gradients

Abstracted compression algorithm into a class hierarchy and added algorithm flag to optimizer and allreduce signatures

Changed compressor to set the dtype on initialization

Resolved conflicts

Additional conflicts

Formatting

More formats

Updated license

Added fp16 compression for Keras

Added arguments to keras examples

Fixed imports

* Added compression to tf.keras

* Added PyTorch compression API

Added unit tests

Whitespace

* Added C interfaces and types

* Forward declare

* Removed Half from older versions of PyTorch

* Added error for old version of PyTorch

* Removed reference to float16

* Updated examples, added compression to the Keras model load

* Cleaned imports

* Removed dependency on enums

* Updated unit tests

* Test compatability fix

* Reverted version updates

* Fixed message

* Removed imports

* Added cuda.HalfTensor to all PyTorch tests with CUDA

* Only compare versions once

* Renamed --fp16 in examples to --fp16-allreduce for clarity

* Replaced assignment with set_

* Modified compression algorithms to be stateless with optional context parameters

* Removed optional ctx parameter

* Replaced 0.4.2 with 1.0.0

* Only run GPU tests with HalfTensors if fp16 is supported
  • Loading branch information
tgaddair committed Sep 28, 2018
1 parent 81c92b7 commit b2e6c06
Show file tree
Hide file tree
Showing 22 changed files with 438 additions and 71 deletions.
10 changes: 8 additions & 2 deletions examples/keras_imagenet_resnet50.py
Expand Up @@ -31,6 +31,8 @@
help='tensorboard log directory')
parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.h5',
help='checkpoint file format')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
Expand Down Expand Up @@ -91,11 +93,15 @@
# Set up standard ResNet-50 model.
model = keras.applications.resnet50.ResNet50(weights=None)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast both model and optimizer weights
# to other workers.
if resume_from_epoch > 0 and hvd.rank() == 0:
model = hvd.load_model(args.checkpoint_format.format(epoch=resume_from_epoch))
model = hvd.load_model(args.checkpoint_format.format(epoch=resume_from_epoch),
compression=compression)
else:
# ResNet-50 model that is included with Keras is optimized for inference.
# Add L2 weight decay & adjust BN settings.
Expand All @@ -117,7 +123,7 @@
momentum=args.momentum)

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)
opt = hvd.DistributedOptimizer(opt, compression=compression)

model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=opt,
Expand Down
10 changes: 8 additions & 2 deletions examples/pytorch_mnist.py
Expand Up @@ -25,6 +25,8 @@
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

Expand Down Expand Up @@ -92,9 +94,13 @@ def forward(self, x):
optimizer = optim.SGD(model.parameters(), lr=args.lr * hvd.size(),
momentum=args.momentum)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters())
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)


def train(epoch):
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/mpi_message.cc
Expand Up @@ -41,6 +41,9 @@ const std::string& MPIDataType_Name(MPIDataType value) {
case HOROVOD_INT64:
static const std::string int64("int64");
return int64;
case HOROVOD_FLOAT16:
static const std::string float16("float16");
return float16;
case HOROVOD_FLOAT32:
static const std::string float32("float32");
return float32;
Expand Down
7 changes: 4 additions & 3 deletions horovod/common/mpi_message.h
Expand Up @@ -30,9 +30,10 @@ enum MPIDataType {
HOROVOD_INT16 = 3,
HOROVOD_INT32 = 4,
HOROVOD_INT64 = 5,
HOROVOD_FLOAT32 = 6,
HOROVOD_FLOAT64 = 7,
HOROVOD_BOOL = 8
HOROVOD_FLOAT16 = 6,
HOROVOD_FLOAT32 = 7,
HOROVOD_FLOAT64 = 8,
HOROVOD_BOOL = 9
};

const std::string& MPIDataType_Name(MPIDataType value);
Expand Down
25 changes: 21 additions & 4 deletions horovod/common/operations.cc
Expand Up @@ -182,6 +182,9 @@ struct HorovodGlobalState {
// COMM_WORLD ranks of processes running on this node.
std::vector<int> local_comm_ranks;

// MPI custom data type for float16.
MPI_Datatype mpi_float16_t;

// Private MPI communicator for Horovod to ensure no collisions with other
// threads using MPI.
MPI_Comm mpi_comm;
Expand Down Expand Up @@ -520,6 +523,8 @@ MPI_Datatype GetMPIDataType(const std::shared_ptr<Tensor> tensor) {
return MPI_INT32_T;
case HOROVOD_INT64:
return MPI_INT64_T;
case HOROVOD_FLOAT16:
return horovod_global.mpi_float16_t;
case HOROVOD_FLOAT32:
return MPI_FLOAT;
case HOROVOD_FLOAT64:
Expand All @@ -539,6 +544,8 @@ ncclDataType_t GetNCCLDataType(const std::shared_ptr<Tensor> tensor) {
return ncclInt32;
case HOROVOD_INT64:
return ncclInt64;
case HOROVOD_FLOAT16:
return ncclFloat16;
case HOROVOD_FLOAT32:
return ncclFloat32;
case HOROVOD_FLOAT64:
Expand Down Expand Up @@ -1010,7 +1017,7 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) {
#else
if (horovod_global.hierarchical_allreduce) {
int element_size;
MPI_Type_size(GetMPIDataType(first_entry.tensor), &element_size);
MPI_Type_size(GetMPIDataType(first_entry.tensor), &element_size);

// If cluster is homogeneous and we are using fusion buffer, include
// dummy elements from the buffer (if necessary) to make sure the data
Expand Down Expand Up @@ -1110,7 +1117,7 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) {
WAIT_FOR_EVENTS(entries, timeline, event_queue)

// According to https://docs.nvidia.com/cuda/cuda-runtime-api/
// api-sync-behavior.html#api-sync-behavior__memcpy-async,
// api-sync-behavior.html#api-sync-behavior__memcpy-async,
// cudaMemcpyAsync is synchronous with respect to the host, so we
// memcpy (effectively) synchronously to generate an accurate timeline
ACTIVITY_START_ALL(entries, timeline, MEMCPY_IN_HOST_BUFFER)
Expand Down Expand Up @@ -1508,6 +1515,11 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
MPI_Comm_rank(cross_comm, &cross_rank);
MPI_Comm_size(cross_comm, &cross_size);

// Create custom MPI float16 data type.
MPI_Datatype mpi_float16_t;
MPI_Type_contiguous(2, MPI_BYTE, &mpi_float16_t);
MPI_Type_commit(&mpi_float16_t);

state.rank = rank;
state.local_rank = local_rank;
state.cross_rank = cross_rank;
Expand All @@ -1516,6 +1528,7 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
state.cross_size = cross_size;
state.local_comm = local_comm;
state.cross_comm = cross_comm;
state.mpi_float16_t = mpi_float16_t;
state.mpi_threads_supported = (provided == MPI_THREAD_MULTIPLE);
state.local_comm_ranks = local_comm_ranks;

Expand Down Expand Up @@ -1558,8 +1571,8 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
}

// Override Tensor Fusion threshold, if it's set.
auto horovod_fusion_threshold = std::getenv("HOROVOD_FUSION_THRESHOLD");
int64_t proposed_fusion_threshold = (horovod_fusion_threshold != nullptr) ?
auto horovod_fusion_threshold = std::getenv("HOROVOD_FUSION_THRESHOLD");
int64_t proposed_fusion_threshold = (horovod_fusion_threshold != nullptr) ?
std::strtol(horovod_fusion_threshold, nullptr, 10) :
state.tensor_fusion_threshold;

Expand Down Expand Up @@ -1924,6 +1937,10 @@ void horovod_shutdown() {
MPI_Comm_free(&horovod_global.cross_comm);
}

if (horovod_global.mpi_float16_t != MPI_DATATYPE_NULL) {
MPI_Type_free(&horovod_global.mpi_float16_t);
}

if (horovod_global.should_finalize) {
#if HAVE_DDL
// ddl_finalize calls MPI_Finalize
Expand Down
7 changes: 4 additions & 3 deletions horovod/common/wire/mpi_message.fbs
Expand Up @@ -24,9 +24,10 @@ enum MPIDataType:byte {
HOROVOD_INT16 = 3,
HOROVOD_INT32 = 4,
HOROVOD_INT64 = 5,
HOROVOD_FLOAT32 = 6,
HOROVOD_FLOAT64 = 7,
HOROVOD_BOOL = 8
HOROVOD_FLOAT16 = 6,
HOROVOD_FLOAT32 = 7,
HOROVOD_FLOAT64 = 8,
HOROVOD_BOOL = 9
}

// An MPIRequest is a message sent from a rank greater than zero to the
Expand Down
8 changes: 5 additions & 3 deletions horovod/common/wire/mpi_message_generated.h
Expand Up @@ -40,9 +40,10 @@ enum MPIDataType {
MPIDataType_HOROVOD_INT16 = 3,
MPIDataType_HOROVOD_INT32 = 4,
MPIDataType_HOROVOD_INT64 = 5,
MPIDataType_HOROVOD_FLOAT32 = 6,
MPIDataType_HOROVOD_FLOAT64 = 7,
MPIDataType_HOROVOD_BOOL = 8,
MPIDataType_HOROVOD_FLOAT16 = 6,
MPIDataType_HOROVOD_FLOAT32 = 7,
MPIDataType_HOROVOD_FLOAT64 = 8,
MPIDataType_HOROVOD_BOOL = 9,
MPIDataType_MIN = MPIDataType_HOROVOD_UINT8,
MPIDataType_MAX = MPIDataType_HOROVOD_BOOL
};
Expand All @@ -55,6 +56,7 @@ inline const char **EnumNamesMPIDataType() {
"HOROVOD_INT16",
"HOROVOD_INT32",
"HOROVOD_INT64",
"HOROVOD_FLOAT16",
"HOROVOD_FLOAT32",
"HOROVOD_FLOAT64",
"HOROVOD_BOOL",
Expand Down
18 changes: 14 additions & 4 deletions horovod/keras/__init__.py
Expand Up @@ -23,12 +23,15 @@
from horovod.tensorflow import rank
from horovod.tensorflow import local_rank
from horovod.tensorflow import mpi_threads_supported
from horovod.tensorflow import Compression

from horovod.keras import callbacks
from horovod.keras import impl as _impl


def DistributedOptimizer(optimizer, name=None, device_dense='', device_sparse=''):
def DistributedOptimizer(optimizer, name=None,
device_dense='', device_sparse='',
compression=Compression.none):
"""
An optimizer that wraps another keras.optimizers.Optimizer, using an allreduce to
average gradient values before applying gradients to model weights.
Expand All @@ -42,8 +45,12 @@ def DistributedOptimizer(optimizer, name=None, device_dense='', device_sparse=''
if Horovod was build with HOROVOD_GPU_ALLREDUCE.
device_sparse: Device to be used for sparse tensors. Uses GPU by default
if Horovod was build with HOROVOD_GPU_ALLGATHER.
compression: Compression algorithm used to reduce the amount of data
sent and received by each worker node. Defaults to not
using compression.
"""
return _impl.create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse)
return _impl.create_distributed_optimizer(keras, optimizer, name,
device_dense, device_sparse, compression)


def broadcast_global_variables(root_rank):
Expand Down Expand Up @@ -99,7 +106,7 @@ def broadcast(value, root_rank, name=None):
return _impl.broadcast(K, value, root_rank, name)


def load_model(filepath, custom_optimizers=None, custom_objects=None):
def load_model(filepath, custom_optimizers=None, custom_objects=None, compression=Compression.none):
"""
Loads a saved Keras model with a Horovod DistributedOptimizer.
Expand All @@ -119,6 +126,9 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None):
during loading.
custom_objects: Optional dictionary mapping names (strings) to custom
classes or functions to be considered during deserialization.
compression: Compression algorithm used to reduce the amount of data
sent and received by each worker node. Defaults to not
using compression.
# Returns
A Keras model instance.
Expand All @@ -128,5 +138,5 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None):
ValueError: In case of an invalid savefile.
"""
def wrap_optimizer(cls):
return lambda **kwargs: DistributedOptimizer(cls(**kwargs))
return lambda **kwargs: DistributedOptimizer(cls(**kwargs), compression=compression)
return _impl.load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects)
17 changes: 10 additions & 7 deletions horovod/keras/impl.py
Expand Up @@ -17,14 +17,15 @@
import tensorflow as tf


def create_distributed_optimizer(keras, optimizer, name=None, device_dense='', device_sparse=''):
def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse, compression):
class _DistributedOptimizer(keras.optimizers.Optimizer):
def __init__(self, name, device_dense, device_sparse, **kwargs):
if name is None:
name = "Distributed%s" % self.__class__.__base__.__name__
self._name = name
self._device_dense = device_dense
self._device_sparse = device_sparse
self._compression = compression
super(self.__class__, self).__init__(**kwargs)

def get_gradients(self, loss, params):
Expand All @@ -42,8 +43,10 @@ def get_gradients(self, loss, params):
with tf.name_scope(self._name + "_Allreduce"):
for grad in gradients:
if grad is not None:
avg_grad = hvd.allreduce(grad, device_dense=self._device_dense,
device_sparse=self._device_sparse)
avg_grad = hvd.allreduce(grad,
device_dense=self._device_dense,
device_sparse=self._device_sparse,
compression=self._compression)
averaged_gradients.append(avg_grad)
else:
averaged_gradients.append(None)
Expand All @@ -65,22 +68,22 @@ def broadcast_global_variables(backend, root_rank):
return backend.get_session().run(bcast_op)


def allreduce(backend, value, name=None, average=True):
def allreduce(backend, value, name, average):
allreduce_op = hvd.allreduce(tf.constant(value, name=name), average=average)
return backend.get_session().run(allreduce_op)


def allgather(backend, value, name=None):
def allgather(backend, value, name):
allgather_op = hvd.allgather(tf.constant(value, name=name))
return backend.get_session().run(allgather_op)


def broadcast(backend, value, root_rank, name=None):
def broadcast(backend, value, root_rank, name):
bcast_op = hvd.broadcast(tf.constant(value, name=name), root_rank)
return backend.get_session().run(bcast_op)


def load_model(keras, wrap_optimizer, filepath, custom_optimizers=None, custom_objects=None):
def load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects):
horovod_objects = {
subclass.__name__.lower(): wrap_optimizer(subclass)
for subclass in keras.optimizers.Optimizer.__subclasses__()
Expand Down

0 comments on commit b2e6c06

Please sign in to comment.