Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/checkpoint.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torch.utils.checkpoint
===============

.. currentmodule:: torch.utils.checkpoint
.. autofunction:: checkpoint
.. autofunction:: checkpoint_sequential
138 changes: 138 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import unittest
import traceback
import torch
import torch.nn as nn
import torch.utils.data
import torch.cuda
import warnings
from torch.autograd import Variable
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from torch.utils.trainer import Trainer
from torch.utils.trainer.plugins import *
from torch.utils.trainer.plugins.plugin import Plugin
Expand Down Expand Up @@ -112,6 +114,142 @@ def __len__(self):
return 10


class TestCheckpoint(TestCase):

# Test whether checkpoint is being triggered or not. For this, we check
# the number of times forward pass happens
def test_checkpoint_trigger(self):

class Net(nn.Module):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


def __init__(self):
super(Net, self).__init__()
self.counter = 0

def forward(self, input_var):
self.counter += 1
return input_var

# checkpointed
modules = [Net() for _ in range(10)]
for m in modules:
self.assertEqual(m.counter, 0)
input_var = torch.randn(3, 4, requires_grad=True)
out = checkpoint_sequential(modules, 2, input_var)
for m in modules:
self.assertEqual(m.counter, 1)
out.sum().backward()
for m in modules[:(len(modules) // 2)]:
self.assertEqual(m.counter, 2)
for m in modules[(len(modules) // 2):]:
self.assertEqual(m.counter, 1)

def test_checkpoint_valid(self):
model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 5),
nn.ReLU()
)

input_var = torch.randn(1, 100, requires_grad=True)

# checkpointed
chunks = 2
modules = list(model.children())
out = checkpoint_sequential(modules, chunks, input_var)
with self.assertRaisesRegex(RuntimeError, "Checkpointing is not compatible"):
torch.autograd.grad(
outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True
)

def test_checkpoint_sequential(self):
model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 5),
nn.ReLU()
)

x = torch.randn(1, 100, requires_grad=True)

# not checkpointed
out = model(x)
out_not_checkpointed = out.data.clone()
model.zero_grad()
out.sum().backward()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
grad_not_checkpointed[name] = param.grad.data.clone()
input_grad = x.grad.data.clone()

# checkpointed
input_var = x.detach()
input_var.requires_grad = True
# pass the sequential itself
out = checkpoint_sequential(model, 2, input_var)
out_checkpointed = out.data.clone()
model.zero_grad()
out.sum().backward()
grad_checkpointed = {}
for name, param in model.named_parameters():
grad_checkpointed[name] = param.grad.data.clone()
checkpoint_input_grad = input_var.grad.data.clone()

# compare the output, input and parameters gradients
self.assertEqual(out_checkpointed, out_not_checkpointed)
self.assertEqual(input_grad, checkpoint_input_grad)
for name in grad_checkpointed:
self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])

def test_checkpoint_functions_list(self):
model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 5),
nn.ReLU()
)

x = torch.randn(1, 100, requires_grad=True)

# not checkpointed
out = model(x)
out_not_checkpointed = out.data.clone()
model.zero_grad()
out.sum().backward()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
grad_not_checkpointed[name] = param.grad.data.clone()
input_grad = x.grad.data.clone()

# checkpointed
chunks = 2
modules = list(model.children())

This comment was marked as off-topic.

This comment was marked as off-topic.

input_var = x.detach()
input_var.requires_grad = True
# pass list of modules to checkpoint
out = checkpoint_sequential(modules, chunks, input_var)
out_checkpointed = out.data.clone()
model.zero_grad()
out.sum().backward()
grad_checkpointed = {}
for name, param in model.named_parameters():
grad_checkpointed[name] = param.grad.data.clone()
checkpoint_input_grad = input_var.grad.data.clone()

# compare the output, input and parameters gradients
self.assertEqual(out_checkpointed, out_not_checkpointed)
self.assertEqual(input_grad, checkpoint_input_grad)
for name in grad_checkpointed:
self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])


class TestDataLoader(TestCase):
def setUp(self):
self.dataset = torch.randn(5, 3, 3, 2)
Expand Down
17 changes: 17 additions & 0 deletions torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=Fal
inputs)


# This function applies in case of gradient checkpointing for memory

This comment was marked as off-topic.

# optimization. Currently, for gradient checkpointing, we only support imperative
# backwards call i.e. torch.autograd.backward() and the torch.autograd.grad() won't
# work. The reason being that: torch.autograd.grad() only calculates the grads
# for the inputs that are passed by user but it doesn't calculate grad for
# anything else e.g. model parameters like weights, bias etc. However, for
# torch.autograd.backward(), we would actually compute the grad for the weights as well.
#
# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
# local variable in torch/csrc/autograd/engine.cpp which looks at the FunctionTask
# in the stack and before a FunctionTask is executed in evaluate_function, it
# checks for whether reentrant backwards is imperative or not.
def _is_checkpoint_valid():
return Variable._execution_engine.is_checkpoint_valid()


def variable(*args, **kwargs):
warnings.warn("torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead")
return torch.tensor(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Autograd

Autograd is a hotspot for PyTorch performance, so most of the heavy lifting is
implemented in C++. This implies that we have to do some shuffling between
implemented in C++. This implies that we have to do some shuffling between
Python and C++; and in general, we want data to be in a form that is convenient
to manipulate from C++.

Expand Down
19 changes: 18 additions & 1 deletion torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ namespace torch { namespace autograd {
static constexpr int NO_DEVICE = -2;
static thread_local int worker_device = NO_DEVICE;

// This variable is true if ALL invocations in the stack of re-entrant engine
// invocations are imperative backwards. This special variable is needed for the
// gradient checkpointing feature only.
static thread_local bool checkpoint_valid = true;

// XXX: Changes to the way multithreading works in execute should be done with
// great care. Right now the implementation guarantees that a single function's
// apply will never be entered concurrently (even if multiple graphs are
Expand Down Expand Up @@ -103,13 +108,16 @@ struct GraphTask {
// run in a "default" mode, which means that all next_edges we encounter should
// get executed. If it's not empty, only functions that have an entry and this entry
// has needed == True should be executed.
// exec_info.empty() means it's .backward(), otherwise it's .grad().
std::unordered_map<Function*, ExecInfo> exec_info;
std::vector<Variable> captured_vars;

void init_to_execute(Function& graph_root, const edge_list& captures);

int owner;

bool can_checkpoint();

GraphTask(bool keep_graph, bool grad_mode)
: exception()
, has_error(false)
Expand Down Expand Up @@ -228,14 +236,16 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, variab
}

static variable_list call_function(FunctionTask& task) {
bool prev_checkpoint_valid_state = checkpoint_valid;
checkpoint_valid = task.base->can_checkpoint() && prev_checkpoint_valid_state;
auto& fn = *task.fn;
auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs)));

if(!task.base->keep_graph) {
fn.will_release_variables();
}
auto outputs = fn(inputs);

checkpoint_valid = prev_checkpoint_valid_state;

This comment was marked as off-topic.

This comment was marked as off-topic.

return call_post_hooks(fn, std::move(outputs), std::move(inputs));
}

Expand Down Expand Up @@ -423,6 +433,10 @@ void Engine::queue_callback(std::function<void()> callback) {
final_callbacks.emplace_back(std::move(callback));
}

bool Engine::is_checkpoint_valid() {
return checkpoint_valid;
}

auto Engine::ready_queue(int device) -> ReadyQueue& {
return *ready_queues.at(device + 1);
}
Expand Down Expand Up @@ -511,5 +525,8 @@ void GraphTask::init_to_execute(Function& graph_root, const edge_list& outputs)
}
}

bool GraphTask::can_checkpoint() {
return exec_info.empty();
}

}} // namespace torch::autograd
2 changes: 2 additions & 0 deletions torch/csrc/autograd/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct Engine {

static Engine& getDefaultEngine();

bool is_checkpoint_valid();

protected:
void compute_dependencies(Function* root, GraphTask& task);
void evaluate_function(FunctionTask& task);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/input_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// The InputBuffer class accumulates a list of Variables for use by a
// function. It implements logic to avoid modifying the passed
// values in-place (adding an input twice will accumulate the result).
// This behaviour needed and used only in backward graphs.
// This behaviour is needed and used only in backward graphs.

#include <vector>
#include <utility>
Expand Down
14 changes: 11 additions & 3 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ static void _maybe_reinitialize_engine_after_fork() {
}

// Implementation of torch._C._EngineBase.run_backward
//
// When inputs == nullptr, this is a torch.autograd.backward() call
// When inputs != nullptr, this is a torch.autograd.grad() call
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -198,6 +195,16 @@ PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
END_HANDLE_TH_ERRORS
}

PyObject* THPEngine_is_checkpoint_valid(PyObject *self) {
HANDLE_TH_ERRORS
if(engine.is_checkpoint_valid()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}

PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
return type->tp_alloc(type, 0);
Expand All @@ -206,6 +213,7 @@ PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr},
{(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr},
{(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr},
{nullptr}
};

Expand Down
Loading