Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7916,7 +7916,9 @@ def lerp(
end_val = end._values_list(True, True)
else:
end_val = end
if _is_tensor_collection(type(weight)):
if isinstance(weight, (float, torch.Tensor)):
weight_val = weight
elif _is_tensor_collection(type(weight)):
weight_val = weight._values_list(True, True)
else:
weight_val = weight
Expand All @@ -7936,7 +7938,9 @@ def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float):
end_val = end._values_list(True, True)
else:
end_val = end
if _is_tensor_collection(type(weight)):
if isinstance(weight, (float, torch.Tensor)):
weight_val = weight
elif _is_tensor_collection(type(weight)):
weight_val = weight._values_list(True, True)
else:
weight_val = weight
Expand Down
157 changes: 80 additions & 77 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,20 +208,33 @@ def _call(
tensordict_out: TensorDictBase | None = None,
**kwargs: Any,
) -> Any:
if self.counter < self._warmup:
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
if self.counter >= self._warmup:
self._tensordict.update_(tensordict, non_blocking=True)
self.graph.replay()
if self._out_matches_in:
result = tensordict.update(
self._out, keys_to_update=self._selected_keys
)
elif tensordict_out is not None:
result = tensordict_out.update(self._out, clone=True)
else:
result = self._out.clone() if self._out is not None else None
return result

if not self._has_cuda or self.counter < self._warmup - 1:
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(tensordict, *args, **kwargs)
if self._out_matches_in is None:
self._out_matches_in = out is tensordict
self.counter += self._has_cuda
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
if self._has_cuda:
torch.cuda.synchronize()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens What's the purpose of this change? IIUC this is forcing more syncs than the old version. Was the old version causing bugs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, data could be read before it was fully updated
We can make this optional if you want to experiment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the offer. I don't think I'll need to experiment with this for now.

return out
elif self.counter == self._warmup:
else:
if tensordict.device is None:
tensordict.apply(self._check_device_and_grad, filter_empty=True)
elif tensordict.device.type != "cuda":
Expand All @@ -230,15 +243,20 @@ def _call(
)

tree_map(self._check_non_tensor, (args, kwargs))
self._tensordict = tensordict.copy()
if tensordict_out is not None:
td_out_save = tensordict_out.copy()
kwargs["tensordict_out"] = tensordict_out

self.graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
self._tensordict = tensordict.copy()
this_out = self.module(tensordict, *args, **kwargs)
torch.cuda.synchronize()

self.graph = torch.cuda.CUDAGraph()
if tensordict_out is not None:
kwargs["tensordict_out"] = td_out_save
with torch.cuda.graph(self.graph):
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(self._tensordict, *args, **kwargs)
self.graph.replay()

if not is_tensor_collection(out) and out is not None:
raise RuntimeError(
Expand All @@ -265,62 +283,50 @@ def check_tensor_id(name, t0, t1):
default=None,
filter_empty=True,
)
return tensordict.update(
self._out, keys_to_update=self._selected_keys
)
if tensordict_out is not None:
return tensordict_out.update(out, clone=True)
return out.clone() if self._out is not None else None
else:
self._tensordict.update_(tensordict)
torch.cuda.synchronize()
self.graph.replay()
torch.cuda.synchronize()
if self._out_matches_in:
return tensordict.update(
self._out, keys_to_update=self._selected_keys
)
if tensordict_out is not None:
return tensordict_out.update(self._out, clone=True)
return self._out.clone() if self._out is not None else None
return this_out

else:

def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter < self._warmup:
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
if self.counter >= self._warmup:
tree_map(
lambda x, y: x.copy_(y, non_blocking=True),
(self._args, self._kwargs),
(args, kwargs),
)
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
)
return result

if not self._has_cuda or self.counter < self._warmup - 1:
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
out = self.module(*args, **kwargs)
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
if self._has_cuda:
torch.cuda.synchronize()
self.counter += self._has_cuda
return out
elif self.counter == self._warmup:

def check_device_and_clone(x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
if x.requires_grad:
raise RuntimeError(self._REQUIRES_GRAD_ERROR)
if x.device is None:
# Check device of leaves of tensordict
x.apply(self._check_device_and_grad, filter_empty=True)

elif x.device.type != "cuda":
raise ValueError(
f"All tensors must be stored on CUDA. Got {x.device.type}."
)

return x.clone()
return x

else:
self._args, self._kwargs = tree_map(
check_device_and_clone, (args, kwargs)
self._check_device_and_clone, (args, kwargs)
)

torch.cuda.synchronize()
this_out = self.module(*args, **kwargs)
torch.cuda.synchronize()

self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self.graph.replay()
self._out = out
self.counter += 1
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
Expand All @@ -341,33 +347,30 @@ def check_device_and_clone(x):
self._return_unchanged = (
"clone" if self._out is not None else True
)
return (
self._out.clone()
if self._return_unchanged == "clone"
else self._out
)
self._return_unchanged = False
return tree_map(lambda x: x.clone(), out)
else:
tree_map(
lambda x, y: x.copy_(y),
(self._args, self._kwargs),
(args, kwargs),
)
torch.cuda.synchronize()
self.graph.replay()
torch.cuda.synchronize()
if self._return_unchanged == "clone":
return self._out.clone()
elif self._return_unchanged:
return self._out
return tree_map(
lambda x: x.clone() if x is not None else x, self._out
)
else:
self._return_unchanged = False
return this_out

_call_func = functools.wraps(self.module)(_call)
self._call_func = _call_func

@classmethod
def _check_device_and_clone(cls, x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
if x.requires_grad:
raise RuntimeError(cls._REQUIRES_GRAD_ERROR)
if x.device is None:
# Check device of leaves of tensordict
x.apply(cls._check_device_and_grad, filter_empty=True)

elif x.device.type != "cuda":
raise ValueError(
f"All tensors must be stored on CUDA. Got {x.device.type}."
)

return x.clone()
return x

@classmethod
def _check_device_and_grad(cls, x):
if isinstance(x, torch.Tensor):
Expand Down