Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import tensordict

```bash
Traceback (most recent call last):
File ...
File ...
```

## Expected behavior
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run-clang-format.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def run_clang_format_diff(args, file):
# Hopefully, this is the correct thing to do.
#
# It's done due to the following assumptions (which may be incorrect):
# - clang-format will returns the bytes read from the files as-is,
# - clang-format will return the bytes read from the files as-is,
# without conversion, and it is already assumed that the files use utf-8.
# - if the diagnostics were internationalized, they would use utf-8:
# > Adding Translations to Clang
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
git checkout ${{ github.event.pull_request.base.sha }}
$RUN_BENCHMARK ${{ env.BASELINE_JSON }}
git checkout ${{ github.event.pull_request.head.sha }}
$RUN_BENCHMARK ${{ env.CONTENDER_JSON }}
$RUN_BENCHMARK ${{ env.CONTENDER_JSON }}
- name: Publish results
uses: apbard/pytest-benchmark-commenter@v3
with:
Expand Down
20 changes: 10 additions & 10 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
os=Linux

# 1. Install conda at ./conda
printf "* Installing conda\n"
wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh"
Expand All @@ -51,33 +51,33 @@ jobs:
conda create --prefix "${env_dir}" -y python=3.8
printf "* Activating\n"
conda activate "${env_dir}"

# 2. upgrade pip, ninja and packaging
apt-get install python3.8 python3-pip unzip -y
python3 -m pip install --upgrade pip
python3 -m pip install setuptools ninja packaging -U

# 3. check python version
python3 --version

# 4. Check git version
git version

# 5. Install PyTorch
python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --quiet --root-user-action=ignore

# 6. Install tensordict
python3 setup.py develop

# 7. Install requirements
python3 -m pip install -r docs/requirements.txt --quiet --root-user-action=ignore

# 8. Test tensordict installation
mkdir _tmp
cd _tmp
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl python3 -c """from tensordict import *"""
cd ..

# 9. Set sanitize version
if [[ ${{ github.event_name }} == push && (${{ github.ref_type }} == tag || (${{ github.ref_type }} == branch && ${{ github.ref_name }} == release/*)) ]]; then
echo '::group::Enable version string sanitization'
Expand All @@ -100,7 +100,7 @@ jobs:

upload:
needs: build-docs
if: github.repository == 'pytorch/tensordict' && github.event_name == 'push' &&
if: github.repository == 'pytorch/tensordict' && github.event_name == 'push' &&
((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag')
permissions:
contents: write
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@ jobs:
echo '::group::Setup environment'
CONDA_PATH=$(which conda)
eval "$(${CONDA_PATH} shell.bash hook)"
conda create --name ci --quiet --yes python=3.8 pip
conda create --name ci --quiet --yes python=3.9 pip
conda activate ci
echo '::endgroup::'

echo '::group::Install lint tools'
pip install --progress-bar=off pre-commit
echo '::endgroup::'

echo '::group::Lint Python source and configs'
set +e
pre-commit run --all-files

if [ $? -ne 0 ]; then
git --no-pager diff
exit 1
Expand All @@ -50,15 +50,15 @@ jobs:
repository: pytorch/tensordict
script: |
set -euo pipefail

echo '::group::Setup environment'
CONDA_PATH=$(which conda)
eval "$(${CONDA_PATH} shell.bash hook)"
conda create --name ci --quiet --yes -c conda-forge python=3.8 ncurses=5 libgcc
conda activate ci
export LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}"
echo '::endgroup::'

echo '::group::Install lint tools'
curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o ./clang-format
chmod +x ./clang-format
Expand All @@ -67,7 +67,7 @@ jobs:
echo '::group::Lint C source'
set +e
./.github/unittest/linux/scripts/run-clang-format.py -r tensordict/csrc --clang-format-executable ./clang-format

if [ $? -ne 0 ]; then
git --no-pager diff
exit 1
Expand Down
16 changes: 11 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v4.6.0
hooks:
- id: check-docstring-first
- id: check-toml
Expand All @@ -9,16 +9,22 @@ repos:
- id: mixed-line-ending
args: [--fix=lf]
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/omnilib/ufmt
rev: v2.0.0b2
rev: v2.7.0
hooks:
- id: ufmt
additional_dependencies:
- black == 22.3.0
- black == 24.4.2
- usort == 1.0.3
- libcst == 0.4.7

- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black

- repo: https://github.com/pycqa/flake8
rev: 7.1.0
hooks:
Expand All @@ -27,12 +33,12 @@ repos:
additional_dependencies:
- flake8-bugbear==22.10.27
- flake8-comprehensions==3.10.1
- torchfix==0.0.2
- torchfix==0.5.0
- flake8-print==5.0.0
# - flake8-unused-arguments==0.0.13

- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
rev: 6.3.0
hooks:
- id: pydocstyle
files: ^tensordict/
12 changes: 6 additions & 6 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ We want to make contributing to this project as easy and transparent as
possible.

## Installing the library
Install the library as suggested in the README. For advanced features,
Install the library as suggested in the README. For advanced features,
it is preferable to install the nightly built of pytorch.

Make sure you install tensordict in develop mode by running
Expand All @@ -12,25 +12,25 @@ python setup.py develop
```
in your shell.

If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message
If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message
`(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))` appears, then try

```
ARCHFLAGS="-arch arm64" python setup.py develop
```

## Formatting your code
**Type annotation**
**Type annotation**

tensordict is not strongly-typed, i.e. we do not enforce type hints, neither do we check that the ones that are present are valid. We rely on type hints purely for documentary purposes. Although this might change in the future, there is currently no need for this to be enforced at the moment.
tensordict is not strongly-typed, i.e. we do not enforce type hints, neither do we check that the ones that are present are valid. We rely on type hints purely for documentary purposes. Although this might change in the future, there is currently no need for this to be enforced at the moment.

**Linting**

Before your PR is ready, you'll probably want your code to be checked. This can be done easily by installing
Before your PR is ready, you'll probably want your code to be checked. This can be done easily by installing
```
pip install pre-commit
```
and running
and running
```
pre-commit run --all-files
```
Expand Down
52 changes: 26 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@

`TensorDict` is a dictionary-like class that inherits properties from tensors,
such as indexing, shape operations, casting to device or point-to-point communication
in distributed settings. Whenever you need to execute an operation over a batch of tensors,
in distributed settings. Whenever you need to execute an operation over a batch of tensors,
TensorDict is there to help you.

The primary goal of TensorDict is to make your code-bases more _readable_, _compact_, and _modular_.
It abstracts away tailored operations, making your code less error-prone as it takes care of
The primary goal of TensorDict is to make your code-bases more _readable_, _compact_, and _modular_.
It abstracts away tailored operations, making your code less error-prone as it takes care of
dispatching the operation on the leaves for you.

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:
Expand All @@ -66,13 +66,13 @@ For instance, the above example can be easily used across classification and seg

Unlike other [pytrees](https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py), TensorDict
carries metadata that make it easy to query the state of the container. The main metadata
are the [``batch_size``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.batch_size)
(also referred as ``shape``),
the [``device``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.device),
are the [``batch_size``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.batch_size)
(also referred as ``shape``),
the [``device``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.device),
the shared status
([``is_memmap``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDictBase.html#tensordict.TensorDictBase.is_memmap) or
[``is_shared``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDictBase.html#tensordict.TensorDictBase.is_shared)),
the dimension [``names``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.names)
([``is_memmap``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDictBase.html#tensordict.TensorDictBase.is_memmap) or
[``is_shared``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDictBase.html#tensordict.TensorDictBase.is_shared)),
the dimension [``names``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.names)
and the [``lock``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.lock_) status.

A tensordict is primarily defined by its `batch_size` (or `shape`) and its key-value pairs:
Expand All @@ -85,7 +85,7 @@ A tensordict is primarily defined by its `batch_size` (or `shape`) and its key-v
... }, batch_size=[3, 4])
```
The `batch_size` and the first dimensions of each of the tensors must be compliant.
The tensors can be of any dtype and device.
The tensors can be of any dtype and device.

Optionally, one can restrict a tensordict to
live on a dedicated ``device``, which will send each tensor that is written there:
Expand All @@ -101,22 +101,22 @@ TensorDict device:
>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")
```
Once the device is set, it can be cleared with the
Once the device is set, it can be cleared with the
[``clear_device_``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.clear_device_)
method.
method.

### TensorDict as a specialized dictionary
TensorDict possesses all the basic features of a dictionary such as
[``clear``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.clear),
[``copy``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.copy),
[``fromkeys``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.fromkeys),
[``get``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.get),
[``items``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.items),
[``keys``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.keys),
[``pop``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.pop),
[``popitem``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.popitem),
[``setdefault``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.setdefault),
[``update``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.update) and
TensorDict possesses all the basic features of a dictionary such as
[``clear``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.clear),
[``copy``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.copy),
[``fromkeys``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.fromkeys),
[``get``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.get),
[``items``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.items),
[``keys``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.keys),
[``pop``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.pop),
[``popitem``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.popitem),
[``setdefault``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.setdefault),
[``update``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.update) and
[``values``](https://pytorch.org/tensordict/reference/generated/tensordict.TensorDict.html#tensordict.TensorDict.values).

But that is not all, you can also store nested values in a tensordict:
Expand Down Expand Up @@ -471,7 +471,7 @@ conda install -c conda-forge tensordict
If you're using TensorDict, please refer to this BibTeX entry to cite this work:
```
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
Expand All @@ -482,9 +482,9 @@ If you're using TensorDict, please refer to this BibTeX entry to cite this work:

## Disclaimer

TensorDict is at the *beta*-stage, meaning that there may be bc-breaking changes introduced, but
TensorDict is at the *beta*-stage, meaning that there may be bc-breaking changes introduced, but
they should come with a warranty.
Hopefully these should not happen too often, as the current roadmap mostly
Hopefully these should not happen too often, as the current roadmap mostly
involves adding new features and building compatibility with the broader
PyTorch ecosystem.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ distinguish on a high level parameters and buffers (they are all packed together

Ensembles
---------
The functional approach enables a straightforward ensemble implementation.
The functional approach enables a straightforward ensemble implementation.
We can duplicate and reinitialize model copies using the :class:`tensordict.nn.EnsembleModule`

.. code-block::
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[build-system]
requires = ["setuptools", "wheel", "torch"]

[tool.black]
safe = true
line-length = 88

[tool.usort]
first_party_detection = false
target-version = ["py38"]
excludes = [
"gallery",
"tutorials",
]

[tool.black]
line-length = 88
target-version = ["py38"]
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ max-line-length = 120
[flake8]
# note: we ignore all 501s (line too long) anyway as they're taken care of by black
max-line-length = 79
ignore = E203, E402, W503, W504, E501
ignore = E203, E402, W503, W504, E501, E701, E704
per-file-ignores =
__init__.py: F401, F403, F405
./hubconf.py: F401
Expand Down
Loading