Skip to content

Conversation

nicolas-dufour
Copy link
Contributor

Tutorial on Tensor Dict and TensorDictModule.
Contains simple examples of TensorDict and TensorDictModule operations.
Also contains implementation of a Transformer model using Tensor Dict and TensorDictModule to showcase how this modules work

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 7, 2022
@vmoens vmoens added the documentation Improvements or additions to documentation label Jul 7, 2022
@vmoens
Copy link
Collaborator

vmoens commented Jul 8, 2022

Few comments:

  1. On a high level, I think it would make sense to split the file in tuto for TensorDict and tuto for TensorDictModule
  2. The structure of the tuto for TensorDict should be like this IMO:
  • What's a TensorDict, what's the 10000 feet view of it, what it can do (broadly speaking): e.g. a TensorDict is a horse that can carry Teletubbies that must be high on eucalyptus. It can crawl and jump but not walk. It is aimed at being nice to watch.
  • Some exciting stuff about it first: don't start with "it shares a lot with dict" otherwise people may be bored already. Find something you feel excited about (that is not too long) and that people will be like "oh gosh I wish I knew about it when i was doing my latest project!"
  • We can ditch the initial part about PPO and DQN.
  1. td.update is cool, but you can also talk about setindex which is cool too
td[1:] = td2
  1. a crucial method is to_tensordict. Unlike clone (which will return the same tensordict subclass) to_tensordict will return a regular tensordict. This is super useful to make stuff contiguous in memory.
  2. let's try to get the linting right: {"a":torch.randn(3)} is {"a": torch.randn(3)}, [3,4,5] is [3, 4, 5] etc.
  3. Let's tell people that their TensorDictModules should inherit from the official TensorDictModule class
    (1)
class TransformerBlockTensorDict(nn.Module):  # don't do that

should be (2)

class TransformerBlockTensorDict(TensorDictModule):

I would advise to explicitly tell people not to do (1) but (2).
7. nn.Sequential(TensorDictModule(...), ...) should be SequentialTensorDict(...)

@vmoens
Copy link
Collaborator

vmoens commented Jul 8, 2022

Some more features we can show
set index

td.__setindex__(...)

memory sharing

td.memmap_()
td.share_memory_()
td.to(SavedTensorDict)

stacking

stacked_td = torch.stack([td1, td2, td3], dim) 
if stacked_td[0] is td1:
   print("every tensordict is awesome!")  # see LazyStackedTensorDict
# reshaping
td.view(-1)
td.unsqueeze(0)
td.permute(...)
td.squeeze(dim)

changing batch size

td.batch_size = [3]

changing device

td.cuda()
td.to("cuda:1")

masking

td[mask]

unbind

td.unbind(0)

cat

torch.cat([td1, td2, td3], 1)

you can check the tests, every feature being tested it's a good collection of features to look at

# necessary to run a TensorDictModule. If a key is an intermediary in
# the chain, there is no reason why it should belong to the input
# TensorDict.
in_keys += [key for key in module.in_keys if key not in out_keys]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes should not be part of this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you update your branch from main after #261 is merged this should disappear

@vmoens vmoens changed the title Added TensorDict tutorial [Doc] Added TensorDict tutorial Jul 11, 2022
@vmoens
Copy link
Collaborator

vmoens commented Jul 11, 2022

See changes in the PR over this PR

Comments I couldn't put there:

  1. What is "modality_A" and "_B" in the fake dataset? Why do you call torch.Tensor around randn? It's not obvious either why you'd want a random dataset? Can we say something like
        return {"image": image, "mask": mask}

to be more explicit?
2. One thing to point out also is that we can (almost, as soon as we merge #256) have nested tensordicts. Imagine doing a stack function in a collate function for nested dict!
3. Before showing the properties of a tensordict, let's make it explicit what it can contain: tensors, memmap-tensors, tensordicts. Also, let's make it clear that all tensors have to share the same device and batch_size, but not dtype.
4. Don't forget to lint your code. We don't test lint for jupyter notebooks so it's important that the code is compliant.
5. Let's avoid print(tensor) when it leads to long outputs with [[1., 1. ...],[...]]
6. When talking about batch_size it is important not just to print it but say what it means. It is also crucial to tell people that it has to be specified explicitly (it is not inferred from the tensors). There are similar classes in other repos that do that. It's a very important feature of TensorDict, let's spend some time explaining what it is. Also I would reprioritize stack and cat and put them earlier. We should make it clear that one can give an out=tensordict kwarg if wanted.
7. Same with tensordict.update({"c": torch.zeros(4, 3, 1)}), you should say what you're trying to do and why it's not working.
8. For tensordict.batch_size = [4,4], perhaps start by showing that it can be set to a value that is complient with the tensors shape: tensordict.batch_size = [3] works fine.
9. Let's not write "tensor dict" but always "tensordict" or "TensorDict". Let's keep the latter for the class and the former when talking about an instantiated object.
10. Try to revise the ### and ####, it is unclear which one is at which level.
11. For cloning, why

### Cloning
TensorDict supports cloning. Cloning returns the same SubTensorDict item than the original item.

This is not true. Cloning returns the same TensorDict class as the original one (unlike to_tensordict()). It is important to state that difference.
12. There are lots of titles for keys, values etc. but many are missing for the tensor features (e.g. masking etc). Let's make sure that if we use titles they are present everywhere.
13. Let's avoid long continuous lines if they don't devide big topics
14. Let's try to minimize the number of cells in the notbook: if there's a title, let's merge it with the comment that follows if possible
15. When you say "We can perform masking on the indexes", what does that mean? Shouldn't it just be "We can mask tensordicts as we mask tensors"? or something like it.
16. Let's not create tensors like this mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]).bool() but like this mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]], dtype=torch.bool). It's better if we play by the rules.
17. There's a section about view but none about permute and reshape. I'm in favour of grouping them, but if so then let's make an appropriate title.
18. The #Cat title is broken. Also let's say something about cat: like, can you cat a tensordict along dim 3 if batch_dims=2?
19. Tensordict also supports squeeze and unsqueeze. Use `to_tensordict` to retrieve a tensordict: you should be clear about what happens if you don't call to_tensordict: those are lazy operations, they're only acted if the tensor is queried.
20. There's a problem with the title ## How to use them in practice? The tensor the TensorDictModule but i guess this section should go away now that we have another file about tensordict modules?

@vmoens
Copy link
Collaborator

vmoens commented Jul 11, 2022

  1. Also let's talk about this
tensordict = TensorDict({}, [10])
for i in range(10):
    tensordict[i] = TensorDict({"a": torch.randn(3, 4)}, [])
print(tensordict)
  1. Let's make sure that the notebook runs without exception. If you want to raise one, place it in a try / except.

@vmoens
Copy link
Collaborator

vmoens commented Jul 11, 2022

  1. Let's talk about the fact that tensordicts do not support algebraic operations

@vmoens
Copy link
Collaborator

vmoens commented Jul 11, 2022

There is some more magic to squeeze and unsqueeze which is this

tensordict = TensorDict({}, [10])
assert tensordict.unsqueeze(0).squeeze(0) is tensordict

Same for LazyStackedTensorDict

@vmoens
Copy link
Collaborator

vmoens commented Jul 11, 2022

Rather than printing stuff, perhaps you could put asserts with a small comment, e.g.

# Converting a LazyStackedTensorDict to a regular TensorDict is easy:
assert isinstance(staked_tensordict.contiguous(), TensorDict)

rather than

print(staked_tensordict.contiguous())

like that the users can understand what you're printing. Otherwise they need to read the log, and figure out what there is to see.

@vmoens
Copy link
Collaborator

vmoens commented Jul 12, 2022

Can we remove the png from here?

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

@vmoens vmoens merged commit 8729851 into pytorch:main Jul 12, 2022
@nicolas-dufour nicolas-dufour deleted the tensordict_tutorial branch September 16, 2022 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants