From 4894d1e51f1db23b07290ab572791dc776f67b6e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 26 Sep 2024 16:56:32 +0100 Subject: [PATCH 1/7] [Refactor] async copy ghstack-source-id: d12b596cce3db900ca584d0956cef03105db510f Pull Request resolved: https://github.com/pytorch/tensordict/pull/1011 --- tensordict/nn/cudagraphs.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index e4b597dd2..b30b3ed4c 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -272,20 +272,20 @@ def check_tensor_id(name, t0, t1): return tensordict_out.update(out, clone=True) return out.clone() if self._out is not None else None else: - self._tensordict.update_(tensordict) + self._tensordict.update_(tensordict, non_blocking=True) torch.cuda.synchronize() self.graph.replay() - torch.cuda.synchronize() if self._out_matches_in: - return tensordict.update( + result = tensordict.update( self._out, keys_to_update=self._selected_keys ) - if tensordict_out is not None: - return tensordict_out.update(self._out, clone=True) - return self._out.clone() if self._out is not None else None - + elif tensordict_out is not None: + result = tensordict_out.update(self._out, clone=True) + else: + result = self._out.clone() if self._out is not None else None + torch.cuda.synchronize() + return result else: - def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter < self._warmup: if self._warmup_stream is not None: @@ -350,20 +350,23 @@ def check_device_and_clone(x): return tree_map(lambda x: x.clone(), out) else: tree_map( - lambda x, y: x.copy_(y), + lambda x, y: x.copy_(y, non_blocking=True), (self._args, self._kwargs), (args, kwargs), ) torch.cuda.synchronize() self.graph.replay() - torch.cuda.synchronize() if self._return_unchanged == "clone": - return self._out.clone() + result = self._out.clone() elif self._return_unchanged: - return self._out - return tree_map( - lambda x: x.clone() if x is not None else x, self._out - ) + result = self._out + else: + result = tree_map( + lambda x: x.clone() if x is not None else x, self._out + ) + torch.cuda.synchronize() + return result + _call_func = functools.wraps(self.module)(_call) self._call_func = _call_func From 1f2e931187900e1321071f85f0399052f269404b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 30 Sep 2024 01:59:18 -0700 Subject: [PATCH 2/7] init --- tensordict/base.py | 8 +++---- tensordict/nn/cudagraphs.py | 48 ++++++++++++++----------------------- 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index e0451faa2..c7fb1f93d 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7936,10 +7936,10 @@ def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float): end_val = end._values_list(True, True) else: end_val = end - if _is_tensor_collection(type(weight)): - weight_val = weight._values_list(True, True) - else: - weight_val = weight + # if isinstance(weight, TensorDictBase) or _is_tensor_collection(type(weight)): + # weight_val = weight._values_list(True, True) + #else: + weight_val = weight torch._foreach_lerp_(self._values_list(True, True), end_val, weight_val) return self diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index b30b3ed4c..16cd01800 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -36,7 +36,6 @@ def tree_leaves(pytree): """Torch 2.0 compatible version of tree_leaves.""" return tree_flatten(pytree)[0] - class CudaGraphModule: """A cudagraph wrapper for PyTorch callables. @@ -209,8 +208,7 @@ def _call( **kwargs: Any, ) -> Any: if self.counter < self._warmup: - if self._warmup_stream is not None: - self._warmup_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.synchronize() with self._warmup_stream_cm: if tensordict_out is not None: kwargs["tensordict_out"] = tensordict_out @@ -218,8 +216,7 @@ def _call( if self._out_matches_in is None: self._out_matches_in = out is tensordict self.counter += self._has_cuda - if self._warmup_stream is not None: - torch.cuda.current_stream().wait_stream(self._warmup_stream) + torch.cuda.synchronize() return out elif self.counter == self._warmup: if tensordict.device is None: @@ -230,15 +227,17 @@ def _call( ) tree_map(self._check_non_tensor, (args, kwargs)) + self._tensordict = tensordict.copy() - self.graph = torch.cuda.CUDAGraph() torch.cuda.synchronize() - self._tensordict = tensordict.copy() + this_out = self.module(self._tensordict, *args, **kwargs) + torch.cuda.synchronize() + + self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): if tensordict_out is not None: kwargs["tensordict_out"] = tensordict_out out = self.module(self._tensordict, *args, **kwargs) - self.graph.replay() if not is_tensor_collection(out) and out is not None: raise RuntimeError( @@ -265,15 +264,9 @@ def check_tensor_id(name, t0, t1): default=None, filter_empty=True, ) - return tensordict.update( - self._out, keys_to_update=self._selected_keys - ) - if tensordict_out is not None: - return tensordict_out.update(out, clone=True) - return out.clone() if self._out is not None else None + return this_out else: self._tensordict.update_(tensordict, non_blocking=True) - torch.cuda.synchronize() self.graph.replay() if self._out_matches_in: result = tensordict.update( @@ -283,17 +276,14 @@ def check_tensor_id(name, t0, t1): result = tensordict_out.update(self._out, clone=True) else: result = self._out.clone() if self._out is not None else None - torch.cuda.synchronize() return result else: def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter < self._warmup: - if self._warmup_stream is not None: - self._warmup_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.synchronize() with self._warmup_stream_cm: out = self.module(*args, **kwargs) - if self._warmup_stream is not None: - torch.cuda.current_stream().wait_stream(self._warmup_stream) + torch.cuda.synchronize() self.counter += self._has_cuda return out elif self.counter == self._warmup: @@ -317,6 +307,9 @@ def check_device_and_clone(x): self._args, self._kwargs = tree_map( check_device_and_clone, (args, kwargs) ) + torch.cuda.synchronize() + this_out = self.module(*self._args, **self._kwargs) + torch.cuda.synchronize() self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): out = self.module(*self._args, **self._kwargs) @@ -341,20 +334,15 @@ def check_device_and_clone(x): self._return_unchanged = ( "clone" if self._out is not None else True ) - return ( - self._out.clone() - if self._return_unchanged == "clone" - else self._out - ) - self._return_unchanged = False - return tree_map(lambda x: x.clone(), out) + else: + self._return_unchanged = False + return this_out else: tree_map( lambda x, y: x.copy_(y, non_blocking=True), (self._args, self._kwargs), (args, kwargs), ) - torch.cuda.synchronize() self.graph.replay() if self._return_unchanged == "clone": result = self._out.clone() @@ -362,9 +350,9 @@ def check_device_and_clone(x): result = self._out else: result = tree_map( - lambda x: x.clone() if x is not None else x, self._out + lambda x: x.detach().clone() if x is not None else x, self._out ) - torch.cuda.synchronize() + # torch.cuda.synchronize() return result From 6d1b1dd1db70049149033a748e84b022426bfaad Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 30 Sep 2024 11:43:48 +0100 Subject: [PATCH 3/7] amend --- tensordict/base.py | 14 ++++++--- tensordict/nn/cudagraphs.py | 62 ++++++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index c7fb1f93d..7d846bd7b 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7916,7 +7916,9 @@ def lerp( end_val = end._values_list(True, True) else: end_val = end - if _is_tensor_collection(type(weight)): + if isinstance(weight, (float, torch.Tensor)): + weight_val = weight + elif _is_tensor_collection(type(weight)): weight_val = weight._values_list(True, True) else: weight_val = weight @@ -7936,10 +7938,12 @@ def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float): end_val = end._values_list(True, True) else: end_val = end - # if isinstance(weight, TensorDictBase) or _is_tensor_collection(type(weight)): - # weight_val = weight._values_list(True, True) - #else: - weight_val = weight + if isinstance(weight, (float, torch.Tensor)): + weight_val = weight + elif _is_tensor_collection(type(weight)): + weight_val = weight._values_list(True, True) + else: + weight_val = weight torch._foreach_lerp_(self._values_list(True, True), end_val, weight_val) return self diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 16cd01800..9a555ebab 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -36,6 +36,7 @@ def tree_leaves(pytree): """Torch 2.0 compatible version of tree_leaves.""" return tree_flatten(pytree)[0] + class CudaGraphModule: """A cudagraph wrapper for PyTorch callables. @@ -198,6 +199,8 @@ def __init__( ) _exclude_td_from_pytree().set() + functools.update_wrapper(self, module) + if self._is_tensordict_module: @dispatch(source=self.in_keys, dest=self.out_keys, auto_batch_size=False) @@ -208,7 +211,8 @@ def _call( **kwargs: Any, ) -> Any: if self.counter < self._warmup: - torch.cuda.synchronize() + if self._has_cuda: + torch.cuda.synchronize() with self._warmup_stream_cm: if tensordict_out is not None: kwargs["tensordict_out"] = tensordict_out @@ -216,9 +220,10 @@ def _call( if self._out_matches_in is None: self._out_matches_in = out is tensordict self.counter += self._has_cuda - torch.cuda.synchronize() + if self._has_cuda: + torch.cuda.synchronize() return out - elif self.counter == self._warmup: + elif self.counter == self._warmup - 1: if tensordict.device is None: tensordict.apply(self._check_device_and_grad, filter_empty=True) elif tensordict.device.type != "cuda": @@ -277,39 +282,29 @@ def check_tensor_id(name, t0, t1): else: result = self._out.clone() if self._out is not None else None return result + else: + def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter < self._warmup: - torch.cuda.synchronize() + if self._has_cuda: + torch.cuda.synchronize() with self._warmup_stream_cm: out = self.module(*args, **kwargs) - torch.cuda.synchronize() + if self._has_cuda: + torch.cuda.synchronize() self.counter += self._has_cuda return out - elif self.counter == self._warmup: - - def check_device_and_clone(x): - if isinstance(x, torch.Tensor) or is_tensor_collection(x): - if x.requires_grad: - raise RuntimeError(self._REQUIRES_GRAD_ERROR) - if x.device is None: - # Check device of leaves of tensordict - x.apply(self._check_device_and_grad, filter_empty=True) - - elif x.device.type != "cuda": - raise ValueError( - f"All tensors must be stored on CUDA. Got {x.device.type}." - ) - - return x.clone() - return x + elif self.counter == self._warmup - 1: self._args, self._kwargs = tree_map( - check_device_and_clone, (args, kwargs) + self._check_device_and_clone, (args, kwargs) ) + torch.cuda.synchronize() this_out = self.module(*self._args, **self._kwargs) torch.cuda.synchronize() + self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): out = self.module(*self._args, **self._kwargs) @@ -350,15 +345,32 @@ def check_device_and_clone(x): result = self._out else: result = tree_map( - lambda x: x.detach().clone() if x is not None else x, self._out + lambda x: x.detach().clone() if x is not None else x, + self._out, ) # torch.cuda.synchronize() return result - _call_func = functools.wraps(self.module)(_call) self._call_func = _call_func + @classmethod + def _check_device_and_clone(cls, x): + if isinstance(x, torch.Tensor) or is_tensor_collection(x): + if x.requires_grad: + raise RuntimeError(cls._REQUIRES_GRAD_ERROR) + if x.device is None: + # Check device of leaves of tensordict + x.apply(cls._check_device_and_grad, filter_empty=True) + + elif x.device.type != "cuda": + raise ValueError( + f"All tensors must be stored on CUDA. Got {x.device.type}." + ) + + return x.clone() + return x + @classmethod def _check_device_and_grad(cls, x): if isinstance(x, torch.Tensor): From b3bc198d0b7a75b3647ffa2aef8730c2a57c7a82 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 30 Sep 2024 12:25:32 +0100 Subject: [PATCH 4/7] amend --- tensordict/nn/cudagraphs.py | 71 +++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 9a555ebab..ecaef85b6 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -210,7 +210,20 @@ def _call( tensordict_out: TensorDictBase | None = None, **kwargs: Any, ) -> Any: - if self.counter < self._warmup: + if self.counter >= self._warmup: + self._tensordict.update_(tensordict, non_blocking=True) + self.graph.replay() + if self._out_matches_in: + result = tensordict.update( + self._out, keys_to_update=self._selected_keys + ) + elif tensordict_out is not None: + result = tensordict_out.update(self._out, clone=True) + else: + result = self._out.clone() if self._out is not None else None + return result + + if not self._has_cuda or self.counter < self._warmup - 1: if self._has_cuda: torch.cuda.synchronize() with self._warmup_stream_cm: @@ -223,7 +236,7 @@ def _call( if self._has_cuda: torch.cuda.synchronize() return out - elif self.counter == self._warmup - 1: + else: if tensordict.device is None: tensordict.apply(self._check_device_and_grad, filter_empty=True) elif tensordict.device.type != "cuda": @@ -270,23 +283,30 @@ def check_tensor_id(name, t0, t1): filter_empty=True, ) return this_out - else: - self._tensordict.update_(tensordict, non_blocking=True) - self.graph.replay() - if self._out_matches_in: - result = tensordict.update( - self._out, keys_to_update=self._selected_keys - ) - elif tensordict_out is not None: - result = tensordict_out.update(self._out, clone=True) - else: - result = self._out.clone() if self._out is not None else None - return result else: def _call(*args: torch.Tensor, **kwargs: torch.Tensor): - if self.counter < self._warmup: + if self.counter >= self._warmup: + tree_map( + lambda x, y: x.copy_(y, non_blocking=True), + (self._args, self._kwargs), + (args, kwargs), + ) + self.graph.replay() + if self._return_unchanged == "clone": + result = self._out.clone() + elif self._return_unchanged: + result = self._out + else: + result = tree_map( + lambda x: x.detach().clone() if x is not None else x, + self._out, + ) + # torch.cuda.synchronize() + return result + + if not self._has_cuda or self.counter < self._warmup - 1: if self._has_cuda: torch.cuda.synchronize() with self._warmup_stream_cm: @@ -295,8 +315,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): torch.cuda.synchronize() self.counter += self._has_cuda return out - elif self.counter == self._warmup - 1: - + else: self._args, self._kwargs = tree_map( self._check_device_and_clone, (args, kwargs) ) @@ -332,24 +351,6 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): else: self._return_unchanged = False return this_out - else: - tree_map( - lambda x, y: x.copy_(y, non_blocking=True), - (self._args, self._kwargs), - (args, kwargs), - ) - self.graph.replay() - if self._return_unchanged == "clone": - result = self._out.clone() - elif self._return_unchanged: - result = self._out - else: - result = tree_map( - lambda x: x.detach().clone() if x is not None else x, - self._out, - ) - # torch.cuda.synchronize() - return result _call_func = functools.wraps(self.module)(_call) self._call_func = _call_func From 2608f16cf2a7f9debbb6d4af09966ac9941a87cf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 30 Sep 2024 13:39:58 +0100 Subject: [PATCH 5/7] amend --- tensordict/nn/cudagraphs.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index ecaef85b6..c17aea104 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -248,7 +248,7 @@ def _call( self._tensordict = tensordict.copy() torch.cuda.synchronize() - this_out = self.module(self._tensordict, *args, **kwargs) + this_out = self.module(tensordict, *args, **kwargs) torch.cuda.synchronize() self.graph = torch.cuda.CUDAGraph() @@ -303,7 +303,6 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): lambda x: x.detach().clone() if x is not None else x, self._out, ) - # torch.cuda.synchronize() return result if not self._has_cuda or self.counter < self._warmup - 1: @@ -321,13 +320,12 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): ) torch.cuda.synchronize() - this_out = self.module(*self._args, **self._kwargs) + this_out = self.module(*args, **kwargs) torch.cuda.synchronize() self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): out = self.module(*self._args, **self._kwargs) - self.graph.replay() self._out = out self.counter += 1 # Check that there is not intersection between the indentity of inputs and outputs, otherwise warn From cb5cdc5d96efce9a85edf71696b548d88cad1862 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 30 Sep 2024 14:01:26 +0100 Subject: [PATCH 6/7] amend --- tensordict/nn/cudagraphs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index c17aea104..60a20a508 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -199,8 +199,6 @@ def __init__( ) _exclude_td_from_pytree().set() - functools.update_wrapper(self, module) - if self._is_tensordict_module: @dispatch(source=self.in_keys, dest=self.out_keys, auto_batch_size=False) @@ -235,6 +233,7 @@ def _call( self.counter += self._has_cuda if self._has_cuda: torch.cuda.synchronize() + print('self.module', self.module, 'out', out, 'tensordict', tensordict) return out else: if tensordict.device is None: @@ -246,15 +245,18 @@ def _call( tree_map(self._check_non_tensor, (args, kwargs)) self._tensordict = tensordict.copy() + if tensordict_out is not None: + td_out_save = tensordict_out.copy() + kwargs["tensordict_out"] = tensordict_out torch.cuda.synchronize() this_out = self.module(tensordict, *args, **kwargs) torch.cuda.synchronize() self.graph = torch.cuda.CUDAGraph() + if tensordict_out is not None: + kwargs["tensordict_out"] = td_out_save with torch.cuda.graph(self.graph): - if tensordict_out is not None: - kwargs["tensordict_out"] = tensordict_out out = self.module(self._tensordict, *args, **kwargs) if not is_tensor_collection(out) and out is not None: From 4489acd785d0022cb56e21848ff7c5149da96865 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 30 Sep 2024 14:01:39 +0100 Subject: [PATCH 7/7] amend --- tensordict/nn/cudagraphs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 60a20a508..5cd1fa301 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -233,7 +233,6 @@ def _call( self.counter += self._has_cuda if self._has_cuda: torch.cuda.synchronize() - print('self.module', self.module, 'out', out, 'tensordict', tensordict) return out else: if tensordict.device is None: