Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

simple and small optimization: default device argument for custom device #103828

Closed
heidongxianhua opened this issue Jun 19, 2023 · 12 comments
Closed
Labels
module: python frontend For issues relating to PyTorch's Python frontend needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@heidongxianhua
Copy link
Contributor

heidongxianhua commented Jun 19, 2023

馃殌 The feature, motivation and pitch

1銆丗or many operators(such as pin_memory), the device argument is default as cuda if not given; but for other device, we must have to give extra argument device_type comparing to cuda, so we add an API to set the default argument device just once at the begining to keep usage consistent with cuda.
2銆丄nd there are some API defined in Python, we add a argument named device_type and the default value is cuda, (such as https://github.com/pytorch/pytorch/blob/main/torch/random.py#L104), so that we could support more device (privateuse1 device).
So we want to add an API to set the default argument device just once at the begining to keep usage consistent with cuda, and add an api to get the default device to keep usage consistent with cuda if not gived device_type.

Alternatives

No response

Additional context

No response

cc @albanD

@albanD
Copy link
Collaborator

albanD commented Jun 26, 2023

cc @ezyang who was discussing this on the proposed PRs

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 26, 2023
@malfet
Copy link
Contributor

malfet commented Jun 26, 2023

There is already torch.set_default_device, but some API seems to be device specific (for example is_pinned )

@heidongxianhua
Copy link
Contributor Author

heidongxianhua commented Jun 27, 2023

set_default_device

yes, the torch.set_default_device func is rather special, it will make all the operators device to be the setted value.

torch.set_default_device('cuda')
a = torch.rand(2)
a.device  # device(type='cuda', index=0)

But we want to add an API to set the device argument of some operators which is cuda (hard-code with cuda), it will be difficult for other device to extend and has different usage, such as is_pinned or pin_memory, for cuda, it just to call with is_pinned(), but for other device, it should be is_pinned("foo").
And another case is torch.device, if we call with torch.device(2), it will be "cuda:2".
So we want to add an API to set the device-argument for the operators wich has hard-code default device with cuda, we can change it to other device only run once, and so that we can keep the usage for this operators as now for other device. @malfet @albanD @ezyang Maybe consider our ideas again ?

@ezyang
Copy link
Contributor

ezyang commented Jul 4, 2023

As I said, I have proposed a context manager similar to the torch.device context manager, but which ONLY applies to things like is_pinned device argument. You can see the existing implementation for device context manager at torch/utils/_device.py, it is easy to adapt.

@heidongxianhua
Copy link
Contributor Author

As I said, I have proposed a context manager similar to the torch.device context manager, but which ONLY applies to things like is_pinned device argument. You can see the existing implementation for device context manager at torch/utils/_device.py, it is easy to adapt.

@ezyang yeah, thankyou so much, and we have made some tests with DeviceContext , it could solve the problems for operators.
And I would like to discuss this issue in more depth. The current status is that in the pytorch framework, whether it is on the c++ or python side, many operators and interfaces have the device parameter, but when the device parameter is not specified or only the device index is specified, and hard-code cuda is used by default. This may be very unfriendly to extend other device types, whether it is mps/xla or custom device. So we want to solve this problem.
So we propose some ideas to address this issue.

  1. As the PR add default argument device type api聽#103575, we add an API to set the default device type, and then we could get the device type to replace the hard-code cuda type by the func named get_default_argument_device_type, and it is available in both c++ and python. This ieda could solve all the problem and is very friendly to extend other device types, whether it is mps/xla or custom device.

  2. And another PR Deprecated the device usage without device_type聽#104457, the core ieda is to discard the inappropriate usage, we want to deprecate the device usage like a = torch.device(0), and we should use a = torch.device("cuda:0"). This idea could solve all the operators with device argument and the core API torch.device, and the APIs defined in python could not be solved.

  3. And the ieda (your suggestion) to use DeviceContext or TorchFunctionMode , we have made many tests, it also could solve all the operators with device argument, but the core API torch.device and the APIs defined in python could not be solved.

So now we do not know how to solve the API torch.device only given with index argument, maybe we should deprecate the device usage torch.device(0).

@ezyang
Copy link
Contributor

ezyang commented Jul 24, 2023

Is it just torch.device? We can make torch.device interposable by TorchFunctionMode, would that be sufficient?

@heidongxianhua
Copy link
Contributor Author

Is it just torch.device? We can make torch.device interposable by TorchFunctionMode, would that be sufficient?

ehha, except for torch.device, there are some APIs defined in python and hard-code cuda in C++, it is hard to extend for other device except for cuda.

@heidongxianhua
Copy link
Contributor Author

And I have tried to use TorchFunctionMode to solve torch.device, it does not work, is there any examples? @ezyang

@ezyang
Copy link
Contributor

ezyang commented Jul 25, 2023

You need to modify it, but the modification is a lot smaller.

PyObject* THPDevice_pynew(
    PyTypeObject* type,
    PyObject* args,
    PyObject* kwargs) {
  HANDLE_TH_ERRORS
  static torch::PythonArgParser parser(
      {"Device(Device device)",
       "Device(c10::string_view type, int64_t? index=-1)"});
  torch::ParsedArgs<2> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  if (r.idx == 0) {
    auto device = r.device(0);
    return THPDevice_New(device);
  } else if (r.idx == 1) {
    auto as_device = r.device(0); // this works, because device can take strings
    auto device_type = r.string(0);
    if (as_device.has_index()) {
      throw std::runtime_error(
          "type (string) must not include an index because index "
          "was passed explicitly: " +
          device_type);
    }
    int32_t device_index = -1;
    if (!r.isNone(1)) {
      device_index = r.toInt64(1);
      // -1 is allowed in ATen/C++, to mean the default device, but not in
      // Python.
      TORCH_CHECK(device_index >= 0, "Device index must not be negative");
    }
    at::Device device(as_device.type(), device_index);
    return THPDevice_New(device);
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

change it to something like

PyObject* THPDevice_pynew(
    PyTypeObject* type,
    PyObject* args,
    PyObject* kwargs) {
  HANDLE_TH_ERRORS
  static torch::PythonArgParser parser(
      {"device(Device device)",
       "device(c10::string_view type, int64_t? index=-1)"});
  torch::ParsedArgs<2> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  if (r.has_torch_function()) {
    return handle_torch_function(
        r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
  }
  if (r.idx == 0) {
    auto device = r.device(0);
    return THPDevice_New(device);
  } else if (r.idx == 1) {
    auto as_device = r.device(0); // this works, because device can take strings
    auto device_type = r.string(0);
    if (as_device.has_index()) {
      throw std::runtime_error(
          "type (string) must not include an index because index "
          "was passed explicitly: " +
          device_type);
    }
    int32_t device_index = -1;
    if (!r.isNone(1)) {
      device_index = r.toInt64(1);
      // -1 is allowed in ATen/C++, to mean the default device, but not in
      // Python.
      TORCH_CHECK(device_index >= 0, "Device index must not be negative");
    }
    at::Device device(as_device.type(), device_index);
    return THPDevice_New(device);
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

and then you should be able to interpose on torch.device constructions. Then, based on your other patch, you just need to modify is_pinned (should already be interposable) and fork_rng (do something similar, but use the Python side torch function handling idiom

    def storage(self):  
        r"""            
        storage() -> torch.TypedStorage
                    
        Returns the underlying :class:`TypedStorage`.
                        
        .. warning::        
                            
            :class:`TypedStorage` is deprecated. It will be removed in the future, and
            :class:`UntypedStorage` will be the only storage class. To access the
            :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
        """     
        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.storage, (self,), self)
                        
        torch.storage._warn_typed_storage_removal(stacklevel=2)
        return self._typed_storage()
                            

)

@heidongxianhua
Copy link
Contributor Author

heidongxianhua commented Jul 26, 2023

yeah, I give a wrong string, and I fix it now. this is PR #106017
but this PR will cause error in prims tests, unlike the operators defined in torch.aten, torch.device is an independent class. I am not familiar with prims锛孖 do not have a good idea to solve it.
and I have made a simple test, TorchFunctionMode may also cause performance degradation:

import torch
from torch.overrides import TorchFunctionMode
import time

_device_constructors = {torch.tensor, torch.device, torch.rand}
class DeviceContext(TorchFunctionMode):
    def __init__(self, device):
        pass

    def __enter__(self):
        return super().__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        return super().__exit__(exc_type, exc_val, exc_tb)

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if func in _device_constructors and isinstance(kwargs.get("device"), int):
            kwargs["device"] = "cpu"
        return func(*args, **kwargs)

def run():
    start_time = time.time()
    for i in range(100000):
        a = torch.rand(2,3)
    print("time:", time.time() - start_time)
    DeviceContext("npu").__enter__()
    start_time = time.time()
    for i in range(100000):
        a = torch.rand(2,3, device=0)
    print("time:", time.time() - start_time)
run()

the result is :

origin time: 0.3201165199279785
torch_function time: 0.6074237823486328

And in python as you give a storage example, we may need to add handle_torch_function func for many APIs, because there are many APIs have the device type argument with default value cuda.
So based on the above questions, perhaps it is the simpler way to add a default device parameter setting in c++, as PR #103575? @ezyang

@ezyang
Copy link
Contributor

ezyang commented Jul 28, 2023

Are you more comfortable with the overhead if you assume you're going to torch.compile the model anyway?

@heidongxianhua
Copy link
Contributor Author

Are you more comfortable with the overhead if you assume you're going to torch.compile the model anyway?

yeah, we also want to support training model in eager mode and torch.compile mode. @ezyang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: python frontend For issues relating to PyTorch's Python frontend needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
5 participants