From 1b7beaf2fb4fadeb957b83bdcd2de14c84d2ba6e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Sep 2021 18:55:20 +0200 Subject: [PATCH] [DPR] Correct init (#13796) * update * add to docs and init * make fix-copies --- docs/source/model_doc/dpr.rst | 7 +++ src/transformers/__init__.py | 2 + src/transformers/models/dpr/__init__.py | 2 + src/transformers/models/dpr/modeling_dpr.py | 59 ++++++++++----------- src/transformers/utils/dummy_pt_objects.py | 9 ++++ tests/test_modeling_dpr.py | 14 +++++ 6 files changed, 61 insertions(+), 32 deletions(-) diff --git a/docs/source/model_doc/dpr.rst b/docs/source/model_doc/dpr.rst index 005faf8cff9621..0dbc7c32f7ac1e 100644 --- a/docs/source/model_doc/dpr.rst +++ b/docs/source/model_doc/dpr.rst @@ -41,6 +41,13 @@ DPRConfig :members: +DPRPreTrainedModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DPRPreTrainedModel + :members: + + DPRContextEncoderTokenizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 42be4bbd96e3f5..a3ab3765d1f8b8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -773,6 +773,7 @@ "DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST", "DPRContextEncoder", "DPRPretrainedContextEncoder", + "DPRPreTrainedModel", "DPRPretrainedQuestionEncoder", "DPRPretrainedReader", "DPRQuestionEncoder", @@ -2512,6 +2513,7 @@ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, DPRContextEncoder, DPRPretrainedContextEncoder, + DPRPreTrainedModel, DPRPretrainedQuestionEncoder, DPRPretrainedReader, DPRQuestionEncoder, diff --git a/src/transformers/models/dpr/__init__.py b/src/transformers/models/dpr/__init__.py index e94ce7ca225a8b..24358c11341fae 100644 --- a/src/transformers/models/dpr/__init__.py +++ b/src/transformers/models/dpr/__init__.py @@ -46,6 +46,7 @@ "DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST", "DPRContextEncoder", "DPRPretrainedContextEncoder", + "DPRPreTrainedModel", "DPRPretrainedQuestionEncoder", "DPRPretrainedReader", "DPRQuestionEncoder", @@ -89,6 +90,7 @@ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, DPRContextEncoder, DPRPretrainedContextEncoder, + DPRPreTrainedModel, DPRPretrainedQuestionEncoder, DPRPretrainedReader, DPRQuestionEncoder, diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index c1a3fa618d4eb1..091479af4b3c4e 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -class DPREncoder(PreTrainedModel): +class DPRPreTrainedModel(PreTrainedModel): + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + + +class DPREncoder(DPRPreTrainedModel): base_model_prefix = "bert_model" @@ -200,13 +222,8 @@ def embeddings_size(self) -> int: return self.encode_proj.out_features return self.bert_model.config.hidden_size - def init_weights(self): - self.bert_model.init_weights() - if self.projection_dim > 0: - self.encode_proj.apply(self.bert_model._init_weights) - -class DPRSpanPredictor(PreTrainedModel): +class DPRSpanPredictor(DPRPreTrainedModel): base_model_prefix = "encoder" @@ -262,16 +279,13 @@ def forward( attentions=outputs.attentions, ) - def init_weights(self): - self.encoder.init_weights() - ################## # PreTrainedModel ################## -class DPRPretrainedContextEncoder(PreTrainedModel): +class DPRPretrainedContextEncoder(DPRPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel): base_model_prefix = "ctx_encoder" _keys_to_ignore_on_load_missing = [r"position_ids"] - def init_weights(self): - self.ctx_encoder.init_weights() - -class DPRPretrainedQuestionEncoder(PreTrainedModel): +class DPRPretrainedQuestionEncoder(DPRPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): base_model_prefix = "question_encoder" _keys_to_ignore_on_load_missing = [r"position_ids"] - def init_weights(self): - self.question_encoder.init_weights() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BertEncoder): - module.gradient_checkpointing = value - - -class DPRPretrainedReader(PreTrainedModel): +class DPRPretrainedReader(DPRPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel): base_model_prefix = "span_predictor" _keys_to_ignore_on_load_missing = [r"position_ids"] - def init_weights(self): - self.span_predictor.encoder.init_weights() - self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights) - self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BertEncoder): - module.gradient_checkpointing = value - ############### # Actual Models diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index fc4f28b4a3602d..c91cb17c0de5fc 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1462,6 +1462,15 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class DPRPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPRPretrainedQuestionEncoder: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_dpr.py b/tests/test_modeling_dpr.py index 8c7d17b5428b80..5636c2a8025830 100644 --- a/tests/test_modeling_dpr.py +++ b/tests/test_modeling_dpr.py @@ -14,6 +14,7 @@ # limitations under the License. +import tempfile import unittest from transformers import DPRConfig, is_torch_available @@ -213,6 +214,19 @@ def test_reader_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reader(*config_and_inputs) + def test_init_changed_config(self): + config = self.model_tester.prepare_config_and_inputs()[0] + + model = DPRQuestionEncoder(config=config) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512) + + self.assertIsNotNone(model) + @slow def test_model_from_pretrained(self): for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: