Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
fa40a45
Build the keys of a dict depth-first
lezcano Jan 17, 2024
125bd38
Update on "Build the keys of a dict depth-first"
lezcano Jan 17, 2024
fdb6e07
Update on "Build the keys of a dict depth-first"
lezcano Jan 17, 2024
becb54f
Update on "Build the keys of a dict depth-first"
lezcano Jan 17, 2024
f2d5fd5
Update on "Build the keys of a dict depth-first"
lezcano Jan 17, 2024
4ea56b9
Update on "Build the keys of a dict depth-first"
lezcano Jan 17, 2024
64c30dd
Update on "Build the keys of a dict depth-first"
lezcano Jan 17, 2024
921dc04
Update on "Build the keys of a dict depth-first"
lezcano Jan 17, 2024
eb1fa85
Update on "Build the keys of a dict depth-first"
lezcano Jan 18, 2024
2c67f6a
Update on "Make variables in dict lazy and avoid using DICT_KEYS guard"
lezcano Jan 18, 2024
6812a60
Update on "Make variables in dict lazy and avoid using DICT_KEYS guard"
lezcano Jan 18, 2024
1002aca
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 22, 2024
b7a7c00
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 22, 2024
bfc393d
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 22, 2024
5149aab
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 22, 2024
61d4ea1
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 22, 2024
f38b105
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 22, 2024
10f480f
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 23, 2024
74d81fd
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 23, 2024
9786535
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 23, 2024
4be4a98
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 24, 2024
3f05363
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 24, 2024
2ecd958
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 24, 2024
267b853
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 24, 2024
5fc6f32
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 25, 2024
f0173b2
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 25, 2024
6128e6b
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 26, 2024
510baf4
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 26, 2024
f8cd51d
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 29, 2024
49ee382
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 30, 2024
01b93b0
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 30, 2024
16281ef
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Jan 30, 2024
9931d42
Update on "Make variables in dict LazyTrackers (not lazily guarded ye…
lezcano Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,22 +322,22 @@ def my_args_generator(t):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
l_d_x_ = L_d_x_
l_d_y_0_ = L_d_y_0_
l_d_y_1_2_ = L_d_y_1_2_
def forward(self, L_d_dict_keys_getitem_L_d_0_ : torch.Tensor, L_d_dict_keys_getitem_L_d_1_0_ : torch.Tensor, L_d_dict_keys_getitem_L_d_1_1_2_ : torch.Tensor):
l_d_dict_keys_getitem_l_d_0_ = L_d_dict_keys_getitem_L_d_0_
l_d_dict_keys_getitem_l_d_1_0_ = L_d_dict_keys_getitem_L_d_1_0_
l_d_dict_keys_getitem_l_d_1_1_2_ = L_d_dict_keys_getitem_L_d_1_1_2_

wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_dict_keys_getitem_l_d_0_, l_d_dict_keys_getitem_l_d_1_0_, l_d_dict_keys_getitem_l_d_1_1_2_); wrap_body_0 = l_d_dict_keys_getitem_l_d_0_ = l_d_dict_keys_getitem_l_d_1_0_ = l_d_dict_keys_getitem_l_d_1_1_2_ = None
getitem = wrap[0]; wrap = None
return (getitem,)

class GraphModule(torch.nn.Module):
def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
sin = l_d_x_.sin(); l_d_x_ = None
cos = l_d_y_0_.cos(); l_d_y_0_ = None
def forward(self, l_d_dict_keys_getitem_l_d_0_, l_d_dict_keys_getitem_l_d_1_0_, l_d_dict_keys_getitem_l_d_1_1_2_):
sin = l_d_dict_keys_getitem_l_d_0_.sin(); l_d_dict_keys_getitem_l_d_0_ = None
cos = l_d_dict_keys_getitem_l_d_1_0_.cos(); l_d_dict_keys_getitem_l_d_1_0_ = None
add = sin + cos; sin = cos = None
sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
sin_1 = l_d_dict_keys_getitem_l_d_1_1_2_.sin(); l_d_dict_keys_getitem_l_d_1_1_2_ = None
sub = add - sin_1; add = sin_1 = None
return (sub,)
""", # NOQA: B950
Expand Down Expand Up @@ -2338,18 +2338,18 @@ def fn(pred, pytree_in):
self.assertExpectedInline(
graph.code.strip(),
"""\
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_dict_keys_getitem_L_pytree_in_4_0_ : torch.Tensor):
l_pred_ = L_pred_
l_pytree_in_0_ = L_pytree_in_0_
l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
l_pytree_in_2_ = L_pytree_in_2_
l_pytree_in_3_0_ = L_pytree_in_3_0_
l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
l_pytree_in_3_2_ = L_pytree_in_3_2_
l_pytree_in_4_g_ = L_pytree_in_4_g_
l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_ = L_pytree_in_4_dict_keys_getitem_L_pytree_in_4_0_
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_dict_keys_getitem_l_pytree_in_4_0_ = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)
Expand Down
14 changes: 14 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,20 @@ def fn(d):
# Extra calls don't recompile
self.assertEqual(cnts.frame_count, 2)

def test_dict_namedtuple(self):
def fn(d):
return d[3] * 2

args1 = {collections.namedtuple: None, 3: torch.randn(3)}
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(fn(args1), opt_fn(args1))
self.assertEqual(cnts.frame_count, 1)
# Test a failing namedtuple guard
args2 = {2: None, 3: torch.randn(3)}
self.assertEqual(fn(args2), opt_fn(args2))
self.assertEqual(cnts.frame_count, 2)

def test_dict_order_keys_tensors(self):
def fn(d, x):
return d[x] + 3
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,7 +1843,7 @@ def guard_fail_fn(failure):
handle.remove()
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 7)
self.assertTrue("forward_hooks.keys" in failure_reason)
self.assertTrue("forward_hooks" in failure_reason)
self.assertEqual(cc.frame_count, 1 + 1)
self.assertEqual(cc.op_count, 6 + 4)

Expand Down
10 changes: 4 additions & 6 deletions test/test_expanded_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,24 +688,22 @@ def filter_supported_tests(t):
if 'constructor' not in test_param:
name = test_param.pop('module_name')
test_param['constructor'] = getattr(nn, name)
decorator = test_param.pop('decorator', None)
decorator = test_param.pop('decorator', lambda test: test)
test = ContextManagerTests(**test_param)
test_name = test.get_name()
if hasattr(TestExpandedWeightModule, test_name):
raise RuntimeError('Found two tests with the same name: ' + test_name)
test_name_multi_input = test.get_name() + "_multiple_inputs"
if hasattr(TestExpandedWeightModule, test_name_multi_input):
raise RuntimeError('Found two tests with the same name: ' + test_name)
if decorator is not None:
fn = decorator(fn) # noqa: F821
if test.test_cpu:
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self, 'cpu'))
setattr(TestExpandedWeightModule, test_name, decorator(lambda self, test=test: test.test_context_manager(self, 'cpu')))
setattr(TestExpandedWeightModule, test_name_multi_input,
lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu'))
decorator(lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu')))
if TEST_CUDA and test.test_cuda:
# since this checks derivatives, only use double for precision
setattr(TestExpandedWeightModule, test_name + '_cuda_double',
lambda self, test=test: test.test_context_manager(self, 'cuda'))
decorator(lambda self, test=test: test.test_context_manager(self, 'cuda')))

# ------------- HELPER FUNCTIONS -----------------

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def get_sources(t_id, dim):
self._produce_guard_code(guard, [shape_guard], shape_env=True)

def TENSOR_MATCH(self, guard: Guard, value=None):
if guard.is_nn_module():
if guard.is_nn_module() or guard.originating_source.is_dict_key():
self.ID_MATCH(guard)
else:
if isinstance(value, TensorWeakRef):
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def name(self):

@dataclasses.dataclass(frozen=True)
class ConstDictKeySource(GetItemSource):
def is_dict_key(self):
return True

def reconstruct(self, codegen):
return [
*codegen.create_load_import_from(utils.__name__, "dict_keys_getitem"),
Expand Down
27 changes: 10 additions & 17 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@
DataClassVariable,
DefaultDictVariable,
HFPretrainedConfigVariable,
is_hashable_python_var,
PythonSysModulesVariable,
SetVariable,
)
Expand Down Expand Up @@ -412,9 +411,7 @@ class Autotuner:
return ConstDictVariable(result, type(value))
elif value is sys.modules:
return PythonSysModulesVariable(source=self.source)
elif istype(
value, (dict, collections.defaultdict, collections.OrderedDict)
) and all(is_hashable_python_var(k) for k in value.keys()):
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
if not value and self.get_source().is_nn_module():
# It is faster to guard on 'false' property than to guard
# on actual dict keys, but we can't do this fast guard in general because
Expand All @@ -425,26 +422,22 @@ class Autotuner:
# but not completely secure job ensuring a property wasn't changed.
self.install_guards(GuardBuilder.BOOL_FALSE)
else:
self.install_guards(GuardBuilder.DICT_KEYS)
self.install_guards(GuardBuilder.LIST_LENGTH)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to delete DICT_KEYS guard (maybe we will have to clean up at other places)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the plan, yes. I'll put up a PR next week.


idx = 0

def build_key_value(k, v):
nonlocal idx
if ConstantVariable.is_literal(k):
key = ConstantVariable.create(k)
source_key = k
else:
source_key = ConstDictKeySource(self.get_source(), idx)
key = VariableBuilder(self.tx, source_key)(k)
# We need all the keys to be hashable. We do this within the
# _HashableTracker class in dicts.py
def build_key_value(i, k, v):
source_key = ConstDictKeySource(self.get_source(), i)
key = LazyVariableTracker.create(k, source_key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so the only substantive difference is that where before we had a single DICT_KEYS guard we now have len(dict) guards on each key individually. Wasn't this avoided previously because extracting a key at a particular index is O(n)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thereby making guard evaluation O(n^2) on the number of keys.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true that this is now O(n²), but it is also true that the previous approach was completely broken. The previous approach was alright for when we just had constant keys, but now we have everything and anything, so trying to generate a printable version of each object in DICT_KEYS is just too broken.

Let's see what the benchmarks have to say about the compilation times of this approach tho. If it's an issue, I could add a pass where if all objects within the keys of the dict are sourceless then we replace all those checks by a check similar to the previous one.

Another way forward would be to really implementing laziness on the keys of a dict. This would most probably offset O(n) issue on its own.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I think we would have individual guards for each key, but the guard checking code would group them together and only iterate over the dictionary's keys once.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this performs okay and it fixes a real bug then it's okay for now though I guess.

Copy link
Contributor

@anijain2305 anijain2305 Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: We are installing LIST_LEN guard just once on the dict itself (and not the keys). So, why is guard eval O(n^2)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guards generated for the keys themselves will look something like:

guard_0(___dict_keys_getitem(dict, 0))
...
guard_i(___dict_keys_getitem(dict, i))

where eack key's guard evaluation calls ___dict_keys_getitem to reconstruct the key at a particular index. This operation in O(n) because dict.keys() doesn't support indexing, so you need to iterate over the key set until you reach the ith element.

If our guard codegen was smarter though, we could just iterator over the key set once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. That makes sense. I missed that ___dict_keys_getitem itself is O(n).


source_value = GetItemSource(self.get_source(), source_key)
value = LazyVariableTracker.create(v, source_value)

idx += 1
return key, value

result = dict(build_key_value(k, v) for k, v in value.items())
result = dict(
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
)

if istype(value, collections.defaultdict):
result = DefaultDictVariable(
Expand Down
46 changes: 7 additions & 39 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@

import collections
import dataclasses
import enum
import functools
import inspect
import sys
from types import MethodWrapperType
from typing import Dict, List, Optional

import torch

from torch._subclasses.fake_tensor import is_fake

from .. import variables
Expand All @@ -28,32 +24,12 @@
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable

# Note: [Adding a new supported class the keys of ConstDictVarialble]
# You'll need to add it to:
# - `is_hashable_python_var` in this file
# - `is_hashable` in this file
# - `const_repr` in util.py, and perhaps modify DICT_KEYS in guards.py


def is_hashable_python_var(x):
# IMPORTANT: Keep me in sync with is_hashable!
# Even better, we should have a map of functions connecting the two
from torch import Tensor
from ..trace_rules import is_builtin_callable, is_numpy

return (
ConstantVariable.is_literal(x)
or isinstance(x, (Tensor, enum.Enum, type, torch.nn.Module, MethodWrapperType))
or is_builtin_callable(x)
or (isinstance(x, tuple) and all(is_hashable_python_var(e) for e in x))
or is_numpy(x)
)
# [Adding a new supported class within the keys of ConstDictVarialble]
# - Add its tracker type to is_hashable
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl


def is_hashable(x):
# IMPORTANT: Keep me in sync with is_hashable_python_var!
# Even better, we should have a map of functions connecting the two

if isinstance(x, variables.TensorVariable):
# Tensors are hashable if they have an example_value (a fake tensor)
# Most VT's should have one.
Expand Down Expand Up @@ -201,16 +177,6 @@ def reconstruct(self, codegen):
else:
return [create_instruction("BUILD_MAP", arg=len(self.items))]

@staticmethod
def _wrap_keys_python_var(d):
"""Wrap the keys of a dictionary with python objs as keys into Hashable objects"""
assert all(is_hashable_python_var(k) for k in d.keys())
Hashable = ConstDictVariable._HashableTracker
from .builder import SourcelessBuilder

build = SourcelessBuilder()
return {Hashable(build(k)): v for k, v in d.items()}

def getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
Expand Down Expand Up @@ -290,8 +256,10 @@ def call_method(
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
# all keys in kwargs are valid (`str`s)
kwargs = ConstDictVariable._wrap_keys_python_var(kwargs)
# Wrap strings
kwargs = {
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
}
self.items.update(kwargs)
return ConstantVariable.create(None)
elif name in ("get", "__getattr__") and args[0] in self:
Expand Down
7 changes: 7 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,9 @@ def tracing(context: Optional[TracingContext]):
# TODO(voz): Consider a toplevel torch/_source.py
@dataclasses.dataclass(frozen=True)
class Source:
def is_dict_key(self):
return False

def reconstruct(self, codegen):
raise NotImplementedError()

Expand All @@ -788,6 +791,10 @@ def is_nn_module(self) -> bool:
class ChainedSource(Source):
base: Source

def is_dict_key(self):
# Recurse until you either hit a ConstDictKey or a Source
return self.base.is_dict_key()


def detect_fake_mode(inputs: Any = None):
"""
Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn.functional as F
from torch.nn import _reduction as _Reduction
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
gradcheck, gradgradcheck, set_default_dtype
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
from torch.autograd import Variable
Expand Down Expand Up @@ -1709,6 +1709,7 @@ def unsqueeze_inp(inp):
input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
check_gradgrad=False,
default_dtype=torch.double,
decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
),
dict(
module_name='Embedding',
Expand All @@ -1718,6 +1719,7 @@ def unsqueeze_inp(inp):
check_gradgrad=False,
desc='discontiguous',
default_dtype=torch.double,
decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
),
dict(
module_name='EmbeddingBag',
Expand Down
3 changes: 0 additions & 3 deletions torch/testing/_internal/dynamo_test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,6 @@
"TestTorchTidyProfiler.test_tensorimpl_invalidation_full", # profiler/test_profiler
"TestProfiler.test_profiler_tracing", # profiler/test_profiler
"TestProfiler.test_is_profiler_enabled", # profiler/test_profiler
"TestExperimentalUtils.test_utils_compute_idle_time", # profiler/test_profiler
"TestTorchTidyProfiler.test_optimizer_parameters_sgd", # profiler/test_profiler
"TestExperimentalUtils.test_profiler_name_pattern", # profiler/test_profiler
"TestTorchTidyProfiler.test_extra_fields", # profiler/test_profiler
Expand Down Expand Up @@ -1655,13 +1654,11 @@
"TestTorchTidyProfiler.test_sparse_tensors", # profiler/test_profiler
"TestTorchTidyProfiler.test_optimizer", # profiler/test_profiler
"TestTorchTidyProfiler.test_tensorimpl_invalidation_keep_alive", # profiler/test_profiler
"TestExperimentalUtils.test_utils_compute_queue_depth", # profiler/test_profiler
"TestExperimentalUtils.test_profiler_pattern_match_helper", # profiler/test_profiler
"TestProfiler.test_export_stacks", # profiler/test_profiler
"TestProfiler.test_source_multithreaded_basic_work_in_main_thread_True", # profiler/test_profiler
"TestTorchTidyProfiler.test_mkldnn_tensors", # profiler/test_profiler
"TestRecordFunction.test_datapipe_with_record_function", # profiler/test_profiler
"TestProfiler.test_memory_profiler", # profiler/test_profiler
"TestTorchTidyProfiler.test_tensor_lists", # profiler/test_profiler
"TestTorchTidyProfiler.test_pointers_and_ids", # profiler/test_profiler
"TestTorchTidyProfiler.test_nnmodule_params", # profiler/test_profiler
Expand Down