From 7eb8ad556cfa4866d6d09dad64511179691b10b8 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Wed, 19 Dec 2018 15:52:39 -0800 Subject: [PATCH] Some code restructuring and naming cleanup. --- torch_xla/csrc/module.cpp | 68 ++++++++++++++++++++++----------------- torch_xla/csrc/module.h | 6 ++++ torch_xla_py/xla_model.py | 10 +++--- 3 files changed, 49 insertions(+), 35 deletions(-) diff --git a/torch_xla/csrc/module.cpp b/torch_xla/csrc/module.cpp index b55570a5c3f8..cbf92dbb08ef 100644 --- a/torch_xla/csrc/module.cpp +++ b/torch_xla/csrc/module.cpp @@ -82,18 +82,9 @@ void XlaModule::Initialize(const TensorBatchVector& inputs) { // Get forward graph. const auto forward = script_module_->find_method("forward"); - JIT_ASSERT(forward); + JIT_ASSERT(forward != nullptr); std::shared_ptr forward_graph = forward->graph()->copy(); - // Run forward passes. - CanonicalizeOps(forward_graph); - InsertExplicitExpand(forward_graph); - EvalStaticSize(forward_graph); - ConstantPropagation(forward_graph); - ReplaceUntracedOperators(forward_graph); - RemoveInPlaceOutParamOps(forward_graph); - ReplaceInPlaceOps(forward_graph); - EliminateDeadCode(forward_graph); - LowerAllTuples(forward_graph); + RunForwardPasses(&forward_graph); // Convert model parameters to vector of XLATensors. std::vector params_buffers_regather; @@ -141,31 +132,48 @@ void XlaModule::Initialize(const TensorBatchVector& inputs) { param_requires_grad.begin(), param_requires_grad.end()); + gradient_ = ComputeGradient(forward_graph); + + TF_VLOG(4) << "Gradient F:\n" << gradient_.f->toString(); + TF_VLOG(4) << "Gradient DF:\n" << gradient_.df->toString(); + // Release the reference to the script module to mark initialization as done. + script_module_ = nullptr; +} + +void XlaModule::RunForwardPasses(std::shared_ptr* graph) { + // Run forward passes. + CanonicalizeOps(*graph); + InsertExplicitExpand(*graph); + EvalStaticSize(*graph); + ConstantPropagation(*graph); + ReplaceUntracedOperators(*graph); + RemoveInPlaceOutParamOps(*graph); + ReplaceInPlaceOps(*graph); + EliminateDeadCode(*graph); + LowerAllTuples(*graph); +} + +Gradient XlaModule::ComputeGradient(const std::shared_ptr& graph) { // Automatically differentiate the forward graph to get the backward graph. // Since differentiation is mutating the graph, do it on a copy. - auto forward_graph_copy = forward_graph->copy(); - gradient_ = differentiate(forward_graph_copy); - + std::shared_ptr graph_copy = graph->copy(); + Gradient gradient = differentiate(graph_copy); // Run the forward passes. - CanonicalizeOps(gradient_.f); - InsertExplicitExpand(gradient_.f); - ConstantPropagation(gradient_.f); - ReplaceUntracedOperators(gradient_.f); - EliminateDeadCode(gradient_.f); + CanonicalizeOps(gradient.f); + InsertExplicitExpand(gradient.f); + ConstantPropagation(gradient.f); + ReplaceUntracedOperators(gradient.f); + EliminateDeadCode(gradient.f); // Run the backward passes. - specializeUndef(*(gradient_.df.get())); - ConstantPropagation(gradient_.df); - ThresholdBackwardPeephole(gradient_.df); - EliminateDeadCode(gradient_.df); - LowerAllTuples(gradient_.df); + specializeUndef(*(gradient.df.get())); + ConstantPropagation(gradient.df); + ThresholdBackwardPeephole(gradient.df); + EliminateDeadCode(gradient.df); + LowerAllTuples(gradient.df); // Run pass on forward and backward graphs that drops outputs that XLA doesn't // need. - RemoveUnusedForwardOutputs(&gradient_); - - TF_VLOG(4) << "Gradient F:\n" << gradient_.f->toString(); - TF_VLOG(4) << "Gradient DF:\n" << gradient_.df->toString(); - // Release the reference to the script module to mark initialization as done. - script_module_ = nullptr; + RemoveUnusedForwardOutputs(&gradient); + return gradient; } void XlaModule::CheckInitialized() const { diff --git a/torch_xla/csrc/module.h b/torch_xla/csrc/module.h index 9d90e61b2ebb..99d134866fe3 100644 --- a/torch_xla/csrc/module.h +++ b/torch_xla/csrc/module.h @@ -98,6 +98,12 @@ struct XlaModule : public std::enable_shared_from_this { // computation have their accumulated operations sync to device memory. void FlushTensorsOperations(); + static void RunForwardPasses(std::shared_ptr* graph); + + // Computes the gradient structure of the given graph, and runs all the + // appropriate passes over the resulting forward and backward graphs. + static Gradient ComputeGradient(const std::shared_ptr& graph); + // Propagate ret_size_op_values storing the aten::size values collected during // forward pass translation to the backward pass. Uses the capture information // from the gradient descriptor. diff --git a/torch_xla_py/xla_model.py b/torch_xla_py/xla_model.py index be240b7a2dc4..f25e515a2654 100644 --- a/torch_xla_py/xla_model.py +++ b/torch_xla_py/xla_model.py @@ -294,7 +294,7 @@ def extract_gradients(inputs, fill_fn=None): # Run an XLA model with the given tensors. -def run_xla_model(xla_model, inputs, devices=None): +def xla_run_model(xla_model, inputs, devices=None): return xla_model(*convert_to_xla_tensors(inputs, devices=devices)) @@ -518,12 +518,12 @@ def __call__(self, *args): if self._loss_fn is None and self._num_cores == 1: # Internally the interface to the XLA module is always in replicated # mode, where num_replicas=1 is just a case of replication. - outputs = run_xla_model(self._xla_model, [args], devices=self._devices) + outputs = xla_run_model(self._xla_model, [args], devices=self._devices) # If in legacy-API mode, convert the XLA tensor directly to PyTorch # tensor. return convert_to_tensors(outputs[0]) assert len(args) == self._num_cores - return run_xla_model(self._xla_model, args, devices=self._devices) + return xla_run_model(self._xla_model, args, devices=self._devices) def backward(self, outputs): xla_run_grad( @@ -577,7 +577,7 @@ def train(self, for batch_number, (inputs, targets) in wloader: self._step += 1 optimizer.zero_grad() - xla_outputs = run_xla_model( + xla_outputs = xla_run_model( self._xla_model, inputs, devices=self._devices) xla_run_grad( self._xla_model, @@ -610,7 +610,7 @@ def test(self, samples_loader, eval_fn, batch_size, log_fn=print): start_time = time.time() with torch.no_grad(): for batch_number, (inputs, targets) in wloader: - xla_outputs = run_xla_model( + xla_outputs = xla_run_model( self._xla_model, inputs, devices=self._devices) for i, replica_xla_outputs in enumerate(xla_outputs): # The original model output is ordinal 1 of the returned