# Beyond Pytorch Native APIs

The XLA backend of PyTorch allows an end user to create functions whose implementation is totally controlled by the user's Python code itself, in terms of the lower level XLA HLO operation generated.

The *xla_builder* module provides a slim wrapper around the *xla::XlaOp* objects documented within the XLA reference:

https://www.tensorflow.org/xla/operation_semantics
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h

While this allows the user to create APIs whose semantics are not currently available in PyTorch, such APIs will only work when used with an XLA device.


In [None]:
!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl

### If you're using GPU with this colab notebook, run the below commented code to install GPU compatible PyTorch wheel and dependencies

In [None]:
#!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-2.0-cp39-cp39-linux_x86_64.whl --force-reinstall 

### Only run the below commented cell if you would like a nightly release

In [None]:
# VERSION = "1.13"  #@param ["1.13", "nightly", "20220315"]  # or YYYYMMDD format
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION
# import os 
# os.environ['LD_LIBRARY_PATH']='/usr/local/lib'
# !echo $LD_LIBRARY_PATH

# !sudo ln -s /usr/local/lib/libmkl_intel_lp64.so /usr/local/lib/libmkl_intel_lp64.so.1
# !sudo ln -s /usr/local/lib/libmkl_intel_thread.so /usr/local/lib/libmkl_intel_thread.so.1
# !sudo ln -s /usr/local/lib/libmkl_core.so /usr/local/lib/libmkl_core.so.1

# !ldconfig
# !ldd /usr/local/lib/python3.7/dist-packages/torch/lib/libtorch.so

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp


# Splits a rank 1 tensor into the scalar indices required by the XLA dynamic
# slicing APIs.
def _split_indices(index):
  ishape = index.shape()
  assert ishape.rank == 1
  indices = []
  for dim in range(0, ishape.sizes[0]):
    indices.append(index.slice_in_dim(dim, dim + 1, 0).reshape([]))
  return indices


# This is the XLA lowering API. Here input and start_indices are Op object of the
# xla_builder module and can be manipulated with such API.
#   https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_builder.py
def _dynamic_slice_forward(input, start_indices, slice_sizes=None):
  return input.dynamic_slice(_split_indices(start_indices), slice_sizes)


# This is the XLA lowering API. Here grad_output, input and start_indices are Op
# object of the xla_builder module and can be manipulated with such API.
#   https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_builder.py
def _dynamic_slice_backward(grad_output, input, start_indices, slice_sizes=None):
  return input.zeros_like().dynamic_update_slice(grad_output, _split_indices(start_indices))


# For efficiency, it is better to register the XLA builder based operations at
# global scope.
DYNAMIC_SLICE_FORWARD = xor.register('DynamicSliceForward', _dynamic_slice_forward)
DYNAMIC_SLICE_BACKWARD = xor.register('DynamicSliceBackward', _dynamic_slice_backward)


# Standard PyTorch way to create a differentiable function.
class DynamicSlice(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input, start_indices, slice_sizes):
    ctx.slice_sizes = slice_sizes
    output = DYNAMIC_SLICE_FORWARD(input, start_indices, slice_sizes=slice_sizes)
    ctx.save_for_backward(input, start_indices)
    return output

  @staticmethod
  def backward(ctx, grad_output):
    input, start_indices = ctx.saved_tensors
    grad = DYNAMIC_SLICE_BACKWARD(grad_output, input, start_indices,
                                  slice_sizes=ctx.slice_sizes)
    # We need to return as many gradients as the forward() inputs, or None if
    # such inputs are not differentiable.
    return grad, None, None


# Exposes the dynamic slice operation, which will support autograd differentation.
def dynamic_slice(input, start_indices, slice_sizes):
  """Slices an input tensor.

  Args:
    input (torch.Tensor): The input tensor to be sliced.
    start_indices (torch.Tensor): The rank 1 tensor containing the start indices.
      The size of the tensor (its dimension 0) must be the same as the rank of
      the input tensor.
    slice_sizes (list, int): The sizes of the slices. This is a list of Python
      integers, whose lenght must be the same of the rank of the input tensor.
  Returns:
    The sliced input tensor.
  """
  return DynamicSlice.apply(input, start_indices, slice_sizes)


# Test implementation.
device = xm.xla_device()

x = torch.randn(6, 8, device=device, requires_grad=True)
index = torch.tensor([2, 4], dtype=torch.int32, device=device)
out = dynamic_slice(x, index, (2, 3))
loss = out.pow(2).sum()
loss.backward()
print(x.grad.cpu())