-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodule_device.py
86 lines (57 loc) · 2.33 KB
/
module_device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from functools import wraps
from typing import List
import torch
from torch import is_tensor
from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_unflatten
# provides a .device for your model
# uses a dummy scalar tensor
def module_device(
device_property_name = 'device'
):
def decorator(klass):
assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'
_orig_init = klass.__init__
@wraps(_orig_init)
def __init__(self, *args, **kwargs):
_orig_init(self, *args, **kwargs)
self.register_buffer('_dummy', torch.tensor(0), persistent = False)
@property
def _device_property(self):
return self._dummy.device
klass.__init__ = __init__
setattr(klass, device_property_name, _device_property)
return klass
return decorator
# a decorator that automatically casts all tensors coming into .forward to the proper device
def autocast_device(
methods: List[str] = ['forward']
):
def decorator(klass):
assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'
orig_fns = [getattr(klass, method) for method in methods]
for method, orig_fn in zip(methods, orig_fns):
@wraps(orig_fn)
def fn(self, *args, **kwargs):
# determine device
# use dummy from decorator above
# otherwise look for parameters and use the device on that
if hasattr(self, '_dummy'):
device = self._dummy.device
else:
device = next(self.parameters()).device
# flatten
flattened_args, tree_spec = tree_flatten([args, kwargs])
# transform args
maybe_transformed_args = []
for flattened_arg in flattened_args:
if is_tensor(flattened_arg):
flattened_arg = flattened_arg.to(device)
maybe_transformed_args.append(flattened_arg)
# unflatten
args, kwargs = tree_unflatten(maybe_transformed_args, tree_spec)
# call original fn
orig_fn(self, *args, **kwargs)
setattr(klass, method, fn)
return klass
return decorator