-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🚀 Feature
We would like to add a new 'VE' device type to PyTorch to enable the NEC SX-Aurora TSUBASA (Vector Engine) processor.
Our working prototype integrates all necessary functionality into PyTorch using the unused HIP device, see https://arxiv.org/abs/2003.10688 . The current implementation works with vanilla PyTorch and does not require any code changes!
Motivation
As we already have a working implementation using the unsued HIP device, we would like to enable users to use the more suitable 'VE' name. Further, if we are not mistaken upcoming PyTorch releases might use the HIP device for the AMD ROCm support, which would create problems with our current implementation.
We have been integrating the VE into PyTorch starting with v1.4 and have constantly kept up with PyTorch's changes up to most recent v1.8.1.
Pitch
In principle we only need a new device type which requires to add VE to several enums and add some function calls.
- Adding new
c10::DeviceType::VE
- Adding new
DLDeviceType::kDLVE
- Adding new
Backend::VE
andBackend::SparseVE
- Adding new
DispatchKey::VE
aten/src/ATen/native/Copy.cpp
need to be made aware of thekVE
device type.
See this diff for all changes that we think are necessary: ve_pytorch.txt
Alternatives
We need to keep using the HIP device and maybe encounter a conflict with ROCm in future :(.
Additional context
Let me explain the structure of our current implementation, which consists of three components:
Neural Network Computations
For the main computational work loads, we utilize the SOL neural network optimization middleware (www.sol-ai.org). In a nutshell, SOL is comparable with TVM but is specifically designed for tight integration into AI frameworks and to
support inference and training on different kinds of hardware. For this, SOL parses an existing PyTorch model and compiles an optimized implementation of it. This optimized implementation then gets integrated into PyTorch using torch.nn.Module
. SOL only replaces the computational part of the model and is therefore 100% compatible with the PyTorch workflow.
Memory Allocator and Basic Function Callbacks
To enable memory allocations, we load a shared library that registers a memory allocator implementation using REGISTER_ALLOCATOR(...)
. It uses the VEDA API (https://github.com/SX-Aurora/veda) functions and behaves similar to the CUDA implementation, except that it does not preallocate any memory segment.
Further, our shared library registers a limited set of functions (binary, logical, +, -, *, /, etc.) using TORCH_LIBRARY_IMPL(aten, HIP, m)
to PyTorch, as the main part of our computations are performed using the SOL compiler. These function callbacks are only translating between the PyTorch and SOL runtime APIs.
This enables to use all basic operations +, -, *, /, print, copy, ...
on the Tensors.
Python integration
As PyTorch does not provide function calls for tensor.hip()
or similar, our python extension adds the following function calls to the torch
module, to enable the same functionality as users are used to when using CUDA devices.
torch.Tensor.hip()
torch.nn.Module.hip()
torch.hip.synchronize()
torch.hip.is_available()
torch.hip.current_device()
torch.hip.set_device()
torch.hip.device_count()
torch.hip.memory_allocated()
CLASS torch.hip.device()
CLASS torch.hip.device_of()
Usage
To use our implementation no modified PyTorch is necessary, so just using pip3 install torch torchvision
suffices for running the following code snippet.
import torch
from torchvision import models
model = models.resnet50()
input = torch.rand(1, 3, 224, 224)
# 1. loads SOL
# 2. loads the shared library that adds the memory allocator and function callbacks to PyTorch
# 3. adds the Python API to the torch module
import sol.pytorch
# Parses ```model``` and returns an optimized ```torch.nn.Module```.
sol_model = sol.optimize(model, input)
# Copies model to VE
sol_model.hip()
# Copies data to VE
input_ve = input.hip()
# Executes the model in eval mode
with torch.no_grad():
sol_model.eval()
output_ve = sol_model(input_ve)