Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanthd04 committed Apr 28, 2024
2 parents 35e1b6a + 6761b49 commit 287f741
Show file tree
Hide file tree
Showing 35 changed files with 1,159 additions and 388 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ per-file-ignores =
torch/ao/quantization/fx/_decomposed.py: TOR901
torch/distributed/_functional_collectives.py: TOR901
torch/distributed/_spmd/data_parallel.py: TOR901
torch/distributed/_tensor/_collective_utils.py: TOR901
optional-ascii-coding = True
exclude =
./.git,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.

// KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
#define KERNEL_PRIVATEUSEONE(OP, ...) \
#define KERNEL_PRIVATEUSEONE(...) \
KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)

#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
Expand Down
10 changes: 10 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def constant3(a, b):
return a - b + (1.0 + 2)


_variable = 0


def update_global(x):
global _variable
_variable += 1
# Check that updated global variable value is picked up
return x * _variable


def func_with_default(a, b, some_default_arg=True):
if some_default_arg:
return a - b
Expand Down
37 changes: 26 additions & 11 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10158,21 +10158,36 @@ def test_linear_module_free(self):
def test_outside_linear_module_free(self):
# Compared to test_linear_module_free, the linear
# layer is not the code object that is directly compiled.
def model_inp_ctr():
fc = torch.nn.Linear(100, 100)

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_ref = fc
# This test does not use _test_compile_model_free because of difficulty
# in handling variable fc.

def forward(self, x):
return fc(x[0])
fc = torch.nn.Linear(100, 100)

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_ref = fc

def forward(self, x):
return fc(x[0])

# return fc to keep it alive in _test_compile_model_free
return Mod(), (torch.randn(100, 100), fc)
cleared = False

def finalize():
nonlocal cleared
cleared = True

self._test_compile_model_free(model_inp_ctr, lambda mod: mod.fc_ref)
def run():
mod = Mod()
inp = torch.randn(100, 100)
weakref.finalize(mod.fc_ref, finalize)
torch.compile(mod, backend="eager")(inp)

run()
del fc # This should delete all the references
gc.collect()
self.assertTrue(cleared)

@unittest.skipIf(sys.version_info >= (3, 12), "leaks in 3.12+")
def test_parameter_free(self):
Expand Down
26 changes: 26 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
import test_functions


_variable = 0
_variable1 = 0


def update_global():
global _variable, _variable1
_variable += 1
_variable1 += 1


class BasicModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -2435,6 +2445,22 @@ def forward(self, inp):

self.assertEqual(model.x, compiled_model.x)

def test_globals_change_in_other_file(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
update_global()
a = test_functions.update_global(x)
# Ensure that the updated global values are read
return x * a * (_variable + _variable1 + test_functions._variable)

res = fn(torch.ones(10))
self.assertEqual(_variable, 1)
self.assertEqual(_variable1, 1)
# Ensure that the reconstructed bytecode updates the global value in the
# other file.
self.assertEqual(test_functions._variable, 1)
self.assertEqual(res, 3 * torch.ones(10))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
66 changes: 66 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: cpu inductor"]
import contextlib
import copy
import functools
import itertools
import math
import platform
Expand Down Expand Up @@ -3048,6 +3049,38 @@ def forward(self, x):
v2 = jit_func(input_tensor)
self.assertEqual(v1, v2)

def test_nn_param_assign_wrapped(self):
class Model2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
self.batchnorm = nn.BatchNorm2d(num_features=5)
self.conv_weight = torch.randn(5, 3, 3, 3)
self.conv_bias = torch.randn(5)

def forward(self, x):
self.conv.weight = nn.Parameter(self.conv_weight)
self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False)
self.conv.eval()
x = self.conv(x)
x = self.batchnorm(x)
x = F.relu(x)
return x

input_tensor = torch.randn(1, 3, 10, 10)
func = Model2().to("cpu")

@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

with torch.no_grad():
func.train(False)
v1 = func(input_tensor)
jit_func = torch.compile(wrapper, fullgraph=True)
v2 = jit_func(input_tensor)
self.assertEqual(v1, v2)

@config.patch(inplace_buffers=True)
def test_in_out_buffer(self):
def fn(x, y):
Expand Down Expand Up @@ -3650,6 +3683,39 @@ def forward(self, x):
x = torch.randn(1, 4, 2, 2)
self.common(fn, (x,))

@requires_vectorization
def test_vec_indirect_load_cse_cache(self):
# https://github.com/pytorch/pytorch/issues/123502
from math import inf

def fn(arg0_1):
full_default = torch.ops.aten.full.default([209985], 1)
select = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1 = torch.ops.aten.select.int(arg0_1, 0, 1)
view = torch.ops.aten.reshape.default(select_1, [-1])
expand = torch.ops.aten.expand.default(view, [209985])
full_default_1 = torch.ops.aten.full.default([10000], 0)
scatter_add = torch.ops.aten.scatter_add.default(
full_default_1, 0, expand, full_default
)
pow_1 = torch.ops.aten.pow.Tensor_Scalar(scatter_add, -0.5)
eq = torch.ops.aten.eq.Scalar(pow_1, inf)
full_default_2 = torch.ops.aten.full.default([], 0.0)
where = torch.ops.aten.where.self(eq, full_default_2, pow_1)
index = torch.ops.aten.index.Tensor(where, [select])
index_1 = torch.ops.aten.index.Tensor(where, [select_1])
mul_1 = torch.ops.aten.mul.Tensor(index, index_1)
return (mul_1,)

x = torch.zeros(2, 209985).to(torch.int64)
opt_fn = torch._dynamo.optimize("inductor")(fn)
_, code = run_and_get_cpp_code(opt_fn, x)
FileCheck().check_count(
"return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(),",
2,
exactly=True,
).run(code)


if __name__ == "__main__":
from torch._inductor.test_case import run_tests
Expand Down

0 comments on commit 287f741

Please sign in to comment.