-
Notifications
You must be signed in to change notification settings - Fork 21.3k
/
thnn.py
48 lines (36 loc) · 1.57 KB
/
thnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from .backend import FunctionBackend
class THNNFunctionBackend(FunctionBackend):
def __reduce__(self):
return (_get_thnn_function_backend, ())
def __deepcopy__(self, memo):
memo[id(self)] = self
return self
def __copy__(self):
return self
def _get_thnn_function_backend():
return backend
def _initialize_backend():
from .._functions.thnn import _all_functions as _thnn_functions
from .._functions.rnn import RNN, \
RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
from .._functions.dropout import Dropout, FeatureDropout
from .._functions.activation import Softsign
from .._functions.loss import CosineEmbeddingLoss, \
HingeEmbeddingLoss, MarginRankingLoss
backend.register_function('RNN', RNN)
backend.register_function('RNNTanhCell', RNNTanhCell)
backend.register_function('RNNReLUCell', RNNReLUCell)
backend.register_function('LSTMCell', LSTMCell)
backend.register_function('GRUCell', GRUCell)
backend.register_function('Dropout', Dropout)
backend.register_function('Dropout2d', FeatureDropout)
backend.register_function('Dropout3d', FeatureDropout)
backend.register_function('CosineEmbeddingLoss', CosineEmbeddingLoss)
backend.register_function('HingeEmbeddingLoss', HingeEmbeddingLoss)
backend.register_function('MarginRankingLoss', MarginRankingLoss)
backend.register_function('Softsign', Softsign)
for cls in _thnn_functions:
name = cls.__name__
backend.register_function(name, cls)
backend = THNNFunctionBackend()
_initialize_backend()