Skip to content

Commit

Permalink
Add option to automatically handle unsorted variable-length sequences…
Browse files Browse the repository at this point in the history
… in RNNs (#15225)

Summary:
Fixes #3584.

Motivation: manually sorting sequences, packing them, and then unsorting them
is something a lot of users have complained about doing, especially when we can
offer library support for them.

Overview: we internally sort sequences before packing them and store a list of
`unsorted_indices` that represent how to unsort the sequences inside
PackedSequence. The packing helper functions return PackedSequence with the
`permutation` field and the unpacking helper functions use it to unsort.

To implement this, the following changes were made:
- PackedSequence now keeps `sorted_indices` and `unsorted_indices`.
  These two can be thought of as permutations and are inverses of each other.
  `sorted_indices` is how the sequences were sorted; `unsorted_indices` is how
  to unsort the sequences.
- Added an `enforce_sorted` argument to pack_sequence and pack_padded_sequence
  that maintains the legacy behavior of error-ing out on unsorted-sequences.
  When `enforce_sorted=True`, these functions maintain their ONNX exportability.
- pack_sequence(sequences, enforce_sorted) takes in unsorted sequences.
- pack_padded_sequence can take in a padded tensor that represents padded,
  unsorted sequences.
- pad_packed_sequence unsorts the PackedSequence such that it is still the
  inverse operation of packed_padded_sequence.
- RNNs apply `sort_indices` to their input hidden state and apply
  `unsort_indices` to their output hidden state. This is to ensure that the
  hidden state batches correspond to the user's ordering of input sequences.

NOT BC-Breaking
- The default for pack_sequence and pack_padded_sequence is
  `enforce_sorted=True` to avoid breaking ONNX export. To use the new
  functionality, pass in `enforce_sorted=False`

Testing Plan
- Modified TestNN.test_pack_sequence, TestNN.test_packed_padded_sequence,
  and TestNN.test_variable_sequence (RNN test) to check the behavior
  of unsorted sequences, sorted sequences, and sorted sequences with
  enforce_sorted=True
- test/test_jit.py has a test to see if RNNs are exportable with
  enforce_sorted=True

cc colesbury
Pull Request resolved: #15225

Reviewed By: soumith

Differential Revision: D13507138

Pulled By: zou3519

fbshipit-source-id: b871dccd6abefffca81bc4e3efef1873faa242ef
  • Loading branch information
zou3519 authored and facebook-github-bot committed Dec 21, 2018
1 parent 52699f0 commit 3353064
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 121 deletions.
277 changes: 186 additions & 91 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,31 +123,36 @@ def test_type_casts(self):
"""Test type casting of `PackedSequence` against type casting of tensor"""
for _, (input_type, _) in self._type_by_name.items():
for expected_type_str, (_, cast_str) in self._type_by_name.items():
padded, lengths = self._padded_sequence(input_type)
packed = rnn_utils.pack_padded_sequence(padded, lengths)
# Apply cast to `PackedSequence` instance and unpack
masked = getattr(packed, cast_str)()
unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
self.assertEqual(unpacked.type(), expected_type_str)
for enforce_sorted in [True, False]:
padded, lengths = self._padded_sequence(input_type)
packed = rnn_utils.pack_padded_sequence(
padded, lengths, enforce_sorted=enforce_sorted)
# Apply cast to `PackedSequence` instance and unpack
masked = getattr(packed, cast_str)()
unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
self.assertEqual(unpacked.type(), expected_type_str)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_cuda_mask(self):
tensor_type = torch.FloatTensor
cuda_type_str = 'torch.cuda.FloatTensor'
padded, lengths = self._padded_sequence(tensor_type)
packed = rnn_utils.pack_padded_sequence(padded, lengths)
self.assertFalse(packed.is_cuda)
packed = packed.cuda()
self.assertTrue(packed.is_cuda)
unpacked, _ = rnn_utils.pad_packed_sequence(packed)
self.assertEqual(unpacked.type(), cuda_type_str)
for enforce_sorted in [True, False]:
tensor_type = torch.FloatTensor
cuda_type_str = 'torch.cuda.FloatTensor'
padded, lengths = self._padded_sequence(tensor_type)
packed = rnn_utils.pack_padded_sequence(
padded, lengths, enforce_sorted=enforce_sorted)
self.assertFalse(packed.is_cuda)
packed = packed.cuda()
self.assertTrue(packed.is_cuda)
unpacked, _ = rnn_utils.pad_packed_sequence(packed)
self.assertEqual(unpacked.type(), cuda_type_str)

def test_wrong_order(self):
# https://github.com/pytorch/pytorch/issues/13324
a = torch.ones(25, 300)
b = torch.ones(22, 300)
b_a = rnn_utils.pad_sequence([b, a])
self.assertRaises(RuntimeError, lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25]))
self.assertRaises(
RuntimeError,
lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True))

def test_total_length(self):
padded, lengths = self._padded_sequence(torch.FloatTensor)
Expand Down Expand Up @@ -183,22 +188,24 @@ def err_fn():
self.assertEqual(unpacked, ref_output)

def test_to(self):
padded, lengths = self._padded_sequence(torch.IntTensor)
a = rnn_utils.pack_padded_sequence(padded, lengths).cpu()

self.assertIs(a, a.to('cpu'))
self.assertIs(a, a.to('cpu', dtype=torch.int32))
self.assertEqual(a.long(), a.to(torch.int64))

if torch.cuda.is_available():
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
b = a.cuda(device=cuda)
self.assertIs(b, b.to(cuda))
self.assertEqual(a, b.to('cpu'))
self.assertEqual(b, a.to(cuda))
self.assertEqual(a, b.to('cpu', dtype=torch.int32))
self.assertIs(b, b.to(dtype=torch.int32))
self.assertEqual(b.long(), b.to(dtype=torch.int64))
for enforce_sorted in (True, False):
padded, lengths = self._padded_sequence(torch.IntTensor)
a = rnn_utils.pack_padded_sequence(
padded, lengths, enforce_sorted=enforce_sorted).cpu()

self.assertIs(a, a.to('cpu'))
self.assertIs(a, a.to('cpu', dtype=torch.int32))
self.assertEqual(a.long(), a.to(torch.int64))

if torch.cuda.is_available():
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
b = a.cuda(device=cuda)
self.assertIs(b, b.to(cuda))
self.assertEqual(a, b.to('cpu'))
self.assertEqual(b, a.to(cuda))
self.assertEqual(a, b.to('cpu', dtype=torch.int32))
self.assertIs(b, b.to(dtype=torch.int32))
self.assertEqual(b.long(), b.to(dtype=torch.int64))


def default_tensor_type(type):
Expand Down Expand Up @@ -4203,22 +4210,41 @@ def pad(tensor, length):
self.assertEqual(padded, expected.transpose(0, 1))

def test_pack_sequence(self):
def _compatibility_test(sequences, lengths, batch_first):
def _compatibility_test(sequences, lengths, batch_first, enforce_sorted=False):
padded = rnn_utils.pad_sequence(sequences, batch_first)
packed = rnn_utils.pack_sequence(sequences)
packed = rnn_utils.pack_sequence(sequences, enforce_sorted)
unpacked = rnn_utils.pad_packed_sequence(packed, batch_first)
self.assertEqual(padded, unpacked[0])
pack_padded = rnn_utils.pack_padded_sequence(padded, lengths, batch_first)
pack_padded = rnn_utils.pack_padded_sequence(
padded, lengths, batch_first, enforce_sorted)
self.assertEqual(packed, pack_padded)

# single dimensional
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
packed = rnn_utils.pack_sequence([a, b, c])
packed = rnn_utils.pack_sequence([a, b, c], enforce_sorted=False)
expected = torch.tensor([1, 4, 6, 2, 5, 3])
self.assertEqual(packed.batch_sizes, [3, 2, 1])
self.assertEqual(packed.data.data, expected)
self.assertEqual(packed.sorted_indices, [0, 1, 2])
self.assertEqual(packed.unsorted_indices, [0, 1, 2])

packed_unsorted = rnn_utils.pack_sequence([b, c, a], enforce_sorted=False)
self.assertEqual(packed_unsorted.batch_sizes, [3, 2, 1])
self.assertEqual(packed_unsorted.data.data, expected)
self.assertEqual(packed_unsorted.sorted_indices, [2, 0, 1])
self.assertEqual(packed_unsorted.unsorted_indices, [1, 2, 0])

# single dimensional, enforce_sorted = True
packed_enforce_sorted = rnn_utils.pack_sequence([a, b, c], enforce_sorted=True)
self.assertEqual(packed_enforce_sorted.batch_sizes, [3, 2, 1])
self.assertEqual(packed_enforce_sorted.data.data, expected)
self.assertTrue(packed_enforce_sorted.sorted_indices is None)
self.assertTrue(packed_enforce_sorted.unsorted_indices is None)

with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'):
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)

# more dimensions
maxlen = 9
Expand All @@ -4230,34 +4256,69 @@ def _compatibility_test(sequences, lengths, batch_first):
seq_len = i * i
lengths.append(seq_len)
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
unsorted_sequences = [s.clone() for s in sequences]
random.shuffle(unsorted_sequences)
unsorted_sequences_lengths = [t.size(0) for t in unsorted_sequences]

# compatibility with other utilities
for batch_first in (True, False):
_compatibility_test(sequences, lengths, batch_first)
for enforce_sorted in (True, False):
_compatibility_test(sequences, lengths, batch_first, enforce_sorted)
_compatibility_test(unsorted_sequences, unsorted_sequences_lengths,
batch_first)

def test_pack_padded_sequence(self):
def pad(tensor, length):
return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()])
lengths = [10, 8, 4, 2, 2, 2, 1]
max_length = lengths[0]
batch_sizes = [sum(map(bool, filter(lambda x: x >= i, lengths))) for i in range(1, max_length + 1)]
offset = 0
padded = torch.cat([pad(i * 100 + torch.arange(1., 5 * l + 1).view(l, 1, 5), max_length)
for i, l in enumerate(lengths, 1)], 1).requires_grad_()
expected_data = [[torch.arange(1., 6) + (i + 1) * 100 + 5 * n for i in range(batch_size)]
for n, batch_size in enumerate(batch_sizes)]
expected_data = list(itertools.chain.from_iterable(expected_data))
expected_data = torch.stack(expected_data, dim=0)
def generate_test_case(sorted_lengths, should_shuffle):
def pad(tensor, length):
return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()])

max_length = sorted_lengths[0]
batch_sizes = [sum(map(bool, filter(lambda x: x >= i, sorted_lengths)))
for i in range(1, max_length + 1)]
offset = 0
padded = torch.cat([pad(i * 100 + torch.arange(1., 5 * l + 1).view(l, 1, 5), max_length)
for i, l in enumerate(sorted_lengths, 1)], 1)
expected_data = [[torch.arange(1., 6) + (i + 1) * 100 + 5 * n for i in range(batch_size)]
for n, batch_size in enumerate(batch_sizes)]
expected_data = list(itertools.chain.from_iterable(expected_data))
expected_data = torch.stack(expected_data, dim=0)

if should_shuffle:
# Shuffle the padded sequence to create an unsorted sequence
permutation = list(range(len(sorted_lengths)))
random.shuffle(permutation)

unsorted_indices = torch.tensor(permutation)
padded = padded.index_select(1, unsorted_indices)
lengths = torch.tensor(sorted_lengths).index_select(0, unsorted_indices)
else:
unsorted_indices = None
lengths = sorted_lengths

return padded.requires_grad_(), lengths, expected_data, batch_sizes, unsorted_indices

test_cases = [
# sorted_lengths, should_shuffle
[[10, 8, 4, 2, 2, 2, 1], False],
[[11, 10, 8, 6, 4, 3, 1], False],
[[11, 10, 8, 6, 4, 3, 1], True],
]

for test_case, batch_first in itertools.product(test_cases, (True, False)):
sorted_lengths, should_shuffle = test_case
padded, lengths, expected_data, batch_sizes, unsorted_indices = generate_test_case(
sorted_lengths, should_shuffle)

for batch_first in (True, False):
src = padded
if batch_first:
src = src.transpose(0, 1)

# check output
packed = rnn_utils.pack_padded_sequence(src, lengths, batch_first=batch_first)
packed = rnn_utils.pack_padded_sequence(src, lengths, batch_first=batch_first,
enforce_sorted=not should_shuffle)
self.assertEqual(packed.data.data, expected_data)
self.assertEqual(packed.batch_sizes, batch_sizes)
self.assertEqual(packed.unsorted_indices, unsorted_indices)

# test inverse
unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed, batch_first=batch_first)
Expand All @@ -4282,46 +4343,80 @@ def pad(var, length):
return var
return torch.cat([var, var.new_zeros(length - var.size(0), *var.size()[1:])])

lengths = [10, 10, 6, 2, 2, 1, 1]
max_length = lengths[0]
x_leaf = torch.randn(max_length, len(lengths), 3, device=device, dtype=dtype, requires_grad=True)
lstm = nn.LSTM(3, 4, bidirectional=True, num_layers=2).to(device, dtype)
lstm2 = deepcopy(lstm).to(device, dtype)
x = x_leaf

# Compute sequences separately
seq_outs = []
seq_hiddens = []
for i, l in enumerate(lengths):
out, hid = lstm2(x[:l, i:i + 1])
out_pad = pad(out, max_length)
seq_outs.append(out_pad)
seq_hiddens.append(hid)
seq_out = torch.cat(seq_outs, 1)
seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))

# Use packed format
packed = rnn_utils.pack_padded_sequence(x, lengths)
packed_out, packed_hidden = lstm(packed)
unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)
def maybe_index_tuple(maybe_tuple_of_tensors, index):
if maybe_tuple_of_tensors is None:
return None
return tuple(maybe_tuple_of_tensors[j][:, index:index + 1, :].contiguous()
for j in range(2))

def check_lengths(lengths, enforce_sorted, use_default_hiddens):
input_size = 3
hidden_size = 4
num_layers = 2
bidirectional = True

max_length = max(lengths)
x_leaf = torch.randn(max_length, len(lengths), input_size, device=device,
dtype=dtype, requires_grad=True)
num_directions = 2 if bidirectional else 1
lstm = nn.LSTM(input_size, hidden_size, bidirectional=bidirectional,
num_layers=num_layers).to(device, dtype)
lstm2 = deepcopy(lstm).to(device, dtype)
x = x_leaf

hidden0 = None
if not use_default_hiddens:
hidden0 = tuple(torch.randn(num_directions * num_layers, len(lengths), hidden_size,
device=device, dtype=dtype)
for _ in range(2))

# Compute sequences separately
seq_outs = []
seq_hiddens = []
for i, l in enumerate(lengths):
hidden_i = maybe_index_tuple(hidden0, i)
out, hid = lstm2(x[:l, i:i + 1], hidden_i)
out_pad = pad(out, max_length)
seq_outs.append(out_pad)
seq_hiddens.append(hid)
seq_out = torch.cat(seq_outs, 1)
seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))

# Use packed format
packed = rnn_utils.pack_padded_sequence(x, lengths, enforce_sorted=enforce_sorted)
packed_out, packed_hidden = lstm(packed, hidden0)
unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)

# Check forward
self.assertEqual(packed_hidden, seq_hidden)
self.assertEqual(unpacked, seq_out)
self.assertEqual(unpacked_len, lengths)

# Check forward
self.assertEqual(packed_hidden, seq_hidden)
self.assertEqual(unpacked, seq_out)
self.assertEqual(unpacked_len, lengths)

# Check backward
seq_out.sum().backward()
grad_x = x_leaf.grad.data.clone()
x_leaf.grad.data.zero_()
unpacked.sum().backward()

self.assertEqual(x_leaf.grad, grad_x, dtype2prec[dtype])
for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
prec = dtype2prec[dtype]
if dtype == torch.float16:
prec = 2e-2
self.assertEqual(p1.grad, p2.grad, prec)
# Check backward
seq_out.sum().backward()
grad_x = x_leaf.grad.data.clone()
x_leaf.grad.data.zero_()
unpacked.sum().backward()

self.assertEqual(x_leaf.grad, grad_x, dtype2prec[dtype])
for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
prec = dtype2prec[dtype]
if dtype == torch.float16:
prec = 2e-2
self.assertEqual(p1.grad, p2.grad, prec)

tests = [
# enforce_sorted, lengths
[True, [5]],
[False, [5]],
[True, [10, 10, 6, 2, 2, 1, 1]],
[False, [10, 10, 6, 2, 2, 1, 1]],
[False, [2, 1, 3, 2, 10, 5, 3]],
]

for enforce_sorted, seq_lens, in tests:
for use_default_hiddens in (True, False):
check_lengths(seq_lens, enforce_sorted, use_default_hiddens)

def test_variable_sequence(self):
self._test_variable_sequence()
Expand Down

0 comments on commit 3353064

Please sign in to comment.