In [None]:
!pip install transformers datasets evaluate rouge_score

In [None]:
import os
import datasets
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import nltk
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    nltk.download("punkt", quiet=True)

In [None]:
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor
import torch

image_encoder_model = "google/vit-base-patch16-224-in21k"

text_decode_model = "gpt2-large"
#text_decode_model = "gpt2"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)

# text tokenizer
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)

# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token


In [None]:
#vision_config.hidden_size

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_pretrained_model_name_or_path=image_encoder_model, 
    decoder_pretrained_model_name_or_path=text_decode_model)

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

output_dir = "vit-gpt-model"
model.save_pretrained(output_dir)
feature_extractor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
import sys
if 'google.colab' in sys.modules:
    from google.colab import output
    output.enable_custom_widget_manager()

!pip install huggingface_hub

from huggingface_hub import notebook_login
notebook_login()

In [None]:
import datasets

ds = datasets.load_dataset("vlordier/mm-celeba-hq", use_auth_token=True)
ds

In [None]:
# 90% train, 10% test + validation
train_testvalid = ds['train'].train_test_split(0.1)

# Split the 10% test + valid in half test, half valid
test_valid = train_testvalid['test'].train_test_split(0.5)
# gather everyone if you want to have a single DatasetDict
ds = datasets.DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})
ds

In [None]:
from PIL import Image
from torchvision.transforms import Compose, ColorJitter, RandomRotation, RandomResizedCrop, ToTensor
import random

augment = Compose(
    [RandomRotation(20), RandomResizedCrop(size=224, scale=(0.8,1.0)), ColorJitter(brightness=0.3, contrast=0.3, hue=0.2), ToTensor()]
)

def aug_img(images):
    
    return [augment(image) for image in images]

# text preprocessing step
def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, 
                      padding="max_length", 
                      max_length=max_target_length).input_ids

    return labels

# image preprocessing step
def feature_extraction_fn(images, check_image=False):
    """
    Run feature extraction on images
    If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
    Otherwise, an exception will be thrown.
    """
    model_inputs = {}
    encoder_inputs = feature_extractor(images=images, return_tensors="np")
    return encoder_inputs.pixel_values

def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
#    image = aug_img(examples['image'])
    image = examples['image']

#    captions_indexes = ['caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5']
#    idx = random.randint(0, len(captions_indexes)-1)
#    print(idx)
#    index = random.randint(0, len(captions_indexes))
#    print(captions_indexes[index])
#    caption = examples[captions_indexes[index]]

    caption = examples['caption_1']
    
    model_inputs = {}
#    # This contains image path column
    model_inputs['labels'] = tokenization_fn(caption, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(image, check_image=check_image)

    return model_inputs

In [None]:
processed_dataset = ds.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": 128},
    remove_columns=ds['train'].column_names
)


In [None]:
processed_dataset

In [None]:
import torch
from transformers import pipeline

summarizer = pipeline(
    "summarization",
    "pszemraj/long-t5-tglobal-base-16384-book-summary",
    device=0 if torch.cuda.is_available() else -1,
    max_length=128
)
#long_text = "Here is a lot of text I don't want to read. Replace me"

#result = summarizer(long_text)
#print(result[0]["summary_text"])

def collate_txt(examples):
    captions_indexes = ['caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5']
    random.shuffle(captions_indexes)
    txt = ''
    for idx in captions_indexes:
        print(idx)
        print(examples[idx])
        txt+=' '+examples[idx]
#    summarizer(txt)
#    examples = [caption for example in example[]]
#    caption = examples[random.randint(0, len(captions))]

#    return mode_inputs



In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=28,
    per_device_eval_batch_size=8,
    output_dir="./image-captioning-output",
    bf16=True,
    gradient_accumulation_steps=4
)

In [None]:
import evaluate
metric = evaluate.load("rouge")

In [None]:
import numpy as np

ignore_pad_token_for_loss = True


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

In [None]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['valid'],
    data_collator=default_data_collator,
)

In [None]:
trainer.train()
trainer.save_model("./image-captioning-output")
tokenizer.save_pretrained("./image-captioning-output")

In [None]:
from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel 
from PIL import Image

device='cuda'
encoder_checkpoint = "./image-captioning-output"
decoder_checkpoint = "./image-captioning-output"
model_checkpoint = "./image-captioning-output"
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)


def predict(image,max_length=64, num_beams=4):
  image = image.convert('RGB')
  image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
  clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
  caption_ids = model.generate(image, max_length = max_length)[0]
  caption_text = clean_text(tokenizer.decode(caption_ids))
  return caption_text 

In [None]:
image = '/volume/person/0-11d72820bf47948440f94f2d5f7d0f3cdd4c0073.jpg'

image = Image.open(image)
display(image)
txt = predict(image)
print(txt)