Skip to content

Commit 9ff1092

Browse files
committedMar 25, 2025
[Refactor] vLLMWrapper class
ghstack-source-id: 21ed901c0428451312cad5fb85e149a9dd694819 Pull Request resolved: #2870
1 parent 7df8317 commit 9ff1092

File tree

6 files changed

+590
-13
lines changed

6 files changed

+590
-13
lines changed
 

‎test/test_actors.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
from torchrl.envs import LLMEnv
2424
from torchrl.modules import (
2525
from_hf_transformers,
26-
from_vllm,
2726
MLP,
2827
SafeModule,
2928
TanhDelta,
3029
TanhNormal,
30+
vLLMWrapper,
3131
)
3232
from torchrl.modules.tensordict_module.actors import (
3333
_process_action_space_spec,
@@ -1031,7 +1031,7 @@ def test_from_vllm(
10311031
torch.manual_seed(0)
10321032

10331033
model = vllm_instance
1034-
m = from_vllm(
1034+
m = vLLMWrapper(
10351035
model,
10361036
from_text=from_text,
10371037
generate=generate,
@@ -1207,14 +1207,14 @@ def test_from_vllm_logprobs(
12071207
torch.manual_seed(0)
12081208

12091209
model = vllm_instance
1210-
m_generate = from_vllm(
1210+
m_generate = vLLMWrapper(
12111211
model,
12121212
from_text=from_text,
12131213
generate=True,
12141214
return_log_probs=True,
12151215
pad_output=pad_output,
12161216
)
1217-
m_logprobs = from_vllm(
1217+
m_logprobs = vLLMWrapper(
12181218
model, from_text=from_text, generate=False, pad_output=pad_output
12191219
)
12201220
self._check_lps(
@@ -1264,7 +1264,7 @@ def _check_lps(
12641264
@pytest.mark.parametrize("use_tensorclass", [True, False])
12651265
def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance):
12661266
# Test generate - padding combinations
1267-
policy = from_vllm(
1267+
policy = vLLMWrapper(
12681268
vllm_instance,
12691269
from_text=True,
12701270
generate=generate,
@@ -1331,7 +1331,7 @@ def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance):
13311331
assert isinstance(tokens, list)
13321332

13331333
def test_vllm_collection(self, vllm_instance):
1334-
policy = from_vllm(
1334+
policy = vLLMWrapper(
13351335
vllm_instance,
13361336
from_text=True,
13371337
generate=True,

‎torchrl/modules/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
)
9494
from .utils import get_primers_from_module
9595
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
96-
from .llm import from_hf_transformers, from_vllm
96+
from .llm import from_hf_transformers, vLLMWrapper
9797

9898
__all__ = [
9999
"Actor",
@@ -178,7 +178,7 @@
178178
"WorldModelWrapper",
179179
"distributions_maps",
180180
"from_hf_transformers",
181-
"from_vllm",
181+
"vLLMWrapper",
182182
"get_primers_from_module",
183183
"recurrent_mode",
184184
"reset_noise",

‎torchrl/modules/llm/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .common import CategoricalSequential
77
from .transformers_policy import from_hf_transformers
8-
from .vllm_policy import from_vllm
98

10-
__all__ = ["from_hf_transformers", "from_vllm", "CategoricalSequential"]
9+
from .vllm_wrapper import vLLMWrapper
10+
11+
__all__ = ["from_hf_transformers", "vLLMWrapper", "CategoricalSequential"]

‎torchrl/modules/llm/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from tensordict import NestedKey, TensorDictBase
88
from tensordict.nn import (
99
ProbabilisticTensorDictModule,
10-
ProbabilisticTensorDictSequential,
10+
TensorDictModuleBase,
1111
TensorDictSequential,
1212
)
1313
from torch import distributions as D
1414
from torch.distributions import Categorical
1515

1616

17-
class CategoricalSequential(ProbabilisticTensorDictSequential):
17+
class CategoricalSequential(TensorDictModuleBase):
1818
"""A ProbabilisticTensorDictSequential subclass meant to work with LLMs.
1919
2020
.. seealso:: :class:`~tensordict.nn.ProbabilisticTensorDictSequential` class.

‎torchrl/modules/llm/transformers_policy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def from_hf_transformers(
181181
>>> print(output_data.text_response)
182182
[' heritageillon rooft rooft Pear Tes grantingalde 58ocrocrocrocrcubecubecubecubecubecubecube']
183183
184-
.. seealso:: :func:`~torchrl.modules.from_vllm` for a similar interface using the vLLM library.
184+
.. seealso:: :func:`~torchrl.modules.vLLMWrapper` for a similar interface using the vLLM library.
185185
186186
"""
187187
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks

0 commit comments

Comments
 (0)
Failed to load comments.