Skip to content

VE device for PyTorch #59296

@mergian

Description

@mergian

🚀 Feature

We would like to add a new 'VE' device type to PyTorch to enable the NEC SX-Aurora TSUBASA (Vector Engine) processor.

Our working prototype integrates all necessary functionality into PyTorch using the unused HIP device, see https://arxiv.org/abs/2003.10688 . The current implementation works with vanilla PyTorch and does not require any code changes!

Motivation

As we already have a working implementation using the unsued HIP device, we would like to enable users to use the more suitable 'VE' name. Further, if we are not mistaken upcoming PyTorch releases might use the HIP device for the AMD ROCm support, which would create problems with our current implementation.

We have been integrating the VE into PyTorch starting with v1.4 and have constantly kept up with PyTorch's changes up to most recent v1.8.1.

Pitch

In principle we only need a new device type which requires to add VE to several enums and add some function calls.

  1. Adding new c10::DeviceType::VE
  2. Adding new DLDeviceType::kDLVE
  3. Adding new Backend::VE and Backend::SparseVE
  4. Adding new DispatchKey::VE
  5. aten/src/ATen/native/Copy.cpp need to be made aware of the kVE device type.

See this diff for all changes that we think are necessary: ve_pytorch.txt

Alternatives

We need to keep using the HIP device and maybe encounter a conflict with ROCm in future :(.

Additional context

Let me explain the structure of our current implementation, which consists of three components:

Neural Network Computations

For the main computational work loads, we utilize the SOL neural network optimization middleware (www.sol-ai.org). In a nutshell, SOL is comparable with TVM but is specifically designed for tight integration into AI frameworks and to
support inference and training on different kinds of hardware. For this, SOL parses an existing PyTorch model and compiles an optimized implementation of it. This optimized implementation then gets integrated into PyTorch using torch.nn.Module. SOL only replaces the computational part of the model and is therefore 100% compatible with the PyTorch workflow.

Memory Allocator and Basic Function Callbacks

To enable memory allocations, we load a shared library that registers a memory allocator implementation using REGISTER_ALLOCATOR(...). It uses the VEDA API (https://github.com/SX-Aurora/veda) functions and behaves similar to the CUDA implementation, except that it does not preallocate any memory segment.

Further, our shared library registers a limited set of functions (binary, logical, +, -, *, /, etc.) using TORCH_LIBRARY_IMPL(aten, HIP, m) to PyTorch, as the main part of our computations are performed using the SOL compiler. These function callbacks are only translating between the PyTorch and SOL runtime APIs.

This enables to use all basic operations +, -, *, /, print, copy, ... on the Tensors.

Python integration

As PyTorch does not provide function calls for tensor.hip() or similar, our python extension adds the following function calls to the torch module, to enable the same functionality as users are used to when using CUDA devices.

torch.Tensor.hip()
torch.nn.Module.hip()
torch.hip.synchronize()
torch.hip.is_available()
torch.hip.current_device()
torch.hip.set_device()
torch.hip.device_count()
torch.hip.memory_allocated()
CLASS torch.hip.device()
CLASS torch.hip.device_of()

Usage

To use our implementation no modified PyTorch is necessary, so just using pip3 install torch torchvision suffices for running the following code snippet.

import torch
from torchvision import models

model = models.resnet50()
input = torch.rand(1, 3, 224, 224)

# 1. loads SOL
# 2. loads the shared library that adds the memory allocator and function callbacks to PyTorch
# 3. adds the Python API to the torch module
import sol.pytorch

# Parses ```model``` and returns an optimized ```torch.nn.Module```.
sol_model = sol.optimize(model, input)

# Copies model to VE
sol_model.hip()

# Copies data to VE
input_ve = input.hip()

# Executes the model in eval mode
with torch.no_grad():
	sol_model.eval()
	output_ve = sol_model(input_ve)

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: backendnon-standard backend supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions