In [5]:
from transformers import (GPT2Config, PreTrainedTokenizerFast, Seq2SeqTrainer,
                          Seq2SeqTrainingArguments, VisionEncoderDecoderConfig,
                          VisionEncoderDecoderModel, ViTConfig, TrOCRProcessor,
                          ViTImageProcessor, default_data_collator)
from PIL import Image

In [9]:
tokenizer = PreTrainedTokenizerFast(tokenizer_file='/home/philko/Documents/Uni/WiSe2223/Consulting/mlw-consulting-project/models/tokenizer/MLW_Tokenizer.json')
feature_extractor: ViTImageProcessor = ViTImageProcessor.from_pretrained(
    'google/vit-base-patch16-224-in21k'
)
image_processor: ViTImageProcessor = ViTImageProcessor.from_pretrained(
    "google/vit-base-patch16-224-in21k"
)

In [10]:
config_encoder = ViTConfig()
config_decoder = GPT2Config()

# Group architectures and define model
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
    config_encoder, config_decoder
)
model = VisionEncoderDecoderModel(config=config)

In [6]:
processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
    "microsoft/trocr-base-handwritten")

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [29]:
image = Image.open(
    "/home/philko/Documents/Uni/WiSe2223/Consulting/mlw-consulting-project/data/interim/lemmata_img/images/301450.jpg"
).convert("RGB")
pixel_values = image_processor(image, return_tensors="pt").pixel_values

# autoregressively generate caption (uses greedy decoding by default)
generated_ids = model.generate(pixel_values)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
generated_text

['<CLS>']

# Experiments with Tokenizer

In [19]:
tokenizer_test = tokenizer
model_test = model

In [20]:
special_tokens_dict = {
    'pad_token': '<PAD>',
    'cls_token': '<CLS>',
    'bos_token': '<|endoftext|>',
    'eos_token': '<|endoftext|>',
    'unk_token': '<|endoftext|>'}
tokenizer_test.add_special_tokens(special_tokens_dict)
model_test.config.decoder_start_token_id = tokenizer_test.cls_token_id
model_test.config.pad_token_id = tokenizer_test.pad_token_id