Skip to content

Commit

Permalink
Add the changes done in huggingface#14016
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jan 6, 2022
1 parent d852154 commit 68b32f2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
# Copyright 2022 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
""" Classes to support TF Vision-Encoder-Text-Decoder architectures"""


import tempfile
from typing import Optional

import tensorflow as tf
Expand Down Expand Up @@ -410,6 +411,14 @@ def from_encoder_decoder_pretrained(
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)

decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
Expand Down Expand Up @@ -445,6 +454,14 @@ def from_encoder_decoder_pretrained(
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)

# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.")
Expand Down
14 changes: 9 additions & 5 deletions tests/test_modeling_tf_vision_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2021- HuggingFace Inc. team.
# Copyright 2022 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """


import os
Expand Down Expand Up @@ -689,13 +690,16 @@ def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)

# TensorFlow => PyTorch
# Make sure `from_pretrained` following `save_pretrained` work and give the same result
# (See https://github.com/huggingface/transformers/pull/14016)
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
encoder_decoder_pt = VisionEncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)

max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)
logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits

max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)

@require_vision
@slow
Expand Down

0 comments on commit 68b32f2

Please sign in to comment.