In [1]:
from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel, AutoConfig, AutoModel, AutoProcessor, AutoFeatureExtractor, AutoTokenizer, TrOCRProcessor, AutoModelForCausalLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, TrainingArguments, Trainer, default_data_collator


def get_processor(encoder_name, decoder_name):
    feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name)
    tokenizer = AutoTokenizer.from_pretrained(decoder_name)
    processor = TrOCRProcessor(feature_extractor, tokenizer)
    return processor

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(





In [2]:

def get_model(encoder_name, decoder_name):

    encoder_config = AutoConfig.from_pretrained(encoder_name)
    encoder_config.is_decoder = False
    encoder_config.add_cross_attention = False
    encoder = AutoModel.from_config(encoder_config)


    decoder_config = AutoConfig.from_pretrained(decoder_name)
    decoder_config.is_decoder = True
    decoder_config.add_cross_attention=True
    decoder = AutoModelForCausalLM.from_config(decoder_config)

    config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
    model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)

    processor = get_processor(encoder_name, decoder_name)

    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.vocab_size = model.config.decoder.vocab_size
    
    model.config.eos_token_id = processor.tokenizer.sep_token_id
    model.config.max_length = 100
    model.config.early_stopping = True
    model.config.no_repeat_ngram_size = 3
    model.config.length_penalty = 2.0
    model.config.num_beams = 4

    return model, processor


In [3]:
model, processor = get_model('facebook/deit-tiny-patch16-224', 'cl-tohoku/bert-base-japanese-char-v2')

# model, processor = get_model('google/vit-base-patch16-224', 'cl-tohoku/bert-base-japanese-char-v2')

  _torch_pytree._register_pytree_node(


In [4]:
from evaluate import load
import torch

wer = load("wer")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"wer_score": wer_score}

In [8]:
from pathlib import Path
from datasets import load_dataset
import numpy as np

run_name='debug',
max_len=100,
num_decoder_layers=2,
batch_size=10,
num_epochs=5,
fp16=True,

#train_ds = load_dataset("imagefolder", data_dir="./dataset/", split="train")
#test_ds = load_dataset("imagefolder", data_dir="./dataset/", split="validation")

ds = load_dataset("imagefolder", data_dir="./manga_dataset/")
ds = ds["train"].train_test_split(test_size=0.1)
train_ds = ds["train"]
test_ds = ds["test"]

# ds = load_dataset("lambdalabs/pokemon-blip-captions")
# ds = ds["train"].train_test_split(test_size=0.1) #{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x1280 at 0x1AE95F68E50>, 'text': 'a yellow and black cartoon character with a big smile'}
# train_ds = ds["train"]
# test_ds = ds["test"]

print(train_ds)
print(test_ds)

# def transforms(example_batch):
#     images = [x for x in example_batch["image"]]
#     captions = [x for x in example_batch["text"]]
#     inputs = processor(images=images, text=captions, padding="max_length")
#     inputs.update({"labels": inputs["input_ids"]})
#     return inputs

# from torchvision.transforms import (
#     CenterCrop,
#     Compose,
#     Normalize,
#     RandomHorizontalFlip,
#     RandomResizedCrop,
#     Resize,
#     ToTensor,
# )

# train_transforms = Compose(
#         [
#             RandomResizedCrop(224),
#             RandomHorizontalFlip(),
#             ToTensor(),
#         ]
#     )

# def preprocess_train(example_batch):
#     """Apply train_transforms across a batch."""
#     example_batch["pixel_values"] = [
#         train_transforms(image.convert("RGB")) for image in example_batch["image"]
#     ]
#     return example_batch

# def transform(example_batch):
#     # Take a list of PIL images and turn them to pixel values
#     inputs = processor([x for x in example_batch['image']], return_tensors='pt')

#     # Don't forget to include the labels!
#     inputs['text'] = example_batch['text']
#     return inputs

# def transform(example_batch):
#     # Take a list of PIL images and turn them to pixel values

#     inputs = processor.feature_extractor([x for x in example_batch['image']], return_tensors="pt").pixel_values.squeeze()#[None].to(model.device)
#     # Don't forget to include the labels!
#     #txt = processor.tokenizer(example_batch['text'], padding="max_length", max_length=max_len, truncation=True).input_ids

#     # txt = np.array(txt)

#     # encoding = {
#     #     "pixel_values": inputs,
#     #     "text": torch.tensor(txt),
#     # }
#     # txt[txt == processor.tokenizer.pad_token_id] = -100

#     # inputs['text'] = torch.tensor(txt)
#     inputs['text'] = example_batch['text']
#     return inputs

def transform(example_batch):
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    inputs['labels'] = torch.tensor(processor.tokenizer(example_batch['text'],
                                                    padding="max_length",
                                                    max_length=100,
                                                    truncation=True).input_ids)
                                   #add back np.array()                 
    return inputs



train_ds.set_transform(transform)
test_ds.set_transform(transform)

print(train_ds[0]['labels'])


training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='steps',
    save_strategy='steps',
    #per_device_train_batch_size=10,
    #per_device_eval_batch_size=10,
    fp16=fp16,
    fp16_full_eval=fp16,
    #dataloader_num_workers=8,
    output_dir = Path("./output"),
    #logging_steps=10,
    #save_steps=20000,
    #eval_steps=20000,
    num_train_epochs=15,
    #run_name=run_name,
    remove_unused_columns=False
)

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Dataset({
    features: ['image', 'text'],
    num_rows: 1490
})
Dataset({
    features: ['image', 'text'],
    num_rows: 166
})
tensor([   2,  879,  933,  892,  896,    1, 3464, 3614,  882,  922, 1026, 1026,
           1,  869,  933,  892, 2561, 3539,  898,  885,  861,    1,  918,  885,
         888,  933,  896,  893,  941, 1026,    3,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0])




  0%|          | 0/2805 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
 18%|█▊        | 500/2805 [08:16<37:47,  1.02it/s]

{'loss': 1.9585, 'grad_norm': 0.861111581325531, 'learning_rate': 4.114081996434938e-05, 'epoch': 2.67}



Non-default generation parameters: {'max_length': 100, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


{'eval_loss': 0.6595956087112427, 'eval_wer_score': 1.1269698334083746, 'eval_runtime': 148.0841, 'eval_samples_per_second': 1.121, 'eval_steps_per_second': 0.142, 'epoch': 2.67}


 36%|███▌      | 1000/2805 [19:00<29:12,  1.03it/s]  

{'loss': 0.6315, 'grad_norm': 0.7005841135978699, 'learning_rate': 3.222816399286987e-05, 'epoch': 5.35}


                                                   
Non-default generation parameters: {'max_length': 100, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


{'eval_loss': 0.6298285126686096, 'eval_wer_score': 1.1508329581269698, 'eval_runtime': 64.269, 'eval_samples_per_second': 2.583, 'eval_steps_per_second': 0.327, 'epoch': 5.35}


 53%|█████▎    | 1500/2805 [28:15<20:05,  1.08it/s]   

{'loss': 0.5426, 'grad_norm': 1.022025227546692, 'learning_rate': 2.3315508021390376e-05, 'epoch': 8.02}


                                                   
Non-default generation parameters: {'max_length': 100, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


{'eval_loss': 0.6239275336265564, 'eval_wer_score': 1.158937415578568, 'eval_runtime': 73.867, 'eval_samples_per_second': 2.247, 'eval_steps_per_second': 0.284, 'epoch': 8.02}


 71%|███████▏  | 2000/2805 [44:19<08:50,  1.52it/s]    

{'loss': 0.4455, 'grad_norm': 1.1543413400650024, 'learning_rate': 1.4402852049910875e-05, 'epoch': 10.7}


                                                   
Non-default generation parameters: {'max_length': 100, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


{'eval_loss': 0.6420084834098816, 'eval_wer_score': 1.1580369203061684, 'eval_runtime': 114.5584, 'eval_samples_per_second': 1.449, 'eval_steps_per_second': 0.183, 'epoch': 10.7}


 89%|████████▉ | 2500/2805 [53:21<05:16,  1.04s/it]  

{'loss': 0.378, 'grad_norm': 0.9236273169517517, 'learning_rate': 5.490196078431373e-06, 'epoch': 13.37}


                                                   
Non-default generation parameters: {'max_length': 100, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


{'eval_loss': 0.6490842700004578, 'eval_wer_score': 1.1391265195857723, 'eval_runtime': 116.9775, 'eval_samples_per_second': 1.419, 'eval_steps_per_second': 0.18, 'epoch': 13.37}


100%|██████████| 2805/2805 [1:00:03<00:00,  1.28s/it]

{'train_runtime': 3604.5989, 'train_samples_per_second': 6.2, 'train_steps_per_second': 0.778, 'train_loss': 0.7428336729977858, 'epoch': 15.0}





TrainOutput(global_step=2805, training_loss=0.7428336729977858, metrics={'train_runtime': 3604.5989, 'train_samples_per_second': 6.2, 'train_steps_per_second': 0.778, 'train_loss': 0.7428336729977858, 'epoch': 15.0})

In [41]:
from PIL import Image
import requests
from transformers import AutoFeatureExtractor, ViTForImageClassification

#url = 'https://assets.pokemon.com/assets/cms2/img/pokedex/full/123.png'
#image = Image.open(requests.get(url, stream=True).raw)

path = './manga_dataset/crops/00000d36.png'
image = Image.open(path)

feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-patch16-224')

image = image.convert('L').convert('RGB')
inputs = feature_extractor(image, return_tensors="pt").pixel_values.squeeze()
inputs = model.generate(inputs[None].to(model.device), max_length=100)[0].cpu()

outputs = processor.tokenizer.decode(inputs, skip_special_tokens=True)


In [42]:

print(outputs)

は い お じ い さ ん


In [43]:
trainer.save_model("./model1")

Non-default generation parameters: {'max_length': 100, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
