Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary "T" before "tf.Tensor" in all_reduce.py #36713

Merged
merged 1 commit into from
Feb 13, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
82 changes: 41 additions & 41 deletions tensorflow/python/distribute/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def _flatten_tensors(tensors):
"""Check tensors for isomorphism and flatten.

Args:
tensors: list of T `tf.Tensor` which must all have the same shape.
tensors: list of `tf.Tensor` which must all have the same shape.

Returns:
tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors
tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors
shape: the original shape of each element of input tensors

Raises:
Expand All @@ -61,12 +61,12 @@ def _reshape_tensors(tensors, shape):
"""Reshape tensors flattened by _flatten_tensors.

Args:
tensors: list of T `tf.Tensor` of identical length 1D tensors.
tensors: list of `tf.Tensor` of identical length 1D tensors.
shape: list of integers describing the desired shape. Product of
the elements must equal the length of each tensor.

Returns:
list of T `tf.Tensor` which are the reshaped inputs.
list of `tf.Tensor` which are the reshaped inputs.
"""
reshaped = []
for t in tensors:
Expand All @@ -79,13 +79,13 @@ def _padded_split(tensor, pieces):
"""Like split for 1D tensors but pads-out case where len % pieces != 0.

Args:
tensor: T `tf.Tensor` that must be 1D.
tensor: `tf.Tensor` that must be 1D.
pieces: a positive integer specifying the number of pieces into which
tensor should be split.

Returns:
list of T `tf.Tensor` of length pieces, which hold the values of
thin input tensor, in order. The final tensor may
list of `tf.Tensor` of length pieces, which hold the values of
thin input tensor, in order. The final tensor may
be zero-padded on the end to make its size equal to those of all
of the other tensors.

Expand Down Expand Up @@ -132,11 +132,11 @@ def _strip_padding(tensors, pad_len):
"""Strip the suffix padding added by _padded_split.

Args:
tensors: list of T `tf.Tensor` of identical length 1D tensors.
tensors: list of `tf.Tensor` of identical length 1D tensors.
pad_len: number of elements to be stripped from the end of each tensor.

Returns:
list of T `tf.Tensor` which are the stripped inputs.
list of `tf.Tensor` which are the stripped inputs.

Raises:
ValueError: tensors must be a non-empty list of 1D tensors, and
Expand All @@ -161,13 +161,13 @@ def _ragged_split(tensor, pieces):
"""Like split for 1D tensors but allows case where len % pieces != 0.

Args:
tensor: T `tf.Tensor` that must be 1D.
tensor: `tf.Tensor` that must be 1D.
pieces: a positive integer specifying the number of pieces into which
tensor should be split.

Returns:
list of T `tf.Tensor` of length pieces, which hold the values of
the input tensor, in order. The final tensor may be shorter
list of `tf.Tensor` of length pieces, which hold the values of
the input tensor, in order. The final tensor may be shorter
than the others, which will all be of equal length.

Raises:
Expand Down Expand Up @@ -256,7 +256,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
"""Construct a subgraph performing a ring-style all-reduce of input_tensors.

Args:
input_tensors: a list of T `tf.Tensor` objects, which must all
input_tensors: a list of `tf.Tensor` objects, which must all
have the same shape and type.
num_workers: number of worker tasks spanned by input_tensors.
num_subchunks: number of subchunks each device should process in one tick.
Expand All @@ -272,7 +272,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
size.

Returns:
a list of T `tf.Tensor` identical sum-reductions of input_tensors.
a list of `tf.Tensor` identical sum-reductions of input_tensors.
"""
if len(input_tensors) < 2:
raise ValueError("input_tensors must be length 2 or longer")
Expand All @@ -299,7 +299,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks,
"""Construct a subgraph for the first (reduction) pass of ring all-reduce.

Args:
input_tensors: a list of T `tf.Tensor` 1D input tensors of same
input_tensors: a list of `tf.Tensor` 1D input tensors of same
shape and type.
devices: array of device name strings
num_subchunks: number of subchunks each device should process in one tick.
Expand All @@ -311,7 +311,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks,
ValueError: tensors must all be one dimensional.

Returns:
list of list of T `tf.Tensor` of (partially) reduced values where
list of list of `tf.Tensor` of (partially) reduced values where
exactly num_subchunks chunks at each device are fully reduced.
"""
num_devices = len(input_tensors)
Expand Down Expand Up @@ -360,11 +360,11 @@ def _apply_unary_to_chunks(f, chunks_by_dev):
"""Apply a unary op to each tensor in chunks_by_dev, on same device.

Args:
f: a unary function over T `tf.Tensor`.
chunks_by_dev: list of lists of T `tf.Tensor`.
f: a unary function over `tf.Tensor`.
chunks_by_dev: list of lists of `tf.Tensor`.

Returns:
new list of lists of T `tf.Tensor` with the same structure as
new list of lists of `tf.Tensor` with the same structure as
chunks_by_dev containing the derived tensors.
"""
output = []
Expand All @@ -381,14 +381,14 @@ def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
Args:
pred_by_s_d: as produced by _ring_permutations
rank_by_s_d: as produced by _ring_permutations
chunks_by_dev: list of list of T `tf.Tensor` indexed by ints
chunks_by_dev: list of list of `tf.Tensor` indexed by ints
(device, chunk)

Raises:
ValueError: chunks_by_dev is not well-formed

Returns:
list of T `tf.Tensor` which are the fully reduced tensors, one
list of `tf.Tensor` which are the fully reduced tensors, one
at each device corresponding to the outer dimension of chunks_by_dev.
"""
num_devices = len(chunks_by_dev)
Expand Down Expand Up @@ -448,12 +448,12 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
the future with edge-case specific logic.

Args:
input_tensors: list of T `tf.Tensor` to be elementwise reduced.
input_tensors: list of `tf.Tensor` to be elementwise reduced.
red_op: a binary elementwise reduction Op.
un_op: an optional unary elementwise Op to apply to reduced values.

Returns:
list of T `tf.Tensor` which are the fully reduced tensors, one
list of `tf.Tensor` which are the fully reduced tensors, one
at each device of input_tensors.

Raises:
Expand Down Expand Up @@ -481,13 +481,13 @@ def _build_recursive_hd_gather(input_tensors, devices, red_op):
"""Construct the gather phase of recursive halving-doubling all-reduce.

Args:
input_tensors: list of T `tf.Tensor` to be elementwise reduced.
input_tensors: list of `tf.Tensor` to be elementwise reduced.
devices: a list of strings naming the devices hosting input_tensors,
which will also be used to host the (partial) reduction values.
red_op: a binary elementwise reduction Op.

Returns:
list of T `tf.Tensor` which are the fully reduced tensor shards.
list of `tf.Tensor` which are the fully reduced tensor shards.

Raises:
ValueError: num_devices not a power of 2, or tensor len not divisible
Expand Down Expand Up @@ -522,12 +522,12 @@ def _build_recursive_hd_scatter(input_tensors, devices):
"""Construct the scatter phase of recursive halving-doublng all-reduce.

Args:
input_tensors: list of T `tf.Tensor` that are fully-reduced shards.
input_tensors: list of `tf.Tensor` that are fully-reduced shards.
devices: a list of strings naming the devices on which the reconstituted
full tensors should be placed.

Returns:
list of T `tf.Tensor` which are the fully reduced tensors.
list of `tf.Tensor` which are the fully reduced tensors.
"""
num_devices = len(devices)
num_hops = int(math.log(num_devices, 2))
Expand Down Expand Up @@ -570,14 +570,14 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
better in the other case.

Args:
input_tensors: list of T @(tf.Tensor} values to be reduced.
input_tensors: list of `tf.Tensor` values to be reduced.
gather_devices: list of names of devices on which reduction shards
should be placed.
red_op: an n-array elementwise reduction Op
un_op: optional elementwise unary Op to be applied to fully-reduced values.

Returns:
list of T `tf.Tensor` which are the fully reduced tensors.
list of `tf.Tensor` which are the fully reduced tensors.
"""
input_tensors, shape = _flatten_tensors(input_tensors)
dst_devices = [t.device for t in input_tensors]
Expand All @@ -593,14 +593,14 @@ def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
"""Construct the gather (concentrate and reduce) phase of shuffle all-reduce.

Args:
input_tensors: list of T @(tf.Tensor} values to be reduced.
input_tensors: list of `tf.Tensor` values to be reduced.
gather_devices: list of names of devices on which reduction shards
should be placed.
red_op: the binary reduction Op
un_op: optional elementwise unary Op to be applied to fully-reduced values.

Returns:
list of T `tf.Tensor` which are the fully reduced shards.
list of `tf.Tensor` which are the fully reduced shards.

Raises:
ValueError: inputs not well-formed.
Expand Down Expand Up @@ -630,12 +630,12 @@ def _build_shuffle_scatter(reduced_shards, dst_devices):
"""Build the scatter phase of shuffle all-reduce.

Args:
reduced_shards: list of T @(tf.Tensor} fully reduced shards
reduced_shards: list of `tf.Tensor` fully reduced shards
dst_devices: list of names of devices at which the fully-reduced value
should be reconstituted.

Returns:
list of T `tf.Tensor` scattered tensors.
list of `tf.Tensor` scattered tensors.
"""
num_devices = len(dst_devices)
out_tensors = []
Expand All @@ -650,7 +650,7 @@ def _split_by_task(devices, values):

Args:
devices: list of device name strings
values: list of T `tf.tensor` of same length as devices.
values: list of `tf.Tensor` of same length as devices.

Returns:
(per_task_devices, per_task_values) where both values are
Expand Down Expand Up @@ -686,14 +686,14 @@ def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
"""Build a subgraph that does one full all-reduce, using NCCL.

Args:
input_tensors: list of T `tf.Tensor` of same-shape and type values to
input_tensors: list of `tf.Tensor` of same-shape and type values to
be reduced.
red_op: binary elementwise reduction operator. Must be one of
red_op: binary elementwise reduction operator. Must be one of
{tf.add}
un_op: optional unary elementwise Op to apply to fully-reduce values.

Returns:
list of T `tf.Tensor` of reduced values.
list of `tf.Tensor` of reduced values.

Raises:
ValueError: red_op not supported.
Expand All @@ -715,14 +715,14 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
"""Construct a subgraph for NCCL hybrid all-reduce.

Args:
input_tensors: list of T `tf.Tensor` of same-shape and type values to
input_tensors: list of `tf.Tensor` of same-shape and type values to
be reduced.
red_op: binary elementwise reduction operator.
upper_level_f: function for reducing one value per worker, across
workers.

Returns:
list of T `tf.Tensor` of reduced values.
list of `tf.Tensor` of reduced values.

Raises:
ValueError: inputs not well-formed.
Expand Down Expand Up @@ -804,15 +804,15 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
"""Construct a subgraph for Shuffle hybrid all-reduce.

Args:
input_tensors: list of T `tf.Tensor` of same-shape and type values to
input_tensors: list of `tf.Tensor` of same-shape and type values to
be reduced.
gather_devices: list of device names on which to host gather shards.
red_op: binary elementwise reduction operator.
upper_level_f: function for reducing one value per worker, across
workers.

Returns:
list of T `tf.Tensor` of reduced values.
list of `tf.Tensor` of reduced values.

Raises:
ValueError: inputs not well-formed.
Expand Down