Skip to content

Conversation

@izdeby
Copy link
Contributor

@izdeby izdeby commented Mar 19, 2019

Stack from ghstack:


This PR enables bool tensor creation and some basic operations for the CPU backend. This is a part of Bool Tensor feature implementation work. The whole plan looks like this:

  1. Storage Implementation [Done]
  2. Tensor Creation.
    a) CPU [Done]
    b) CUDA [This PR]
  3. Tensor Conversions.
  4. Tensor Indexing.
  5. Tensor Operations.
  6. Back compatibility related changes.

Change:
Enable bool tensor in CUDA with the following operations:

torch.zeros
torch.tensor
torch.ones
torch.rand/rand_like/randint/randint_like
torch.full
torch.full_like
torch.empty
torch.empty_like

Tested via unit tests and local scripts.

Differential Revision: D14605104

@izdeby izdeby changed the title Bool Tensor for CUDA [WIP] Bool Tensor for CUDA Mar 19, 2019
@izdeby izdeby changed the title [WIP] Bool Tensor for CUDA Bool Tensor for CUDA Mar 19, 2019
@izdeby izdeby requested review from ezyang and gchanan March 19, 2019 04:04
@izdeby izdeby changed the title Bool Tensor for CUDA [WIP]Bool Tensor for CUDA Mar 19, 2019
@izdeby izdeby changed the title [WIP]Bool Tensor for CUDA Bool Tensor for CUDA Mar 19, 2019
@izdeby
Copy link
Contributor Author

izdeby commented Mar 19, 2019

@gchanan @ezyang, please take a look once you have time.
Please note:

  1. i enabled 'cat' for bool as test_print was failing in test_torch and i've figured out that we probably need it for bool.
  2. GENERATE_KERNEL2 related changes are needed as 'generate_uniform' is being used in 'uniform' for THC random.

@izdeby izdeby requested a review from gchanan March 19, 2019 23:25
static void apply(Tensor& dst, const Tensor& src) {
CUDA_tensor_apply2<dst_T, bool>(
dst, src, [] __device__(dst_T & dst_val, const bool& src_val) {
dst_val = static_cast<dst_T>(static_cast<native::inter_copy_type_t<dst_T>>(src_val));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, was this specialization necessary because __ldg doesn't have a bool overload?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be true, but it's pretty shocking to me. It's definitely better if we keep __ldg here. Maybe this means we should read it in as a uint8_t, cast it to bool, and then do the final static cast to destination?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ngimel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just implement __ldg for bool like we did for Half?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably wait until there is another case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a guarantee that sizeof(bool) == sizeof(uint8_t), so you cannot read it as uint8_t?

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patch seems basically reasonable. However, there are some major issues.

@izdeby izdeby changed the title Bool Tensor for CUDA [WIP] Bool Tensor for CUDA Mar 21, 2019
@izdeby
Copy link
Contributor Author

izdeby commented Apr 2, 2019

@pytorchbot retest this please

1 similar comment
@izdeby
Copy link
Contributor Author

izdeby commented Apr 2, 2019

@pytorchbot retest this please

@izdeby
Copy link
Contributor Author

izdeby commented Apr 2, 2019

@pytorchbot retest this please

1 similar comment
@izdeby
Copy link
Contributor Author

izdeby commented Apr 2, 2019

@pytorchbot retest this please

zdevito pushed a commit to zdevito/ATen that referenced this pull request Apr 3, 2019
Summary:
Pull Request resolved: pytorch/pytorch#18166
ghimport-source-id: a8e2ba2d966e49747a55701c4f6863c5e24d6f14

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18166 Bool Tensor for CUDA**
* #18165 Resolved comments from Bool Tensor for CPU PR
------

This PR enables bool tensor creation and some basic operations for the CPU backend. This is a part of Bool Tensor feature implementation work. The whole plan looks like this:
1. Storage Implementation [Done]
2. Tensor Creation.
a) CPU [Done]
b) CUDA [This PR]
3. Tensor Conversions.
4. Tensor Indexing.
5. Tensor Operations.
6. Back compatibility related changes.

Change:
Enable bool tensor in CUDA with the following operations:

    torch.zeros
    torch.tensor
    torch.ones
    torch.rand/rand_like/randint/randint_like
    torch.full
    torch.full_like
    torch.empty
    torch.empty_like

Tested via unit tests and local scripts.

Differential Revision: D14605104

fbshipit-source-id: b7d7340a7d70edd03a109222d271e68becba762c
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in b832b99.

facebook-github-bot pushed a commit that referenced this pull request Apr 3, 2019
Summary:
Pull Request resolved: #18505
ghimport-source-id: f3c9b92

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18505 [WIP]Added numpy conversion**
* #18166 Bool Tensor for CUDA

Differential Revision: D14646403

fbshipit-source-id: 79d39d692c778ce1981c1d35b1c33e3d93111041
facebook-github-bot pushed a commit that referenced this pull request Apr 3, 2019
Summary:
Pull Request resolved: #18583
ghimport-source-id: 2b19414

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18583 Added indexing for bool tensors and bool Indices**
* #18505 Added numpy conversion
* #18166 Bool Tensor for CUDA

-----------
This PR enables bool tensor indexing and indexing with bool indices. This is a part of Bool Tensor feature implementation work. The whole plan looks like this:
1. Storage Implementation [Done]
2. Tensor Creation.
    a) CPU [Done]
    b) CUDA [In review]
3. Tensor Conversions. [In review]
4. Tensor Indexing. [This PR]
5. Tensor Operations.
6. Back compatibility related changes.

TODO:
as a follow up, we should move nonzero method from TH to Aten to make code cleaner.

Change:
```
v = torch.tensor([True, False, True], dtype=torch.bool)
boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
v[boolIndices]
-> tensor([True], dtype=torch.bool)

v = torch.randn(5, 7, 3)
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool)
v[boolIndices]
->
tensor([[[ 0.5885, -0.3322,  0.7388],
         [ 1.1182,  0.7808, -1.1492],
         [-0.7952,  0.5255, -0.0251],
         [ 0.7128,  0.8099,  1.2689],
         [-0.7018, -1.4733, -0.3732],
         [ 0.4503,  0.4986, -1.1605],
         [ 0.3348, -1.3767, -0.2976]],

        [[-2.0303, -0.4720, -0.1448],
         [-0.1914, -0.6821,  2.0061],
         [-1.0420, -0.1872, -0.3438],
         [ 1.7587, -0.4183, -0.7577],
         [ 1.0094, -0.1950, -0.2430],
         [ 0.1174,  0.3308, -0.5700],
         [ 0.1110, -0.2714,  1.3006]],

        [[-0.1946, -1.4747, -0.4650],
         [-1.0567,  1.0110, -0.2809],
         [ 0.3729, -0.5699,  0.0815],
         [-0.7733, -0.8316,  0.1674],
         [ 1.2000, -0.3745, -1.1679],
         [ 1.7105,  0.9851, -0.1907],
         [-1.1077,  0.2086, -0.0548]]])
```

Differential Revision: D14673403

fbshipit-source-id: 2b88ec2c7eb26a4f5ef64f8707fb68068d476fc9
@izdeby izdeby deleted the gh/izdeby/2/head branch April 10, 2019 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants