From 71bdfe83a5a25cfc6b4b5f93f713c38ae19bcd3c Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Fri, 19 Apr 2019 14:26:04 -0700 Subject: [PATCH 01/19] setup sparse half tensors --- aten/src/ATen/core/Type.h | 2 ++ aten/src/ATen/gen.py | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index d86877c736a2e..9cb080b32262b 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -63,6 +63,7 @@ enum class TypeID { SparseCPUInt, SparseCPULong, SparseCPUShort, + SparseCPUHalf, SparseCPUQInt8, MkldnnCPUFloat, CUDABool, @@ -83,6 +84,7 @@ enum class TypeID { SparseCUDAInt, SparseCUDALong, SparseCUDAShort, + SparseCUDAHalf, SparseCUDAQInt8, QuantizedCPUQInt8, MSNPUBool, diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index b89dd029c5ee0..7eabe0d904876 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -435,9 +435,6 @@ def iterate_types(): for scalar_type in scalar_types: if density == 'Mkldnn' and (backend != 'CPU' or scalar_type[0] != 'Float'): continue - if density == 'Sparse' and scalar_type[0] == 'Half': - # THS does not do half type yet. - continue if backend in quantized_backends: if density == 'Dense' and scalar_type in quantized_scalar_types: yield (backend, density, scalar_type) From 68a201af758f75c505c8525d8440c67b254cd5fc Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 24 Apr 2019 17:56:42 -0400 Subject: [PATCH 02/19] add test for sparse half cuda embedding --- test/test_nn.py | 49 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 2366806802a23..9d50bc6730463 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2089,6 +2089,42 @@ def test_embedding_sparse_backward(self): self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 8, 1, 3]])) self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3)) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_embedding_sparse_half_backward(self): + # same as test_embedding_sparse_backward above but testing half types in + # cuda. cpu sum not supported. + embedding = nn.Embedding(10, 3, sparse=True).half() + longTensor = torch.LongTensor([[7, 1, 3]]) + ones = torch.tensor(1.).expand(3, 3).half() + longTwice = torch.LongTensor([[7, 1, 3, 7, 1, 3]]) + onesTwice = torch.tensor(1.).expand(6, 3) + + embedding = embedding.cuda() + longTensor = longTensor.cuda() + ones = ones.cuda() + longTwice = longTwice.cuda() + onesTwice = onesTwice.cuda() + + embedding.zero_grad() + embedding(longTensor).sum().backward() + self.assertEqual(embedding.weight.grad._indices(), longTensor) + self.assertEqual(embedding.weight.grad._values(), ones) + + embedding.zero_grad() + embedding(longTensor).sum().backward() + embedding(longTensor).sum().backward() + self.assertEqual(embedding.weight.grad._indices(), longTwice) + self.assertEqual(embedding.weight.grad._values(), onesTwice) + + embedding.zero_grad() + embedding(longTensor[0]).sum().backward() + longTensor[0, 0] = 8 + embedding(longTensor).sum().backward() + longTwice[0, 3] = 8 + self.assertEqual(embedding.weight.grad._indices(), longTwice) + self.assertEqual(embedding.weight.grad._values(), onesTwice) + def test_embedding_padding_idx(self): embedding = nn.Embedding(10, 20, padding_idx=0) input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])) @@ -2374,7 +2410,11 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, # We have more floating point error here because we are dealing with larger numbers if backward_prec is None: - needed_prec = dtype2prec[dtype] * 2 + # torch.half is particularly imprecise + if dtype == torch.half: + needed_prec = dtype2prec[dtype] * 3 + else: + needed_prec = dtype2prec[dtype] * 2 else: needed_prec = backward_prec self.assertEqual(es_weight_grad, e.weight.grad, needed_prec) @@ -2730,10 +2770,9 @@ def test_embedding_bag_cuda(self, dtype=torch.float): self._test_EmbeddingBag(True, 'sum', False, dtype) self._test_EmbeddingBag(True, 'mean', False, dtype) self._test_EmbeddingBag(True, 'max', False, dtype) - if dtype != torch.half: - # torch.cuda.sparse.HalfTensor is not enabled. - self._test_EmbeddingBag(True, 'sum', True, dtype) - self._test_EmbeddingBag(True, 'mean', True, dtype) + + self._test_EmbeddingBag(True, 'sum', True, dtype) + self._test_EmbeddingBag(True, 'mean', True, dtype) def test_fractional_max_pool2d(self): x = torch.randn(1, 2, 7, 7, requires_grad=True) From 227c4e99ff1c9a1760a612db0c67c0a38452de0d Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Mon, 29 Apr 2019 06:35:40 -0700 Subject: [PATCH 03/19] fix flake8 error --- test/test_nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_nn.py b/test/test_nn.py index 2b633d3297ad1..90a7dfdade1a3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2089,7 +2089,6 @@ def test_embedding_sparse_backward(self): self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 8, 1, 3]])) self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3)) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_embedding_sparse_half_backward(self): # same as test_embedding_sparse_backward above but testing half types in From 1201f1d4ec15dc9f60c505f2acd44ffcffff8f41 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Mon, 29 Apr 2019 14:52:33 -0400 Subject: [PATCH 04/19] fix flaky low-precision test and clarify precision issue --- test/test_nn.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 90a7dfdade1a3..ae5c056b24d74 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2409,13 +2409,10 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, # We have more floating point error here because we are dealing with larger numbers if backward_prec is None: - # torch.half is particularly imprecise - if dtype == torch.half: - needed_prec = dtype2prec[dtype] * 3 - else: - needed_prec = dtype2prec[dtype] * 2 + needed_prec = dtype2prec[dtype] * 2 else: needed_prec = backward_prec + self.assertEqual(es_weight_grad, e.weight.grad, needed_prec) if test_per_sample_weights and trainable_per_sample_weights: @@ -2603,12 +2600,13 @@ def test_contig_wrong_stride_cudnn(self): def test_embedding_bag(self): for dtype in [torch.double, torch.float]: - # TODO: figure out why backward on float breaks - test_backward = dtype is not torch.float - self._test_EmbeddingBag(False, 'sum', False, test_backward=test_backward, dtype=dtype) - self._test_EmbeddingBag(False, 'mean', False, test_backward=test_backward, dtype=dtype) - self._test_EmbeddingBag(False, 'max', False, test_backward=test_backward, dtype=dtype) + self._test_EmbeddingBag(False, 'sum', False, dtype=dtype) + self._test_EmbeddingBag(False, 'mean', False, dtype=dtype) + self._test_EmbeddingBag(False, 'max', False, dtype=dtype) + # TODO: figure out why precision on sparse embeddings isn't the + # same as for dense. + test_backward = dtype is not torch.float self._test_EmbeddingBag(False, 'sum', True, test_backward=test_backward, dtype=dtype) self._test_EmbeddingBag(False, 'mean', True, test_backward=test_backward, dtype=dtype) @@ -2773,8 +2771,10 @@ def test_embedding_bag_cuda(self, dtype=torch.float): self._test_EmbeddingBag(True, 'mean', False, dtype) self._test_EmbeddingBag(True, 'max', False, dtype) - self._test_EmbeddingBag(True, 'sum', True, dtype) - self._test_EmbeddingBag(True, 'mean', True, dtype) + # see test_embedding_bag + test_backward = dtype is not torch.float16 + self._test_EmbeddingBag(True, 'sum', True, dtype, test_backward=test_backward) + self._test_EmbeddingBag(True, 'mean', True, dtype, test_backward=test_backward) def test_fractional_max_pool2d(self): x = torch.randn(1, 2, 7, 7, requires_grad=True) From 3900816c912290f33ad5b179f6c2cfe078f99150 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 1 May 2019 09:28:26 -0400 Subject: [PATCH 05/19] simplify tests --- test/test_nn.py | 62 ++++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index ae5c056b24d74..bfd892b9b12b2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2071,57 +2071,45 @@ def test_embedding_dense_grad_cuda(self): self._test_embedding_dense_grad("cuda") def test_embedding_sparse_backward(self): - embedding = nn.Embedding(10, 3, sparse=True) - embedding.zero_grad() - embedding(torch.LongTensor([7, 1, 3])).sum().backward() - self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3]])) - self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(3, 3)) - - embedding.zero_grad() - embedding(torch.LongTensor([7, 1, 3])).sum().backward() - embedding(torch.LongTensor([7, 1, 3])).sum().backward() - self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 7, 1, 3]])) - self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3)) - - embedding.zero_grad() - embedding(torch.LongTensor([7, 1, 3])).sum().backward() - embedding(torch.LongTensor([8, 1, 3])).sum().backward() - self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 8, 1, 3]])) - self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3)) + self._test_embedding_backward(False) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_embedding_sparse_half_backward(self): # same as test_embedding_sparse_backward above but testing half types in - # cuda. cpu sum not supported. - embedding = nn.Embedding(10, 3, sparse=True).half() - longTensor = torch.LongTensor([[7, 1, 3]]) + # cuda. cpu sum not supported for half types. + self._test_embedding_backward(True) + + def _test_embedding_backward(self, half_cuda=True): + embedding = nn.Embedding(10, 3, sparse=True) + tensor = torch.tensor([[7, 1, 3]]) ones = torch.tensor(1.).expand(3, 3).half() - longTwice = torch.LongTensor([[7, 1, 3, 7, 1, 3]]) + tensorTwice = torch.tensor([[7, 1, 3, 7, 1, 3]]) onesTwice = torch.tensor(1.).expand(6, 3) - embedding = embedding.cuda() - longTensor = longTensor.cuda() - ones = ones.cuda() - longTwice = longTwice.cuda() - onesTwice = onesTwice.cuda() + if half_cuda: + embedding = embedding.half().cuda() + tensor = tensor.cuda() + ones = ones.cuda() + tensorTwice = tensorTwice.cuda() + onesTwice = onesTwice.cuda() embedding.zero_grad() - embedding(longTensor).sum().backward() - self.assertEqual(embedding.weight.grad._indices(), longTensor) + embedding(tensor).sum().backward() + self.assertEqual(embedding.weight.grad._indices(), tensor) self.assertEqual(embedding.weight.grad._values(), ones) embedding.zero_grad() - embedding(longTensor).sum().backward() - embedding(longTensor).sum().backward() - self.assertEqual(embedding.weight.grad._indices(), longTwice) + embedding(tensor).sum().backward() + embedding(tensor).sum().backward() + self.assertEqual(embedding.weight.grad._indices(), tensorTwice) self.assertEqual(embedding.weight.grad._values(), onesTwice) embedding.zero_grad() - embedding(longTensor[0]).sum().backward() - longTensor[0, 0] = 8 - embedding(longTensor).sum().backward() - longTwice[0, 3] = 8 - self.assertEqual(embedding.weight.grad._indices(), longTwice) + embedding(tensor[0]).sum().backward() + tensor[0, 0] = 8 + embedding(tensor).sum().backward() + tensorTwice[0, 3] = 8 + self.assertEqual(embedding.weight.grad._indices(), tensorTwice) self.assertEqual(embedding.weight.grad._values(), onesTwice) def test_embedding_padding_idx(self): @@ -2771,7 +2759,7 @@ def test_embedding_bag_cuda(self, dtype=torch.float): self._test_EmbeddingBag(True, 'mean', False, dtype) self._test_EmbeddingBag(True, 'max', False, dtype) - # see test_embedding_bag + # see 'todo' in test_embedding_bag. test_backward = dtype is not torch.float16 self._test_EmbeddingBag(True, 'sum', True, dtype, test_backward=test_backward) self._test_EmbeddingBag(True, 'mean', True, dtype, test_backward=test_backward) From 79787aabe369fa8265cb77d016b80076b140e22f Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 1 May 2019 13:47:31 -0400 Subject: [PATCH 06/19] add suggested test for constructing/moving sparse half embeddings --- test/test_nn.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index bfd892b9b12b2..b37d3e5670088 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2070,6 +2070,21 @@ def test_embedding_dense_grad(self): def test_embedding_dense_grad_cuda(self): self._test_embedding_dense_grad("cuda") + def test_move_sparse_half_embedding(self): + embedding = nn.Embedding(10, 3, sparse=True) + self.assertEqual(embedding.weight.device.type, 'cpu') + self.assertEqual(embedding.weight.dtype, torch.float64) + embedding.half() + self.assertEqual(embedding.weight.dtype, torch.float16) + self.assertEqual(embedding.embedding_dim, 3) + self.assertEqual(embedding.num_embeddings, 10) + + if torch.cuda.is_available(): + embedding.cuda() + self.assertEqual(embedding.weight.device.type, 'cuda') + embedding.cpu() + self.assertEqual(embedding.weight.device.type, 'cpu') + def test_embedding_sparse_backward(self): self._test_embedding_backward(False) @@ -2082,14 +2097,14 @@ def test_embedding_sparse_half_backward(self): def _test_embedding_backward(self, half_cuda=True): embedding = nn.Embedding(10, 3, sparse=True) tensor = torch.tensor([[7, 1, 3]]) - ones = torch.tensor(1.).expand(3, 3).half() + ones = torch.tensor(1.).expand(3, 3) tensorTwice = torch.tensor([[7, 1, 3, 7, 1, 3]]) onesTwice = torch.tensor(1.).expand(6, 3) if half_cuda: embedding = embedding.half().cuda() tensor = tensor.cuda() - ones = ones.cuda() + ones = ones.cuda().half() tensorTwice = tensorTwice.cuda() onesTwice = onesTwice.cuda() From bf286f79462f81b48e98d4038e40a1d932d095d1 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 1 May 2019 14:36:42 -0400 Subject: [PATCH 07/19] cleanup sparse half embedding tests per review --- test/test_nn.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index b37d3e5670088..7ec9e02e1aece 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2074,39 +2074,38 @@ def test_move_sparse_half_embedding(self): embedding = nn.Embedding(10, 3, sparse=True) self.assertEqual(embedding.weight.device.type, 'cpu') self.assertEqual(embedding.weight.dtype, torch.float64) - embedding.half() + embedding.to(torch.float16) self.assertEqual(embedding.weight.dtype, torch.float16) self.assertEqual(embedding.embedding_dim, 3) self.assertEqual(embedding.num_embeddings, 10) if torch.cuda.is_available(): - embedding.cuda() + embedding.to('cuda') self.assertEqual(embedding.weight.device.type, 'cuda') - embedding.cpu() + embedding.to('cpu') self.assertEqual(embedding.weight.device.type, 'cpu') def test_embedding_sparse_backward(self): - self._test_embedding_backward(False) + self._test_embedding_backward() @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_embedding_sparse_half_backward(self): # same as test_embedding_sparse_backward above but testing half types in # cuda. cpu sum not supported for half types. - self._test_embedding_backward(True) + self._test_embedding_backward('cuda', torch.float16) - def _test_embedding_backward(self, half_cuda=True): + def _test_embedding_backward(self, device='cpu', dtype=torch.float64): embedding = nn.Embedding(10, 3, sparse=True) tensor = torch.tensor([[7, 1, 3]]) ones = torch.tensor(1.).expand(3, 3) - tensorTwice = torch.tensor([[7, 1, 3, 7, 1, 3]]) + tensorTwice = tensor.repeat(1, 2) onesTwice = torch.tensor(1.).expand(6, 3) - if half_cuda: - embedding = embedding.half().cuda() - tensor = tensor.cuda() - ones = ones.cuda().half() - tensorTwice = tensorTwice.cuda() - onesTwice = onesTwice.cuda() + embedding = embedding.to(dtype=dtype).to(device) + tensor = tensor.to(device) + ones = ones.to(device) + tensorTwice = tensorTwice.to(device) + onesTwice = torch.cat((ones, ones)) embedding.zero_grad() embedding(tensor).sum().backward() @@ -2122,7 +2121,7 @@ def _test_embedding_backward(self, half_cuda=True): embedding.zero_grad() embedding(tensor[0]).sum().backward() tensor[0, 0] = 8 - embedding(tensor).sum().backward() + embedding(tensor[0]).sum().backward() tensorTwice[0, 3] = 8 self.assertEqual(embedding.weight.grad._indices(), tensorTwice) self.assertEqual(embedding.weight.grad._values(), onesTwice) From 6cfebce30cf7141cd95b83550ec6e4f3a5af74b3 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 1 May 2019 14:40:34 -0400 Subject: [PATCH 08/19] fix typo in test --- test/test_nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 7ec9e02e1aece..beb9fa3eef13c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2099,13 +2099,13 @@ def _test_embedding_backward(self, device='cpu', dtype=torch.float64): tensor = torch.tensor([[7, 1, 3]]) ones = torch.tensor(1.).expand(3, 3) tensorTwice = tensor.repeat(1, 2) - onesTwice = torch.tensor(1.).expand(6, 3) + onesTwice = torch.cat((ones, ones)) embedding = embedding.to(dtype=dtype).to(device) tensor = tensor.to(device) ones = ones.to(device) tensorTwice = tensorTwice.to(device) - onesTwice = torch.cat((ones, ones)) + onesTwice = onesTwice.to(device) embedding.zero_grad() embedding(tensor).sum().backward() From f7a10b117912cbdc89af3f77c70a612a178e17c8 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 1 May 2019 14:44:50 -0400 Subject: [PATCH 09/19] update embedding inputs to consistently use one dimension --- test/test_nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index beb9fa3eef13c..ae9ab0ac01467 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2108,13 +2108,13 @@ def _test_embedding_backward(self, device='cpu', dtype=torch.float64): onesTwice = onesTwice.to(device) embedding.zero_grad() - embedding(tensor).sum().backward() + embedding(tensor[0]).sum().backward() self.assertEqual(embedding.weight.grad._indices(), tensor) self.assertEqual(embedding.weight.grad._values(), ones) embedding.zero_grad() - embedding(tensor).sum().backward() - embedding(tensor).sum().backward() + embedding(tensor[0]).sum().backward() + embedding(tensor[0]).sum().backward() self.assertEqual(embedding.weight.grad._indices(), tensorTwice) self.assertEqual(embedding.weight.grad._values(), onesTwice) From 29a72ea101cc8597d50b150345631883fe8d9b0b Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 1 May 2019 17:46:12 -0400 Subject: [PATCH 10/19] add tests and support for torch.sparse.HalfTensor constructor --- test/test_sparse.py | 73 +++++++++++++++++++------------ torch/csrc/utils/tensor_types.cpp | 2 +- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/test/test_sparse.py b/test/test_sparse.py index 764f0a38c5520..7c02c58ba57d7 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -563,6 +563,12 @@ def test_Sparse_to_Sparse_copy_(self): # test type conversion (when x1.copy_(x2), x1.dtype should stay the same) x1 = x1.to(torch.float32) + + x2 = x2.to(torch.float16) + x1_dtype = x1.dtype + x1.copy_(x2) + self.assertEqual(x1_dtype, x1.dtype) + x2 = x2.to(torch.float64) x1_dtype = x1.dtype x1.copy_(x2) @@ -630,6 +636,9 @@ def test_tensor(x): x = torch.sparse.FloatTensor(2, 3, 4) test_tensor(x) + x = torch.sparse.HalfTensor(2, 3, 4) + test_tensor(x) + x = torch.sparse.FloatTensor(2, 3, 4, 0) test_tensor(x) @@ -1512,33 +1521,33 @@ def test_factory(self): for use_tensor_idx in [True, False]: for use_tensor_val in [True, False]: for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]): - # have to include size with cuda sparse tensors - include_size = include_size or use_cuda - dtype = torch.float64 - long_dtype = torch.int64 - device = torch.device('cpu') if not use_cuda else \ - torch.device(torch.cuda.device_count() - 1) - indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2]) - if test_empty_tensor: - values = self.value_empty(1, 0) - else: - if use_tensor_val: - values = torch.tensor([1.], dtype=dtype) + for dtype in [torch.float64, torch.float16]: + # have to include size with cuda sparse tensors + include_size = include_size or use_cuda + long_dtype = torch.int64 + device = torch.device('cpu') if not use_cuda else \ + torch.device(torch.cuda.device_count() - 1) + indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2]) + if test_empty_tensor: + values = self.value_empty(1, 0) + else: + if use_tensor_val: + values = torch.tensor([1.], dtype=dtype) + else: + values = 1. + if include_size: + sparse_tensor = torch.sparse_coo_tensor(indices, values, size, dtype=dtype, + device=device, requires_grad=True) else: - values = 1. - if include_size: - sparse_tensor = torch.sparse_coo_tensor(indices, values, size, dtype=dtype, - device=device, requires_grad=True) - else: - sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=dtype, - device=device, requires_grad=True) - self.assertEqual(indices, sparse_tensor._indices()) - self.assertEqual(values, sparse_tensor._values()) - self.assertEqual(size if include_size else default_size, sparse_tensor.size()) - self.assertEqual(dtype, sparse_tensor.dtype) - if use_cuda: - self.assertEqual(device, sparse_tensor._values().device) - self.assertEqual(True, sparse_tensor.requires_grad) + sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=dtype, + device=device, requires_grad=True) + self.assertEqual(indices, sparse_tensor._indices()) + self.assertEqual(values, sparse_tensor._values()) + self.assertEqual(size if include_size else default_size, sparse_tensor.size()) + self.assertEqual(dtype, sparse_tensor.dtype) + if use_cuda: + self.assertEqual(device, sparse_tensor._values().device) + self.assertEqual(True, sparse_tensor.requires_grad) def test_factory_size_check(self): indices = self.index_tensor([[1, 2], @@ -1653,6 +1662,8 @@ def test_factory_dense_dim(self): @cpu_only def test_factory_type_inference(self): + t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=torch.float16)) + self.assertEqual(torch.float16, t.dtype) t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=torch.float32)) self.assertEqual(torch.float32, t.dtype) t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=torch.float64)) @@ -1660,6 +1671,8 @@ def test_factory_type_inference(self): t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1])) self.assertEqual(torch.int64, t.dtype) + t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.HalfTensor(1, 0)) + self.assertEqual(torch.float16, t.dtype) t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.FloatTensor(1, 0)) self.assertEqual(torch.float32, t.dtype) t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.DoubleTensor(1, 0)) @@ -1713,6 +1726,10 @@ def test_tensor(indices, values, indices_equal, values_equal): values = torch.tensor([1.], dtype=torch.float32) test_tensor(indices, values, True, False) + indices = torch.tensor(([0], [2]), dtype=torch.int64) + values = torch.tensor([1.], dtype=torch.float16) + test_tensor(indices, values, True, False) + indices = torch.tensor(([0], [2]), dtype=torch.int64) values = torch.FloatTensor(1, 0) test_tensor(indices, values, True, True) # An empty tensor's data_ptr is always equal to 0 @@ -1766,14 +1783,14 @@ def test_constructor_device_legacy(self): @cpu_only # not really, but we only really want to run this once def test_dtypes(self): - all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16] + all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes()] do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu')) if torch.cuda.is_available(): do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0')) @cpu_only # not really, but we only really want to run this once def test_empty_full(self): - all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16] + all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes()] do_test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu')) if torch.cuda.device_count() > 0: do_test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, None) diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index 2b83d57d924a1..989fd9dfae448 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -86,7 +86,7 @@ std::vector> all_declared_types() { for (auto& backend : backends) { for (auto& scalar_type : scalar_types) { // there are no sparse half or bool types. - if ((scalar_type == ScalarType::Half || scalar_type == ScalarType::Bool) && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) { + if (scalar_type == ScalarType::Bool && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) { continue; } ret.emplace_back(std::make_pair(backend, scalar_type)); From 2d1be262b376c7b089c114382cb06f8988a93ae8 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Mon, 6 May 2019 12:25:21 -0400 Subject: [PATCH 11/19] add half ops needed for to_sparse/to_dense --- aten/src/ATen/Declarations.cwrap | 1 + aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 25 ++++--- aten/src/ATen/native/native_functions.yaml | 4 + aten/src/ATen/native/sparse/SparseTensor.cpp | 2 +- .../ATen/native/sparse/SparseTensorMath.cpp | 2 +- aten/src/TH/THBlas.cpp | 3 + aten/src/TH/THBlas.h | 3 + aten/src/TH/THBlasUtils.h | 6 +- aten/src/TH/THGenerateFloatType.h | 6 ++ aten/src/TH/THGenerateHalfType.h | 2 +- aten/src/TH/THTensor.h | 3 + aten/src/TH/THTensorEvenMoreMath.cpp | 3 + aten/src/TH/generic/THTensorEvenMoreMath.cpp | 6 +- test/test_sparse.py | 73 ++++++++++--------- torch/csrc/utils/tensor_types.cpp | 2 +- 15 files changed, 90 insertions(+), 51 deletions(-) diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index bb8cfa652a33c..f5a3176733e02 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -111,6 +111,7 @@ [[ name: _th_nonzero cname: nonzero + cpu_half: True cpu_bool: True cuda_bool: True variants: diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 431a6dc88a985..6fdf9372a0b2c 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -14,15 +14,22 @@ namespace { using namespace vec256; void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_cpu", [&]() { - auto alpha = alpha_scalar.to(); - auto alpha_vec = Vec256(alpha); - binary_kernel_vec(iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; }, - [=](Vec256 a, Vec256 b) { - return vec256::fmadd(b, alpha_vec, a); - }); - }); + if( iter.dtype() == ScalarType::Half ) { + auto alpha = alpha_scalar.to(); + binary_kernel(iter, [&](Half a, Half b) -> Half { + return a + (alpha * b); + }); + } else { + AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_cpu", [&]() { + auto alpha = alpha_scalar.to(); + auto alpha_vec = Vec256(alpha); + binary_kernel_vec(iter, + [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; }, + [=](Vec256 a, Vec256 b) { + return vec256::fmadd(b, alpha_vec, a); + }); + }); + } } void sub_kernel(TensorIterator& iter, Scalar alpha_scalar) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 27fc8921752c5..95dd4c5e512e0 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1656,6 +1656,7 @@ SparseCUDA: add_out_sparse_cuda - func: _sparse_dense_add(Tensor self, SparseTensorRef other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + cpu_half: True dispatch: CPU: add_out_dense_sparse_cpu CUDA: add_out_dense_sparse_cuda @@ -2386,6 +2387,7 @@ - func: coalesce(Tensor self) -> Tensor + cpu_half: True variants: method dispatch: SparseCPU: coalesce_sparse_cpu @@ -2473,12 +2475,14 @@ variants: function, method - func: to_sparse(Tensor self, int sparse_dim) -> Tensor + cpu_half: True variants: method dispatch: CPU: dense_to_sparse CUDA: dense_to_sparse - func: to_sparse(Tensor self) -> Tensor + cpu_half: True variants: method dispatch: CPU: dense_to_sparse diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index a8cf6c97fca1c..e84216f2bf536 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -376,7 +376,7 @@ SparseTensor coalesce_sparse_cpu(const SparseTensor& self) { auto indicesBufferAccessor = indicesBuffer.accessor(); int64_t i = -1; - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, values.scalar_type(), "coalesce", [&] { int64_t prev = -1; int64_t blockSize = values.stride(0); diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index e7fc355d969a5..9a2e1dc005352 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -364,7 +364,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, SparseTensorRef dstBuffer.add_(srcBuffer, value); } } else { - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, values.scalar_type(), "add_dense_sparse", [&] { add_dense_sparse_worker_cpu(r, value, sparse, indices, values); }); diff --git a/aten/src/TH/THBlas.cpp b/aten/src/TH/THBlas.cpp index a98097db9d005..ba695c352240a 100644 --- a/aten/src/TH/THBlas.cpp +++ b/aten/src/TH/THBlas.cpp @@ -2,3 +2,6 @@ #include #include + +#include +#include diff --git a/aten/src/TH/THBlas.h b/aten/src/TH/THBlas.h index 3911861f9005c..a62e186cf0f4f 100644 --- a/aten/src/TH/THBlas.h +++ b/aten/src/TH/THBlas.h @@ -8,4 +8,7 @@ #include #include +#include +#include + #endif diff --git a/aten/src/TH/THBlasUtils.h b/aten/src/TH/THBlasUtils.h index ffc39822fe9cb..9d61e5feab29a 100644 --- a/aten/src/TH/THBlasUtils.h +++ b/aten/src/TH/THBlasUtils.h @@ -16,7 +16,7 @@ inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_axpy(n, a, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(AXPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(AXPY_SPECIALIZATION) template @@ -29,7 +29,7 @@ inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_copy(n, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(COPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(COPY_SPECIALIZATION) template inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); @@ -40,4 +40,4 @@ inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); return TH ## name ## Blas_dot(n, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(DOT_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(DOT_SPECIALIZATION) diff --git a/aten/src/TH/THGenerateFloatType.h b/aten/src/TH/THGenerateFloatType.h index c4b97b52362cb..0a12007cb1984 100644 --- a/aten/src/TH/THGenerateFloatType.h +++ b/aten/src/TH/THGenerateFloatType.h @@ -2,6 +2,10 @@ #error "You must define TH_GENERIC_FILE before including THGenerateFloatType.h" #endif +// no BLAS support for half types +#pragma push_macro("USE_BLAS") +#undef USE_BLAS + #define scalar_t float #define accreal double #define TH_CONVERT_REAL_TO_ACCREAL(_val) (accreal)(_val) @@ -19,6 +23,8 @@ #undef TH_CONVERT_REAL_TO_ACCREAL #undef TH_CONVERT_ACCREAL_TO_REAL +#pragma pop_macro("USE_BLAS") + #ifndef THGenerateManyTypes #undef TH_GENERIC_FILE #endif diff --git a/aten/src/TH/THGenerateHalfType.h b/aten/src/TH/THGenerateHalfType.h index b075c683e009a..16e7d8fdb5a59 100644 --- a/aten/src/TH/THGenerateHalfType.h +++ b/aten/src/TH/THGenerateHalfType.h @@ -3,7 +3,7 @@ #endif #include -#define scalar_t THHalf +#define scalar_t at::Half #define accreal float #define TH_CONVERT_REAL_TO_ACCREAL(_val) (accreal)(_val) #define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val) diff --git a/aten/src/TH/THTensor.h b/aten/src/TH/THTensor.h index c73415dc08160..61c05160b1901 100644 --- a/aten/src/TH/THTensor.h +++ b/aten/src/TH/THTensor.h @@ -31,6 +31,9 @@ #include #include +#include +#include + /* fill and zero*/ #include #include diff --git a/aten/src/TH/THTensorEvenMoreMath.cpp b/aten/src/TH/THTensorEvenMoreMath.cpp index a0b9e190998d8..432deb26828d7 100644 --- a/aten/src/TH/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/THTensorEvenMoreMath.cpp @@ -8,3 +8,6 @@ #include #include + +#include +#include diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index af1096272f92e..8c39e8447a73e 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -11,7 +11,7 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) int64_t *subscript_data; int64_t i = 0; #ifdef TH_REAL_IS_HALF -#define IS_NONZERO(val) ((val.x & 0x7fff) != 0) +#define IS_NONZERO(val) (c10::Half(0)!=val) #else #define IS_NONZERO(val) ((val)!=0) #endif @@ -65,9 +65,11 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) ); delete [] sizes; delete [] idx; + +#undef IS_NONZERO } -#if !defined(TH_REAL_IS_BOOL) /* non bool only part */ +#if !defined(TH_REAL_IS_BOOL) && !defined(TH_REAL_IS_HALF) /* non bool or half only part */ void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value) { diff --git a/test/test_sparse.py b/test/test_sparse.py index 7c02c58ba57d7..215b937095f7c 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -234,33 +234,35 @@ def fn(x): [0, 0, 0, 3], [0, 0, 1, 4], ]) - v = self.value_tensor([2, 1, 3, 4]) - x = self.sparse_tensor(i, v, torch.Size([3, 4, 5])) - res = self.value_tensor([ - [[2, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]], - [[1, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]], - [[0, 3, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 4]], - ]) - test_tensor(x, res) - - i = self.index_tensor([ - [0, 1, 2, 2], - [0, 0, 0, 3], - [0, 0, 1, 4], - ]) - v = self.value_empty(4, 0) - x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0])) - res = self.value_empty(3, 4, 5, 0) - test_tensor(x, res) + for dtype in [torch.float16, torch.int, torch.float64]: + v = self.value_tensor([2, 1, 3, 4]).to(dtype=dtype) + x = self.sparse_tensor(i, v, torch.Size([3, 4, 5])) + res = self.value_tensor([ + [[2, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[1, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 3, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 4]], + ]).to(dtype=dtype) + + test_tensor(x, res) + + i = self.index_tensor([ + [0, 1, 2, 2], + [0, 0, 0, 3], + [0, 0, 1, 4], + ]) + v = self.value_empty(4, 0).to(dtype=dtype) + x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0])) + res = self.value_empty(3, 4, 5, 0).to(dtype=dtype) + test_tensor(x, res) def test_to_sparse(self): shape = [10, 5, 19, 8] @@ -269,12 +271,14 @@ def test_to_sparse(self): max_nnz *= dim_sz rnnz = torch.randint(2, max_nnz, (1,)).item() for nnz in [0, 1, rnnz]: - expected, _, _ = self._gen_sparse(dim, nnz, shape) - d = expected.to_dense() - result = d.to_sparse(dim) - self.assertEqual(d, result.to_dense()) # == not implemented for sparse tensors yet - self.assertEqual(expected.size(), result.size()) - self.assertEqual(dim, result.sparse_dim()) + for dtype in [torch.float16, torch.float64, torch.int]: + expected, _, _ = self._gen_sparse(dim, nnz, shape) + expected = expected.to(dtype) + d = expected.to_dense() + result = d.to_sparse(dim) + self.assertEqual(d, result.to_dense()) # == not implemented for sparse tensors yet + self.assertEqual(expected.size(), result.size()) + self.assertEqual(dim, result.sparse_dim()) sp, _, _ = self._gen_sparse(2, 10, [3, 3, 3]) self.assertRaises(RuntimeError, lambda: sp.to_sparse()) @@ -639,6 +643,9 @@ def test_tensor(x): x = torch.sparse.HalfTensor(2, 3, 4) test_tensor(x) + x = torch.cuda.sparse.HalfTensor(2, 3, 4) + test_tensor(x) + x = torch.sparse.FloatTensor(2, 3, 4, 0) test_tensor(x) diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index 989fd9dfae448..122fad160b643 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -85,7 +85,7 @@ std::vector> all_declared_types() { ScalarType::Int, ScalarType::Long, ScalarType::Short, ScalarType::Half, ScalarType::Bool}; for (auto& backend : backends) { for (auto& scalar_type : scalar_types) { - // there are no sparse half or bool types. + // there is no sparse bool type. if (scalar_type == ScalarType::Bool && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) { continue; } From 5dc8cc2df4ea9abc8e13f05ec6f236e229cea093 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Mon, 6 May 2019 14:59:14 -0400 Subject: [PATCH 12/19] add_ for cpu half needed by TestUncoalescedSparse --- aten/src/ATen/native/native_functions.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ba1790b2a5fde..10fa6f81fdd90 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -135,6 +135,7 @@ MkldnnCPU: mkldnn_add - func: add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + cpu_half: True variants: method dispatch: CPU: add_ From 22babf1cbf4a6b0f2d368c42141d1a1c9407b9e0 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Tue, 7 May 2019 13:15:40 -0400 Subject: [PATCH 13/19] Reverting to_dense and add support on half/cpu to_dense requires add_. add is much slower than float for half types on CPU. --- aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 25 +++++++------------- aten/src/ATen/native/native_functions.yaml | 2 -- aten/src/ATen/native/sparse/SparseTensor.cpp | 3 +++ test/test_sparse.py | 16 ++++++++++--- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 6fdf9372a0b2c..431a6dc88a985 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -14,22 +14,15 @@ namespace { using namespace vec256; void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { - if( iter.dtype() == ScalarType::Half ) { - auto alpha = alpha_scalar.to(); - binary_kernel(iter, [&](Half a, Half b) -> Half { - return a + (alpha * b); - }); - } else { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_cpu", [&]() { - auto alpha = alpha_scalar.to(); - auto alpha_vec = Vec256(alpha); - binary_kernel_vec(iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; }, - [=](Vec256 a, Vec256 b) { - return vec256::fmadd(b, alpha_vec, a); - }); - }); - } + AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_cpu", [&]() { + auto alpha = alpha_scalar.to(); + auto alpha_vec = Vec256(alpha); + binary_kernel_vec(iter, + [=](scalar_t a, scalar_t b) -> scalar_t { return a + alpha * b; }, + [=](Vec256 a, Vec256 b) { + return vec256::fmadd(b, alpha_vec, a); + }); + }); } void sub_kernel(TensorIterator& iter, Scalar alpha_scalar) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 10fa6f81fdd90..7bf02008b988e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -135,7 +135,6 @@ MkldnnCPU: mkldnn_add - func: add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) - cpu_half: True variants: method dispatch: CPU: add_ @@ -2438,7 +2437,6 @@ - func: coalesce(Tensor self) -> Tensor - cpu_half: True variants: method dispatch: SparseCPU: coalesce_sparse_cpu diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index e84216f2bf536..531e9b1b7cdec 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -324,6 +324,9 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ // NB: Dropped the resizeNd variants Tensor sparse_to_dense(const SparseTensor& self) { + if(self.scalar_type() == ScalarType::Half && !self.is_cuda()) { + AT_ERROR("to_dense() not supported for float16 on CPU"); + } Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided)); return dst.add_(self); } diff --git a/test/test_sparse.py b/test/test_sparse.py index 215b937095f7c..328918ff60afa 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -234,7 +234,9 @@ def fn(x): [0, 0, 0, 3], [0, 0, 1, 4], ]) - for dtype in [torch.float16, torch.int, torch.float64]: + # we don't have to_dense for half types on CPU because it is implemented + # with a slower add_ operation + for dtype in [torch.float16, torch.float64] if self.device == 'cuda' else [torch.float64]: v = self.value_tensor([2, 1, 3, 4]).to(dtype=dtype) x = self.sparse_tensor(i, v, torch.Size([3, 4, 5])) res = self.value_tensor([ @@ -264,6 +266,13 @@ def fn(x): res = self.value_empty(3, 4, 5, 0).to(dtype=dtype) test_tensor(x, res) + # half tesnors on cpu don't implement to_dense, so need to convert to float + def _half_safe_to_dense(self, tensor): + if(tensor.dtype == torch.half and tensor.device.type == 'cpu'): + return tensor.to(torch.float).to_dense().to(torch.half) + else: + return tensor.to_dense() + def test_to_sparse(self): shape = [10, 5, 19, 8] max_nnz = 1 @@ -274,9 +283,10 @@ def test_to_sparse(self): for dtype in [torch.float16, torch.float64, torch.int]: expected, _, _ = self._gen_sparse(dim, nnz, shape) expected = expected.to(dtype) - d = expected.to_dense() + + d = self._half_safe_to_dense(expected) result = d.to_sparse(dim) - self.assertEqual(d, result.to_dense()) # == not implemented for sparse tensors yet + self.assertEqual(d, self._half_safe_to_dense(result)) # == not implemented for sparse tensors yet self.assertEqual(expected.size(), result.size()) self.assertEqual(dim, result.sparse_dim()) From b0e7f08bcceb3c310f7f2cef6b321b5d0e3b8762 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 8 May 2019 10:45:22 -0400 Subject: [PATCH 14/19] remove more of unneeded half-cpu code --- aten/src/ATen/native/native_functions.yaml | 1 - aten/src/ATen/native/sparse/SparseTensor.cpp | 4 ++-- aten/src/ATen/native/sparse/SparseTensorMath.cpp | 2 +- aten/src/TH/THBlas.cpp | 3 --- aten/src/TH/THBlas.h | 3 --- aten/src/TH/THBlasUtils.h | 6 +++--- aten/src/TH/generic/THTensorEvenMoreMath.cpp | 7 +++++-- aten/src/TH/generic/THTensorMath.h | 3 +++ 8 files changed, 14 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7bf02008b988e..538bf9f1941f4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1705,7 +1705,6 @@ SparseCUDA: add_out_sparse_cuda - func: _sparse_dense_add(Tensor self, SparseTensorRef other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - cpu_half: True dispatch: CPU: add_out_dense_sparse_cpu CUDA: add_out_dense_sparse_cuda diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 531e9b1b7cdec..49ca7016e3d3a 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -324,7 +324,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ // NB: Dropped the resizeNd variants Tensor sparse_to_dense(const SparseTensor& self) { - if(self.scalar_type() == ScalarType::Half && !self.is_cuda()) { + if(self.scalar_type() == ScalarType::Half && self.options().device().is_cpu()) { AT_ERROR("to_dense() not supported for float16 on CPU"); } Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided)); @@ -379,7 +379,7 @@ SparseTensor coalesce_sparse_cpu(const SparseTensor& self) { auto indicesBufferAccessor = indicesBuffer.accessor(); int64_t i = -1; - AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, + AT_DISPATCH_ALL_TYPES( values.scalar_type(), "coalesce", [&] { int64_t prev = -1; int64_t blockSize = values.stride(0); diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 9a2e1dc005352..e7fc355d969a5 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -364,7 +364,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, SparseTensorRef dstBuffer.add_(srcBuffer, value); } } else { - AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, + AT_DISPATCH_ALL_TYPES( values.scalar_type(), "add_dense_sparse", [&] { add_dense_sparse_worker_cpu(r, value, sparse, indices, values); }); diff --git a/aten/src/TH/THBlas.cpp b/aten/src/TH/THBlas.cpp index ba695c352240a..a98097db9d005 100644 --- a/aten/src/TH/THBlas.cpp +++ b/aten/src/TH/THBlas.cpp @@ -2,6 +2,3 @@ #include #include - -#include -#include diff --git a/aten/src/TH/THBlas.h b/aten/src/TH/THBlas.h index a62e186cf0f4f..3911861f9005c 100644 --- a/aten/src/TH/THBlas.h +++ b/aten/src/TH/THBlas.h @@ -8,7 +8,4 @@ #include #include -#include -#include - #endif diff --git a/aten/src/TH/THBlasUtils.h b/aten/src/TH/THBlasUtils.h index 9d61e5feab29a..ffc39822fe9cb 100644 --- a/aten/src/TH/THBlasUtils.h +++ b/aten/src/TH/THBlasUtils.h @@ -16,7 +16,7 @@ inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_axpy(n, a, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(AXPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(AXPY_SPECIALIZATION) template @@ -29,7 +29,7 @@ inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_copy(n, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(COPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(COPY_SPECIALIZATION) template inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); @@ -40,4 +40,4 @@ inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); return TH ## name ## Blas_dot(n, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(DOT_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(DOT_SPECIALIZATION) diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index 3b2b6c4a4692e..d7ec018fd1715 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -69,6 +69,8 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) #undef IS_NONZERO } +#if !defined(TH_REAL_IS_HALF) /* non bool or half only part */ + accreal THTensor_(sumall)(THTensor *tensor) { accreal sum = 0; @@ -88,8 +90,7 @@ accreal THTensor_(sumall)(THTensor *tensor) } return sum; } - -#if !defined(TH_REAL_IS_BOOL) && !defined(TH_REAL_IS_HALF) /* non bool or half only part */ +#if !defined(TH_REAL_IS_BOOL) void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value) { @@ -1033,4 +1034,6 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value) #endif +#endif + #endif /* TH_GENERIC_FILE */ diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 82823441aee81..c12ab74048c13 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -4,6 +4,8 @@ TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor); +#ifndef TH_REAL_IS_HALF + TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value); TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value); TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value); @@ -183,3 +185,4 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp #endif #endif +#endif From ab20267c1aff99cd1dbc8a8860df7369fde7ae6d Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 8 May 2019 11:10:38 -0400 Subject: [PATCH 15/19] remove no-longer needed half-cpu change --- aten/src/TH/THGenerateHalfType.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/TH/THGenerateHalfType.h b/aten/src/TH/THGenerateHalfType.h index 16e7d8fdb5a59..b075c683e009a 100644 --- a/aten/src/TH/THGenerateHalfType.h +++ b/aten/src/TH/THGenerateHalfType.h @@ -3,7 +3,7 @@ #endif #include -#define scalar_t at::Half +#define scalar_t THHalf #define accreal float #define TH_CONVERT_REAL_TO_ACCREAL(_val) (accreal)(_val) #define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val) From fdb90a2e19e9eaff30cda59cfc7560a23d9aafa9 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 8 May 2019 11:13:40 -0400 Subject: [PATCH 16/19] fix a stale comment --- aten/src/TH/generic/THTensorEvenMoreMath.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index d7ec018fd1715..46c546f3b2989 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -69,7 +69,7 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) #undef IS_NONZERO } -#if !defined(TH_REAL_IS_HALF) /* non bool or half only part */ +#if !defined(TH_REAL_IS_HALF) /* non half only part */ accreal THTensor_(sumall)(THTensor *tensor) { From 1974e1cb7bc1491d7e1dcb2c6c659bc2530df11e Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 8 May 2019 11:26:43 -0400 Subject: [PATCH 17/19] remove unnecessary blas guard --- aten/src/TH/THGenerateFloatType.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/aten/src/TH/THGenerateFloatType.h b/aten/src/TH/THGenerateFloatType.h index 0a12007cb1984..c4b97b52362cb 100644 --- a/aten/src/TH/THGenerateFloatType.h +++ b/aten/src/TH/THGenerateFloatType.h @@ -2,10 +2,6 @@ #error "You must define TH_GENERIC_FILE before including THGenerateFloatType.h" #endif -// no BLAS support for half types -#pragma push_macro("USE_BLAS") -#undef USE_BLAS - #define scalar_t float #define accreal double #define TH_CONVERT_REAL_TO_ACCREAL(_val) (accreal)(_val) @@ -23,8 +19,6 @@ #undef TH_CONVERT_REAL_TO_ACCREAL #undef TH_CONVERT_ACCREAL_TO_REAL -#pragma pop_macro("USE_BLAS") - #ifndef THGenerateManyTypes #undef TH_GENERIC_FILE #endif From ceebfe63ae239de62780e814272dc25feef2733a Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 8 May 2019 17:34:23 -0400 Subject: [PATCH 18/19] cleanup names from code-review --- test/test_sparse.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_sparse.py b/test/test_sparse.py index 328918ff60afa..37c79342f6ade 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -236,7 +236,7 @@ def fn(x): ]) # we don't have to_dense for half types on CPU because it is implemented # with a slower add_ operation - for dtype in [torch.float16, torch.float64] if self.device == 'cuda' else [torch.float64]: + for dtype in [torch.float16, torch.float64] if self.device != 'cpu' else [torch.float64]: v = self.value_tensor([2, 1, 3, 4]).to(dtype=dtype) x = self.sparse_tensor(i, v, torch.Size([3, 4, 5])) res = self.value_tensor([ @@ -266,8 +266,8 @@ def fn(x): res = self.value_empty(3, 4, 5, 0).to(dtype=dtype) test_tensor(x, res) - # half tesnors on cpu don't implement to_dense, so need to convert to float - def _half_safe_to_dense(self, tensor): + # half tensors on cpu don't implement to_dense, so need to convert to float + def _half_to_dense_safe(self, tensor): if(tensor.dtype == torch.half and tensor.device.type == 'cpu'): return tensor.to(torch.float).to_dense().to(torch.half) else: @@ -284,9 +284,9 @@ def test_to_sparse(self): expected, _, _ = self._gen_sparse(dim, nnz, shape) expected = expected.to(dtype) - d = self._half_safe_to_dense(expected) + d = self._half_to_dense_safe(expected) result = d.to_sparse(dim) - self.assertEqual(d, self._half_safe_to_dense(result)) # == not implemented for sparse tensors yet + self.assertEqual(d, self._half_to_dense_safe(result)) # == not implemented for sparse tensors yet self.assertEqual(expected.size(), result.size()) self.assertEqual(dim, result.sparse_dim()) @@ -1546,7 +1546,7 @@ def test_factory(self): torch.device(torch.cuda.device_count() - 1) indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2]) if test_empty_tensor: - values = self.value_empty(1, 0) + values = self.value_empty(1, 0).to(dtype) else: if use_tensor_val: values = torch.tensor([1.], dtype=dtype) From 79d1906bffc79fb230ed198f6d0f39e67fa1b980 Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 8 May 2019 17:39:10 -0400 Subject: [PATCH 19/19] rename a function to be clearer --- test/test_sparse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_sparse.py b/test/test_sparse.py index 37c79342f6ade..d484a998fecc4 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -267,7 +267,7 @@ def fn(x): test_tensor(x, res) # half tensors on cpu don't implement to_dense, so need to convert to float - def _half_to_dense_safe(self, tensor): + def _to_dense_half_safe(self, tensor): if(tensor.dtype == torch.half and tensor.device.type == 'cpu'): return tensor.to(torch.float).to_dense().to(torch.half) else: @@ -284,9 +284,9 @@ def test_to_sparse(self): expected, _, _ = self._gen_sparse(dim, nnz, shape) expected = expected.to(dtype) - d = self._half_to_dense_safe(expected) + d = self._to_dense_half_safe(expected) result = d.to_sparse(dim) - self.assertEqual(d, self._half_to_dense_safe(result)) # == not implemented for sparse tensors yet + self.assertEqual(d, self._to_dense_half_safe(result)) # == not implemented for sparse tensors yet self.assertEqual(expected.size(), result.size()) self.assertEqual(dim, result.sparse_dim())