Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanthd04 committed May 6, 2024
2 parents be64baf + 7bf6ed0 commit 79f542b
Show file tree
Hide file tree
Showing 87 changed files with 1,223 additions and 1,031 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
73b915b55d96553a0e370b2bab01f47b8c2a9e7c
e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd
57 changes: 16 additions & 41 deletions caffe2/serialize/inline_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,35 +620,15 @@ size_t ostream_write_func(
return ret;
}

// This func will not update combined_uncomp_crc32_ with the uncomp_crc32
// since there is no way to get the uncomp_crc32 when no buffer is provided.
size_t ostream_seek_func(
void* pOpaque,
mz_uint64 file_ofs,
size_t n) {
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
if (self->current_pos_ != file_ofs) {
CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
}
size_t ret = self->seek_func_(n);
if (self->current_pos_ + n != ret) {
self->err_seen_ = true;
}
self->current_pos_ += n;
return n;
}

PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
: archive_name_(basename(file_name)) {
setup(file_name);
}

PyTorchStreamWriter::PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func,
const std::function<size_t(size_t)> seek_func)
const std::function<size_t(const void*, size_t)> writer_func)
: archive_name_("archive"),
writer_func_(writer_func),
seek_func_(seek_func) {
writer_func_(writer_func) {
setup(archive_name_);
}

Expand Down Expand Up @@ -677,15 +657,10 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_stream_.write(static_cast<const char*>(buf), nbytes);
return !file_stream_ ? 0 : nbytes;
};
seek_func_ = [this](size_t nbytes) -> size_t {
file_stream_.seekp(nbytes, std::ios_base::cur);
return file_stream_.tellp();
};
}

ar_->m_pIO_opaque = this;
ar_->m_pWrite = ostream_write_func;
ar_->m_pSeek = ostream_seek_func;

mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
valid("initializing archive ", file_name.c_str());
Expand Down Expand Up @@ -715,20 +690,20 @@ void PyTorchStreamWriter::writeRecord(
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
mz_zip_writer_add_mem_ex_v2(
/*pZip=*/ar_.get(),
/*pArchive_name=*/full_name.c_str(),
/*pBuf=*/data,
/*buf_size=*/size,
/*pComment=*/nullptr,
/*comment_size=*/0,
/*level_and_flags=*/flags,
/*uncomp_size=*/0,
/*uncomp_crc32=*/0,
/*last_modified=*/nullptr,
/*user_extra_data=*/padding_.c_str(),
/*user_extra_data_len=*/padding_size,
/*user_extra_data_central=*/nullptr,
/*user_extra_data_central_len=*/0);
ar_.get(),
full_name.c_str(),
data,
size,
nullptr,
0,
flags,
0,
0,
nullptr,
padding_.c_str(),
padding_size,
nullptr,
0);
valid("writing file ", name.c_str());
files_written_.insert(name);
}
Expand Down
17 changes: 1 addition & 16 deletions caffe2/serialize/inline_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,11 @@ class TORCH_API PyTorchStreamReader final {
size_t additional_reader_size_threshold_;
};

namespace {

size_t default_seek_func(size_t nbytes) {
TORCH_CHECK(false, "attempting to write record metadata but seek_func unimplemented, please implement seek_func");
return 0;
}

} // namespace

class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(const std::string& archive_name);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func,
const std::function<size_t(size_t)> seek_func = default_seek_func);
const std::function<size_t(const void*, size_t)> writer_func);

void setMinVersion(const uint64_t version);

Expand Down Expand Up @@ -256,7 +246,6 @@ class TORCH_API PyTorchStreamWriter final {
std::string padding_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
std::function<size_t(size_t)> seek_func_;
uint64_t combined_uncomp_crc32_ = 0;
std::string serialization_id_;

Expand All @@ -270,10 +259,6 @@ class TORCH_API PyTorchStreamWriter final {
uint64_t file_ofs,
const void* pBuf,
size_t n);
friend size_t ostream_seek_func(
void* pOpaque,
uint64_t file_ofs,
size_t n);
};

namespace detail {
Expand Down
20 changes: 10 additions & 10 deletions test/dynamo/test_autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def f(x, weird, z):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_z_ : torch.Tensor, L_weird_b : torch.Tensor, L_weird_c : torch.Tensor):
def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "f32[]"):
l_x_ = L_x_
l_z_ = L_z_
l_weird_b = L_weird_b
Expand All @@ -522,23 +522,23 @@ def forward(self, L_x_ : torch.Tensor, L_z_ : torch.Tensor, L_weird_b : torch.Te
function_ctx = torch.autograd.function.FunctionCtx()
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply = torch._functorch.autograd_function.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
autograd_function_apply: "f32[]" = torch._functorch.autograd_function.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
return (autograd_function_apply,)
class GraphModule(torch.nn.Module):
def forward(self, function_ctx, l_x_, l_z_, l_weird_b, l_weird_c):
mul = l_weird_b * l_weird_c
clone = l_x_.clone(); l_x_ = None
mul_1 = mul * clone; mul = clone = None
def forward(self, function_ctx, l_x_: "f32[]", l_z_: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
mul: "f32[]" = l_weird_b * l_weird_c
clone: "f32[]" = l_x_.clone(); l_x_ = None
mul_1: "f32[]" = mul * clone; mul = clone = None
return (mul_1, [l_weird_b, l_weird_c])
class GraphModule(torch.nn.Module):
def forward(self, function_ctx, mul_1, l_weird_b, l_weird_c):
def forward(self, function_ctx, mul_1: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
_set_grad_enabled = torch._C._set_grad_enabled(False)
mul = mul_1 * l_weird_b; l_weird_b = None
mul_2 = mul * l_weird_c; mul = l_weird_c = None
mul_3 = mul_1 * 2; mul_1 = None
mul: "f32[]" = mul_1 * l_weird_b; l_weird_b = None
mul_2: "f32[]" = mul * l_weird_c; mul = l_weird_c = None
mul_3: "f32[]" = mul_1 * 2; mul_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
return (mul_2, mul_3)
Expand Down
16 changes: 8 additions & 8 deletions test/dynamo/test_backward_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list):
l_inputs_ = L_inputs_
getitem = l_inputs_[0]; l_inputs_ = None
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
new_grad = torch.clone(getitem)
new_grad: "f32[s0]" = torch.clone(getitem)
result = getitem * getitem; getitem = None
result: "f32[s0]" = getitem * getitem; getitem = None
new_grad_1 = torch.clone(result); result = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
Expand Down Expand Up @@ -195,13 +195,13 @@ class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list):
l_inputs_ = L_inputs_
getitem = l_inputs_[0]; l_inputs_ = None
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
new_grad = torch.clone(getitem)
new_grad: "f32[s0]" = torch.clone(getitem)
result = getitem * getitem; getitem = None
result: "f32[s0]" = getitem * getitem; getitem = None
new_grad_1 = torch.clone(result); result = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
Expand Down
76 changes: 43 additions & 33 deletions test/dynamo/test_ctx_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,21 +1052,23 @@ def f(x, y):
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))

expected = """\
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
x: "f32[1]" = torch.ones(1)
y = torch.zeros(1)
y: "f32[1]" = torch.zeros(1)
add = x + y; x = y = None
add: "f32[1]" = x + y; x = y = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (add,)
"""
self.assertExpectedInline(actual, expected)
""",
)

def test_disable_saved_tensors_hooks_prev_disabled(self):
def fn(z):
Expand All @@ -1090,21 +1092,23 @@ def f(x, y):
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))

expected = """\
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
x: "f32[1]" = torch.ones(1)
y = torch.zeros(1)
y: "f32[1]" = torch.zeros(1)
add = x + y; x = y = None
add: "f32[1]" = x + y; x = y = None
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message')
return (add,)
"""
self.assertExpectedInline(actual, expected)
""",
)

def test_disable_saved_tensors_hooks_prev_disabled_nested(self):
def fn(z):
Expand Down Expand Up @@ -1134,27 +1138,29 @@ def inner_fn(x, y):
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))

expected = """\
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
x: "f32[1]" = torch.ones(1)
y = torch.zeros(1)
y: "f32[1]" = torch.zeros(1)
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported inner')
add = x + y; y = None
add: "f32[1]" = x + y; y = None
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
add_1 = add + x; add = x = None
add_1: "f32[1]" = add + x; add = x = None
_saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message')
return (add_1,)
"""
self.assertExpectedInline(actual, expected)
""",
)

def test_disable_saved_tensors_hooks_graph_break(self):
def fn(x):
Expand All @@ -1171,37 +1177,41 @@ def fn(x):
def check_graph(actual, expected):
self.assertExpectedInline(actual, expected)

expected = """\
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
def forward(self, L_x_: "f32[]"):
l_x_ = L_x_
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
y = l_x_ + 1; l_x_ = None
y: "f32[]" = l_x_ + 1; l_x_ = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (y,)
"""
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
check_graph(actual, expected)
""",
)

expected = """\
graph = eager.graphs[1]
actual = normalize_gm(graph.print_readable(False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_y_ : torch.Tensor):
def forward(self, L_y_: "f32[]"):
l_y_ = L_y_
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
mul = l_y_ * 2; l_y_ = None
mul: "f32[]" = l_y_ * 2; l_y_ = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (mul,)
"""
graph = eager.graphs[1]
actual = normalize_gm(graph.print_readable(False))
check_graph(actual, expected)
""",
)

def test_context_wrapping_grad_mode_decorator(self):
ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)]
Expand Down

0 comments on commit 79f542b

Please sign in to comment.