-
Notifications
You must be signed in to change notification settings - Fork 411
[Doc] Added TensorDict tutorial #255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Few comments:
td[1:] = td2
class TransformerBlockTensorDict(nn.Module): # don't do that should be (2) class TransformerBlockTensorDict(TensorDictModule): I would advise to explicitly tell people not to do (1) but (2). |
Some more features we can show td.__setindex__(...) memory sharing
stacking stacked_td = torch.stack([td1, td2, td3], dim)
if stacked_td[0] is td1:
print("every tensordict is awesome!") # see LazyStackedTensorDict # reshaping
td.view(-1)
td.unsqueeze(0)
td.permute(...)
td.squeeze(dim) changing batch size td.batch_size = [3] changing device td.cuda()
td.to("cuda:1") masking td[mask] unbind td.unbind(0) cat torch.cat([td1, td2, td3], 1) you can check the tests, every feature being tested it's a good collection of features to look at |
# necessary to run a TensorDictModule. If a key is an intermediary in | ||
# the chain, there is no reason why it should belong to the input | ||
# TensorDict. | ||
in_keys += [key for key in module.in_keys if key not in out_keys] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these changes should not be part of this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once you update your branch from main after #261 is merged this should disappear
See changes in the PR over this PR Comments I couldn't put there:
return {"image": image, "mask": mask} to be more explicit?
This is not true. Cloning returns the same TensorDict class as the original one (unlike |
tensordict = TensorDict({}, [10])
for i in range(10):
tensordict[i] = TensorDict({"a": torch.randn(3, 4)}, [])
print(tensordict)
|
|
There is some more magic to squeeze and unsqueeze which is this
Same for |
Rather than printing stuff, perhaps you could put # Converting a LazyStackedTensorDict to a regular TensorDict is easy:
assert isinstance(staked_tensordict.contiguous(), TensorDict) rather than print(staked_tensordict.contiguous()) like that the users can understand what you're printing. Otherwise they need to read the log, and figure out what there is to see. |
Some corrections in tensordict tuto
…ict_tutorial Retrieving nested tensordict
cleaning
Can we remove the png from here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
Tutorial on Tensor Dict and TensorDictModule.
Contains simple examples of TensorDict and TensorDictModule operations.
Also contains implementation of a Transformer model using Tensor Dict and TensorDictModule to showcase how this modules work