In [None]:
import os

import evaluate
import torch
import pandas as pd
from box import Box
from PIL import Image
from torch.utils.data import Dataset
from transformers import (GPT2Tokenizer, Seq2SeqTrainer,
                          Seq2SeqTrainingArguments, VisionEncoderDecoderModel,
                          ViTFeatureExtractor, default_data_collator,
                          get_linear_schedule_with_warmup)
from transformers.optimization import AdamW

In [None]:
class ImageCaptionDataset(Dataset):
    def __init__(
        self,
        df,
        feature_extractor,
        tokenizer,
        image_dir,
        max_length,
        eval_mode=False,
    ):
        self.df = df
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.image_dir = image_dir
        self.max_length = max_length
        self.eval_mode = eval_mode

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return {
            'pixel_values': self._get_pixel_values(idx),
            'labels': self._get_labels(idx),
        }

    def _get_pixel_values(self, idx):
        image_filepath = os.path.join(
            self.image_dir,
            '{0}.jpg'.format(
                self.df['filmId'][idx],
            ),
        )
        image = Image.open(image_filepath).convert('RGB')
        return self.feature_extractor(
            image,
            return_tensors='pt',
        ).pixel_values

    def _get_labels(self, idx):
        caption = self._add_special_tokens(
            self.df['description'][idx],
        )
        padding = False
        truncation = False
        if not self.eval_mode:
            padding = 'max_length'
            truncation = True
        labels = self.tokenizer(
            caption,
            padding=padding,
            truncation=truncation,
            max_length=self.max_length,
        ).input_ids.squeeze()
        return self._labels_mask(labels)

    def _add_special_tokens(self, text):
        return '{0} {1} {2}'.format(
            self.tokenizer.bos_token,
            text,
            self.tokenizer.eos_token,
        )

    def _labels_mask(self, labels):
        return torch.where(
            labels != self.tokenizer.pad_token_id,
            labels,
            -100,
        )

In [None]:
def compute_metrics(batch):
    y_pred = batch.predictions
    y_true = batch.label_ids

    predictions = tokenizer.batch_decode(
        y_pred, skip_special_tokens=True,
    )
    references = tokenizer.batch_decode(
        y_true, skip_special_tokens=True,
    )

    score = metric.compute(
        predictions=[predictions], references=[references],
    )
    return {'SacreBLEU': score}

In [None]:
device = torch.device(
    'cuda:0' if torch.cuda.is_available() else 'cpu'
)

data_config = Box(
    {
        'max_length': 64,
        'image_dir': 'data/img',
    },
)

model_config = Box(
    {
        'pretrained_encoder_name': 'google/vit-base-patch16-224-in21k',
        'pretrained_decoder_name': 'sberbank-ai/rugpt3small_based_on_gpt2',
        'num_beams': 4,
        'early_stopping': True,
        'no_repeat_ngram_size': 3,
        'length_penalty': 2.0,
        'repetition_penalty': 2.0,
    },
)

trainer_config = {
    'num_train_epochs': 20,
    'per_device_train_batch_size': 32,
    'per_device_eval_batch_size': 1,
    'output_dir': './output',
    'do_train': True,
    'do_eval': True,
    'fp16': True,
    'learning_rate': 1e-5,
    'load_best_model_at_end': False,
    'evaluation_strategy': 'epoch',
    'save_strategy': 'epoch',
    'save_total_limit': 2,
    'report_to': 'none',
}

seq2seq_trainer_config = Seq2SeqTrainingArguments(**trainer_config)

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained(
    model_config.pretrained_encoder_name,
)

tokenizer = GPT2Tokenizer.from_pretrained(
    model_config.pretrained_decoder_name, use_fast=True,
)
tokens_to_add = {
    'pad_token': '[PAD]',
    'bos_token': '[BOS]',
    'eos_token': '[EOS]',
}
tokenizer.add_special_tokens(tokens_to_add)

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    model_config.pretrained_encoder_name,
    model_config.pretrained_decoder_name,
)

model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

model.decoder.resize_token_embeddings(len(tokenizer))
model.config.vocab_size = model.config.decoder.vocab_size

model.config.max_length = model_config.max_length
model.config.num_beams = model_config.num_beams
model.config.early_stopping = model_config.early_stopping
model.config.no_repeat_ngram_size = model_config.no_repeat_ngram_size
model.config.length_penalty = model_config.length_penalty
model.config.repetition_penalty = model_config.repetition_penalty

In [None]:
train_df = pd.read_csv('../input/poster-images/train.csv')
eval_df = pd.read_csv('../input/poster-images/val.csv')

train_dataset = ImageCaptionDataset(
    df=train_df,
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    image_dir=data_config.image_dir,
    max_length=data_config.max_length,
)
eval_dataset = ImageCaptionDataset(
    df=eval_df,
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    image_dir=data_config.image_dir,
    max_length=data_config.max_length,
    eval_mode=True,
)

In [None]:
metric = evaluate.load('sacrebleu')

optimizer = AdamW(
    model.parameters(),
    lr=seq2seq_trainer_config.learning_rate,
)

steps_per_epoch = len(train_dataset) // seq2seq_trainer_config.per_device_train_batch_size
num_training_steps = steps_per_epoch * seq2seq_trainer_config.num_train_epochs

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=seq2seq_trainer_config.warmup_steps,
    num_training_steps=num_training_steps,
)

optimizers = (optimizer, lr_scheduler)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    optimizers=optimizers,
    tokenizer=feature_extractor,
    args=seq2seq_trainer_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

trainer.train()