In [None]:
# https://huggingface.co/yuanzhoulvpi/gpt2_chinese
# https://huggingface.co/google/vit-base-patch16-224
# https://huggingface.co/nlpconnect/vit-gpt2-image-captioning

In [None]:
from transformers import (VisionEncoderDecoderModel, 
                          BertConfig, ViTConfig,
                          ViTModel, GPT2Config, GPT2Model,GPT2LMHeadModel,
                          AutoTokenizer,ViTImageProcessor,
                          Trainer,TrainingArguments)
from typing import List, Any 
import torch
from torch import Tensor
from PIL import Image
from datasets import load_dataset,Dataset

from tqdm import tqdm 
import numpy as np 
import pandas as pd 

In [None]:
VIT_MODEL_NAME_OR_PATH = "google/vit-base-patch16-224"
GPT_MODEL_NAME_OR_PATH = "yuanzhoulvpi/gpt2_chinese"


VIT_model = ViTModel.from_pretrained(VIT_MODEL_NAME_OR_PATH)
GPT_model = GPT2LMHeadModel.from_pretrained(GPT_MODEL_NAME_OR_PATH, add_cross_attention=True)

GPT_model.config.add_cross_attention# = True

In [None]:
processor = ViTImageProcessor.from_pretrained(VIT_MODEL_NAME_OR_PATH)
tokenizer = AutoTokenizer.from_pretrained(GPT_MODEL_NAME_OR_PATH)

In [None]:
def process_image_2_pixel_value(x:str) -> Tensor:
    image = Image.open(x)
    res = processor(images=image, return_tensors='pt')['pixel_values'].squeeze(0)
    return res 


process_image_2_pixel_value(x = "bigdata/image_data/test-9282.jpg").shape 

In [None]:
def process_text_2_input_id(x:str) :
    res = tokenizer(text=x,max_length=100, truncation=True,padding="max_length")['input_ids']
    return res 

len(process_text_2_input_id(x='hhh'))

# len(process_text_2_input_id(x="你好啊，csdhhchsh谁cdshhchshcsdhhhhhhhh")['input_ids'])

In [None]:
tokenizer.pad_token_id

In [None]:
# GPT_model.config.add_cross_attention = True
# # GPT_model.crossattention = False
# GPT_model.config.add_cross_attention

# # config.add_cross_attention=True
# # hasattr(GPT_model, "crossattention")

In [None]:
new_encoder_decoder_model = VisionEncoderDecoderModel(
    encoder=VIT_model,
    decoder=GPT_model,
    
)
# new_encoder_decoder_model.config.use_return_dict = False
new_encoder_decoder_model.config.decoder_start_token_id = tokenizer.bos_token_id
new_encoder_decoder_model.config.pad_token_id = tokenizer.pad_token_id

# new_encoder_decoder_model.decoder.config.add_cross_attention=True

In [None]:
torch.tensor(process_text_2_input_id(x='hhh'), dtype=torch.long).unsqueeze(0).shape

In [None]:
new_encoder_decoder_model.config.add_cross_attention = True
new_encoder_decoder_model.config.add_cross_attention

In [None]:
outputs = new_encoder_decoder_model(pixel_values=torch.randn(1, 3,224,224), 
                                    labels=torch.tensor(process_text_2_input_id(x='hhh'), dtype=torch.long).unsqueeze(0),
                                    # return_dict=False
                                    )

In [None]:
outputs.loss

In [None]:
# new_encoder_decoder_model.config.decoder

In [None]:
# config.add_cross_attention
# new_encoder_decoder_model.decoder.config.add_cross_attention = True

In [None]:
# new_encoder_decoder_model.config.decoder_start_token_id is None

In [None]:
# torch.Tensor([1,2,3]).to(torch.long)

In [None]:
dataset = Dataset.from_pandas(df=pd.read_csv("bigdata/clean_train_test/test.csv"))
dataset = dataset.train_test_split(test_size=0.02)


def tokenizer_text(examples) :
    examples['labels'] = [process_text_2_input_id(i) for i in examples['text']]
    # res = [process_text_2_input_id(i) for i in examples['text']]
    # examples['labels'] = [i['input_ids'] for i in res]
    return examples

def transform_images(examples):
    images = [process_image_2_pixel_value(i) for i in examples['image_path']]
    # images = [torch.Tensor(i) for i in images]
    examples['pixel_values'] = images
    return examples

dataset = dataset.map(
    function=tokenizer_text,
    batched=True
)
# dataset = dataset.map(
#     function=transform_images,
#     batched=True
# )

dataset.set_transform(transform=transform_images)


dataset

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([i['pixel_values'] for i in examples])
    labels = torch.tensor([example["labels"] for example in examples], dtype=torch.long)
    return {
        "pixel_values": pixel_values,
        "labels": labels
    }


train_argument = TrainingArguments(
    output_dir="vit-gpt2-image-chinese-captioning",
    per_device_train_batch_size=28,
    per_device_eval_batch_size=28,
    evaluation_strategy="steps",
    eval_steps=30,
    logging_steps=30,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=30,
    fp16=True,
    remove_unused_columns=False,
    save_total_limit=4

)



trainer = Trainer(
    model=new_encoder_decoder_model,
    args=train_argument,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    data_collator=collate_fn,
)
trainer.train()

In [None]:
dataset['train'][0]['pixel_values'].numpy().ndim

In [None]:
dataset['train'][0]['pixel_values'].squeeze(0).numpy().ndim