In [1]:
from dataclasses import dataclass, field
from typing import Optional

from transformers import AutoConfig, AutoImageProcessor, AutoTokenizer, VisionEncoderDecoderModel, HfArgumentParser


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
encoder_model_name_or_path = "google/vit-base-patch16-224-in21k"
decoder_model_name_or_path = "stanford-crfm/BioMedLM"
output_dir = "./vit_biomedlm_caption_model"

print("Encoder:", encoder_model_name_or_path)
print("Decoder:", decoder_model_name_or_path)

Encoder: google/vit-base-patch16-224-in21k
Decoder: stanford-crfm/BioMedLM


In [5]:
from transformers import AutoConfig

encoder_config = AutoConfig.from_pretrained(encoder_model_name_or_path)
decoder_config = AutoConfig.from_pretrained(decoder_model_name_or_path)

# Enable decoder cross-attention
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True

In [6]:
from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_pretrained_model_name_or_path=encoder_model_name_or_path,
    decoder_pretrained_model_name_or_path=decoder_model_name_or_path,
    encoder_config=encoder_config,
    decoder_config=decoder_config,
)
print("VisionEncoderDecoderModel created successfully!")


Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at stanford-crfm/BioMedLM and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.crossattention.c_proj.bias', 'tra

✅ VisionEncoderDecoderModel created successfully!


In [7]:
decoder_start_token_id = decoder_config.decoder_start_token_id or decoder_config.bos_token_id
pad_token_id = decoder_config.pad_token_id or decoder_config.eos_token_id

model.config.eos_token_id = decoder_config.eos_token_id
model.config.decoder_start_token_id = decoder_start_token_id
model.config.pad_token_id = pad_token_id


In [8]:
from transformers import AutoTokenizer, AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(encoder_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(decoder_model_name_or_path)

# Set padding token properly
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [9]:
model.save_pretrained(output_dir)
image_processor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model, tokenizer, and image processor saved in: {output_dir}")


Model, tokenizer, and image processor saved in: ./vit_biomedlm_caption_model
