diff --git a/offlinerllib/__init__.py b/offlinerllib/__init__.py index 36e61c6..abe9369 100644 --- a/offlinerllib/__init__.py +++ b/offlinerllib/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/offlinerllib/module/net/attention/base.py b/offlinerllib/module/net/attention/base.py index 7c6b1ce..a7ded03 100644 --- a/offlinerllib/module/net/attention/base.py +++ b/offlinerllib/module/net/attention/base.py @@ -4,6 +4,13 @@ from offlinerllib.module.net.attention.positional_encoding import BasePosEncoding +class DecayParameter(nn.Parameter): + pass + +class NoDecayParameter(nn.Parameter): + pass + + class BaseTransformer(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -31,6 +38,10 @@ def configure_params(self): no_decay.add(fpn) elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): decay.add(fpn) + elif isinstance(p, DecayParameter): + decay.add(fpn) + elif isinstance(p, NoDecayParameter): + no_decay.add(fpn) # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()}