Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 38 additions & 30 deletions torch_xla/csrc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,9 @@ void XlaModule::Initialize(const TensorBatchVector& inputs) {

// Get forward graph.
const auto forward = script_module_->find_method("forward");
JIT_ASSERT(forward);
JIT_ASSERT(forward != nullptr);
std::shared_ptr<Graph> forward_graph = forward->graph()->copy();
// Run forward passes.
CanonicalizeOps(forward_graph);
InsertExplicitExpand(forward_graph);
EvalStaticSize(forward_graph);
ConstantPropagation(forward_graph);
ReplaceUntracedOperators(forward_graph);
RemoveInPlaceOutParamOps(forward_graph);
ReplaceInPlaceOps(forward_graph);
EliminateDeadCode(forward_graph);
LowerAllTuples(forward_graph);
RunForwardPasses(&forward_graph);

// Convert model parameters to vector of XLATensors.
std::vector<at::Tensor*> params_buffers_regather;
Expand Down Expand Up @@ -141,31 +132,48 @@ void XlaModule::Initialize(const TensorBatchVector& inputs) {
param_requires_grad.begin(),
param_requires_grad.end());

gradient_ = ComputeGradient(forward_graph);

TF_VLOG(4) << "Gradient F:\n" << gradient_.f->toString();
TF_VLOG(4) << "Gradient DF:\n" << gradient_.df->toString();
// Release the reference to the script module to mark initialization as done.
script_module_ = nullptr;
}

void XlaModule::RunForwardPasses(std::shared_ptr<Graph>* graph) {
// Run forward passes.
CanonicalizeOps(*graph);
InsertExplicitExpand(*graph);
EvalStaticSize(*graph);
ConstantPropagation(*graph);
ReplaceUntracedOperators(*graph);
RemoveInPlaceOutParamOps(*graph);
ReplaceInPlaceOps(*graph);
EliminateDeadCode(*graph);
LowerAllTuples(*graph);
}

Gradient XlaModule::ComputeGradient(const std::shared_ptr<Graph>& graph) {
// Automatically differentiate the forward graph to get the backward graph.
// Since differentiation is mutating the graph, do it on a copy.
auto forward_graph_copy = forward_graph->copy();
gradient_ = differentiate(forward_graph_copy);

std::shared_ptr<Graph> graph_copy = graph->copy();
Gradient gradient = differentiate(graph_copy);
// Run the forward passes.
CanonicalizeOps(gradient_.f);
InsertExplicitExpand(gradient_.f);
ConstantPropagation(gradient_.f);
ReplaceUntracedOperators(gradient_.f);
EliminateDeadCode(gradient_.f);
CanonicalizeOps(gradient.f);
InsertExplicitExpand(gradient.f);
ConstantPropagation(gradient.f);
ReplaceUntracedOperators(gradient.f);
EliminateDeadCode(gradient.f);
// Run the backward passes.
specializeUndef(*(gradient_.df.get()));
ConstantPropagation(gradient_.df);
ThresholdBackwardPeephole(gradient_.df);
EliminateDeadCode(gradient_.df);
LowerAllTuples(gradient_.df);
specializeUndef(*(gradient.df.get()));
ConstantPropagation(gradient.df);
ThresholdBackwardPeephole(gradient.df);
EliminateDeadCode(gradient.df);
LowerAllTuples(gradient.df);
// Run pass on forward and backward graphs that drops outputs that XLA doesn't
// need.
RemoveUnusedForwardOutputs(&gradient_);

TF_VLOG(4) << "Gradient F:\n" << gradient_.f->toString();
TF_VLOG(4) << "Gradient DF:\n" << gradient_.df->toString();
// Release the reference to the script module to mark initialization as done.
script_module_ = nullptr;
RemoveUnusedForwardOutputs(&gradient);
return gradient;
}

void XlaModule::CheckInitialized() const {
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ struct XlaModule : public std::enable_shared_from_this<XlaModule> {
// computation have their accumulated operations sync to device memory.
void FlushTensorsOperations();

static void RunForwardPasses(std::shared_ptr<Graph>* graph);

// Computes the gradient structure of the given graph, and runs all the
// appropriate passes over the resulting forward and backward graphs.
static Gradient ComputeGradient(const std::shared_ptr<Graph>& graph);

// Propagate ret_size_op_values storing the aten::size values collected during
// forward pass translation to the backward pass. Uses the capture information
// from the gradient descriptor.
Expand Down
10 changes: 5 additions & 5 deletions torch_xla_py/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def extract_gradients(inputs, fill_fn=None):


# Run an XLA model with the given tensors.
def run_xla_model(xla_model, inputs, devices=None):
def xla_run_model(xla_model, inputs, devices=None):
return xla_model(*convert_to_xla_tensors(inputs, devices=devices))


Expand Down Expand Up @@ -518,12 +518,12 @@ def __call__(self, *args):
if self._loss_fn is None and self._num_cores == 1:
# Internally the interface to the XLA module is always in replicated
# mode, where num_replicas=1 is just a case of replication.
outputs = run_xla_model(self._xla_model, [args], devices=self._devices)
outputs = xla_run_model(self._xla_model, [args], devices=self._devices)
# If in legacy-API mode, convert the XLA tensor directly to PyTorch
# tensor.
return convert_to_tensors(outputs[0])
assert len(args) == self._num_cores
return run_xla_model(self._xla_model, args, devices=self._devices)
return xla_run_model(self._xla_model, args, devices=self._devices)

def backward(self, outputs):
xla_run_grad(
Expand Down Expand Up @@ -577,7 +577,7 @@ def train(self,
for batch_number, (inputs, targets) in wloader:
self._step += 1
optimizer.zero_grad()
xla_outputs = run_xla_model(
xla_outputs = xla_run_model(
self._xla_model, inputs, devices=self._devices)
xla_run_grad(
self._xla_model,
Expand Down Expand Up @@ -610,7 +610,7 @@ def test(self, samples_loader, eval_fn, batch_size, log_fn=print):
start_time = time.time()
with torch.no_grad():
for batch_number, (inputs, targets) in wloader:
xla_outputs = run_xla_model(
xla_outputs = xla_run_model(
self._xla_model, inputs, devices=self._devices)
for i, replica_xla_outputs in enumerate(xla_outputs):
# The original model output is ordinal 1 of the returned
Expand Down