Skip to content

Commit 8774e73

Browse files
author
Vincent Moens
committed
[CI] Fix 3.13t wheels
ghstack-source-id: c57cd1f Pull Request resolved: #1294
1 parent 3d2e6d2 commit 8774e73

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

.github/scripts/version_script.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
#!/bin/bash
22

33
export TENSORDICT_BUILD_VERSION=0.8.0
4+
${CONDA_RUN} pip install --upgrade pip
5+
6+
${CONDA_RUN} conda install conda-forge::rust -y
7+
# for orjson
8+
export UNSAFE_PYO3_BUILD_FREE_THREADED=1

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ requires = ["setuptools", "wheel", "torch"]
33

44
[tool.usort]
55
first_party_detection = false
6-
target-version = ["py38"]
6+
target-version = ["py39"]
77
excludes = [
88
"gallery",
99
"tutorials",
1010
]
1111

1212
[tool.black]
1313
line-length = 88
14-
target-version = ["py38"]
14+
target-version = ["py39"]

test/test_nn.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -812,8 +812,9 @@ def test_dispatch_deactivate(self):
812812
with _set_dispatch_td_nn_modules(True):
813813
out = tdm(a=torch.zeros(1, 1))
814814
assert (out == td["b"]).all()
815-
with _set_dispatch_td_nn_modules(False), pytest.raises(
816-
TypeError, match="missing 1 required positional argument"
815+
with (
816+
_set_dispatch_td_nn_modules(False),
817+
pytest.raises(TypeError, match="missing 1 required positional argument"),
817818
):
818819
tdm(a=torch.zeros(1, 1))
819820

@@ -2458,11 +2459,14 @@ def test_no_warning_single_key(self):
24582459
out_keys=[("dirich", "categ")],
24592460
return_log_prob=True,
24602461
)
2461-
with pytest.warns(
2462-
DeprecationWarning, match="You are querying the log-probability key"
2463-
), pytest.warns(
2464-
DeprecationWarning,
2465-
match="Composite log-prob aggregation wasn't defined explicitly",
2462+
with (
2463+
pytest.warns(
2464+
DeprecationWarning, match="You are querying the log-probability key"
2465+
),
2466+
pytest.warns(
2467+
DeprecationWarning,
2468+
match="Composite log-prob aggregation wasn't defined explicitly",
2469+
),
24662470
):
24672471
td = TensorDict(
24682472
params=TensorDict(

test/test_tensorclass.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,10 +1625,13 @@ class MyDataNested:
16251625
X = torch.randn(3, 4, 5)
16261626
z = ["a", "b", "c"]
16271627
batch_size = [3, 4]
1628-
with set_list_to_stack(list_to_stack), (
1629-
pytest.raises(RuntimeError, match="batch dimension mismatch")
1630-
if list_to_stack
1631-
else contextlib.nullcontext()
1628+
with (
1629+
set_list_to_stack(list_to_stack),
1630+
(
1631+
pytest.raises(RuntimeError, match="batch dimension mismatch")
1632+
if list_to_stack
1633+
else contextlib.nullcontext()
1634+
),
16321635
):
16331636
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
16341637
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
@@ -1662,10 +1665,13 @@ class MyDataNested:
16621665
X = torch.ones(3, 4, 5)
16631666
z = ["a", "b", "c"]
16641667
batch_size = [3, 4]
1665-
with set_list_to_stack(list_to_stack), (
1666-
pytest.raises(RuntimeError, match="batch dimension mismatch")
1667-
if list_to_stack
1668-
else contextlib.nullcontext()
1668+
with (
1669+
set_list_to_stack(list_to_stack),
1670+
(
1671+
pytest.raises(RuntimeError, match="batch dimension mismatch")
1672+
if list_to_stack
1673+
else contextlib.nullcontext()
1674+
),
16691675
):
16701676
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
16711677
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)

0 commit comments

Comments
 (0)