Skip to content

Commit

Permalink
Apply the fix for checkpoint loading as in huggingface#14016
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Dec 1, 2021
1 parent 9648654 commit 314a239
Showing 1 changed file with 23 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" Classes to support TF Vision-Encoder-Text-Decoder architectures """


import tempfile
from typing import Optional

import tensorflow as tf
Expand Down Expand Up @@ -241,7 +242,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
If there are only pytorch checkpoints for a particular encoder-decoder model, a workaround is::
>>> # a workaround to load from pytorch checkpoint
>>> _model = VisionEncoderDecoderModel.from_pretrained("ydshieh/ydshieh/vit-gpt2-coco-en")
>>> _model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
Expand All @@ -250,6 +251,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
Example::
>>> from transformers import TFVisionEncoderDecoderModel
>>> model = TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
"""

from_pt = kwargs.pop("from_pt", False)
Expand Down Expand Up @@ -364,6 +370,14 @@ def from_encoder_decoder_pretrained(
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)

decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
Expand Down Expand Up @@ -392,6 +406,14 @@ def from_encoder_decoder_pretrained(
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)

# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.")
Expand Down

0 comments on commit 314a239

Please sign in to comment.