-
Notifications
You must be signed in to change notification settings - Fork 117
/
__init__.py
45 lines (41 loc) · 1.13 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Serialization module."""
import os
from torch.nn.modules import activation
# If the use of Skops needs to be disabled.
USE_SKOPS = bool(os.environ.get("USE_SKOPS", 1))
# Define all currently supported Torch activation functions
SUPPORTED_TORCH_ACTIVATIONS = [
activation.CELU,
activation.ELU,
activation.GELU,
activation.Hardshrink,
activation.Hardsigmoid,
activation.Hardswish,
activation.Hardtanh,
activation.LeakyReLU,
activation.LogSigmoid,
activation.LogSoftmax,
activation.Mish,
activation.PReLU,
activation.ReLU,
activation.ReLU6,
activation.SELU,
activation.SiLU,
activation.Sigmoid,
activation.Softmin,
activation.Softplus,
activation.Softshrink,
activation.Softsign,
activation.Tanh,
activation.Tanhshrink,
activation.Threshold,
]
# Some Torch activation functions are currently not supported in Concrete ML
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/335
UNSUPPORTED_TORCH_ACTIVATIONS = [
activation.GLU,
activation.MultiheadAttention,
activation.RReLU,
activation.Softmax,
activation.Softmax2d,
]