Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
yf225 committed Jun 19, 2024
1 parent 995eb18 commit 321a014
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 55 deletions.
23 changes: 0 additions & 23 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7103,29 +7103,6 @@ def forward(self, dict_input):
):
self.assertEqual(param.grad, checkpoint_param.grad)

def test_callback(self):
called = [0]

def callback_final():
called[0] += 1

class MyFunc(Function):
@staticmethod
def forward(ctx, input):
return input

@staticmethod
@once_differentiable
def backward(ctx, grad):
Variable._execution_engine.queue_callback(callback_final)
return grad

a = torch.rand((3, 3), requires_grad=True)
b = MyFunc.apply(a)
b.sum().backward()

self.assertEqual(called[0], 1)

def test_callback_adds_callback(self):
called = [0]

Expand Down
16 changes: 8 additions & 8 deletions torch/_dynamo/create_parameter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
""".strip()


# lib = torch.library.Library("create_parameter_op", "FRAGMENT")
lib = torch.library.Library("create_parameter_op", "FRAGMENT")

# lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()")
lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()")

# @torch.library.impl(lib, "set_", "Meta")
# def set_(tensor, data):
# tensor.set_(data)
@torch.library.impl(lib, "set_", "Meta")
def set_(tensor, data):
tensor.set_(data)

# @torch.library.impl(lib, "set_", "CUDA")
# def set_(tensor, data):
# tensor.set_(data)
@torch.library.impl(lib, "set_", "CUDA")
def set_(tensor, data):
tensor.set_(data)


class TracableCreateParameter(torch.autograd.Function):
Expand Down
1 change: 0 additions & 1 deletion torch/_dynamo/external_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
import torch.utils._pytree as pytree
from torch.autograd.variable import compiled_autograd_final_callbacks

try:
import numpy as np
Expand Down
14 changes: 1 addition & 13 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
same,
set_example_value,
)
from .variables.base import MutableLocal, VariableTracker
from .variables.base import VariableTracker
from .variables.builder import (
BackwardStateGraphArg,
GraphArg,
Expand Down Expand Up @@ -408,11 +408,6 @@ def __init__(

self.guard_on_key_order: Set[str] = set()

# Track compiled autograd final callbacks that must be called at the end of this graph.
# Only applicable if this graph is created from Dynamo tracing in Compiled Autograd.
self.ca_final_callbacks: List[Callable] = []
self.ca_final_callbacks_var = None

def install_builtins_dict_in_fglobals(self):
# f_globals["__builtins__"] can be a dict or a module. This is an
# implemenation detail -
Expand Down Expand Up @@ -450,13 +445,6 @@ def get_backward_state_proxy(self):
self.backward_state_var = self.new_var()
return self.backward_state_proxy

def get_ca_final_callbacks_var(self):
if self.ca_final_callbacks_var is None:
self.ca_final_callbacks_var = variables.ListVariable(
self.ca_final_callbacks, mutable_local=MutableLocal()
)
return self.ca_final_callbacks_var

# This gets its own helper function so guards DEBUG logs are more informative
def init_ambient_guards(self):
# Register a SHAPE_ENV guard to make sure we setup shape guards
Expand Down
1 change: 0 additions & 1 deletion torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@
"torch._tensor._convert": UserFunctionVariable,
"torch.jit._unwrap_optional": UserFunctionVariable,
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
"torch.autograd.variable.queue_callback": UserFunctionVariable,
"torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable,
"torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable,
"torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable,
Expand Down
9 changes: 0 additions & 9 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,6 @@ def __repr__(self):

@staticmethod
def is_matching_cls(value):
# Update supported_ctx_manager_classes here to avoid circular import
import torch.distributed._composable.fsdp
supported_ctx_manager_classes.update(
dict.fromkeys(
[
torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup.use_training_state,
]
)
)
# Unwrap if it's a functools.lru_cache wrapper
value = unwrap_if_wrapper(value)
# We can't do isinstance(value, type) check because some ctx managers
Expand Down

0 comments on commit 321a014

Please sign in to comment.