In [None]:
import os

import torch
import pandas as pd
from box import Box
from PIL import Image
from torch.utils.data import Dataset
from transformers import (AutoTokenizer, Seq2SeqTrainer,
                          Seq2SeqTrainingArguments, VisionEncoderDecoderModel,
                          ViTFeatureExtractor, default_data_collator,
                          get_linear_schedule_with_warmup)
from transformers.optimization import AdamW

In [None]:
class ImageCaptionDataset(Dataset):
    def __init__(
        self,
        data,
        image_dir,
        feature_extractor,
        tokenizer,
        max_length,
    ):
        self.data = data
        self.image_dir = image_dir
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            'pixel_values': self._get_pixel_values(idx),
            'labels': self._get_labels(idx),
        }

    def _get_pixel_values(self, idx):
        product_id = self.data['article_id'][idx]
        image_filepath = os.path.join(
            self.image_dir,
            product_id[:3],
            '{0}.jpg'.format(
                product_id,
            ),
        )
        image = Image.open(image_filepath).convert('RGB')
        return self.feature_extractor(
            image,
            return_tensors='pt',
        ).pixel_values.squeeze()

    def _get_labels(self, idx):
        caption = self.data['detail_desc'][idx]
        labels = self.tokenizer(
            caption,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt',
        ).input_ids
        return self._labels_mask(labels)

    def _labels_mask(self, labels):
        return torch.where(
            labels != self.tokenizer.pad_token_id,
            labels,
            -100,
        )

In [None]:
device = torch.device(
    'cuda:0' if torch.cuda.is_available() else 'cpu'
)

data_training_args = Box(
    {
        'max_target_length': 64,
        'num_beams': 4,
        'images_dir': 'data/img',
    },
)

model_args = Box(
    {
        'encoder_model_name_or_path': 'google/vit-base-patch16-224-in21k',
        'decoder_model_name_or_path': 'sberbank-ai/rugpt3small_based_on_gpt2',
        'no_repeat_ngram_size': 3,
        'length_penalty': 2.0,
    },
)

training_args = {
    'num_train_epochs': 20,
    'per_device_train_batch_size': 32,
    'per_device_eval_batch_size': 1,
    'output_dir': './output',
    'do_train': True,
    'do_eval': True,
    'fp16': True,
    'learning_rate': 1e-5,
    'load_best_model_at_end': False,
    'evaluation_strategy': 'epoch',
    'save_strategy': 'epoch',
    'save_total_limit': 2,
    'report_to': 'none'
}

seq2seq_training_args = Seq2SeqTrainingArguments(**training_args)

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained(
    model_args.encoder_model_name_or_path
)
tokenizer = AutoTokenizer.from_pretrained(
    model_args.decoder_model_name_or_path, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    model_args.encoder_model_name_or_path, model_args.decoder_model_name_or_path
)

model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = data_training_args.max_target_length
model.config.no_repeat_ngram_size = model_args.no_repeat_ngram_size
model.config.length_penalty = model_args.length_penalty
model.config.num_beams = data_training_args.num_beams
model.decoder.resize_token_embeddings(len(tokenizer))

In [None]:
train_df = pd.read_csv('../input/poster-images/train.csv')
eval_df = pd.read_csv('../input/poster-images/val.csv')

train_dataset = ImageCaptionDataset(
    data=train_df,
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    image_dir=data_training_args.images_dir,
    max_length=data_training_args.max_target_length,
)
eval_dataset = ImageCaptionDataset(
    data=eval_df,
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    image_dir=data_training_args.images_dir,
    max_length=data_training_args.max_target_length,
)

In [None]:
optimizer = AdamW(
    model.parameters(),
    lr=seq2seq_training_args.learning_rate,
)

steps_per_epoch = len(train_dataset) // seq2seq_training_args.per_device_train_batch_size
num_training_steps = steps_per_epoch * seq2seq_training_args.num_train_epochs

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=seq2seq_training_args.warmup_steps,
    num_training_steps=num_training_steps,
)

optimizers = (optimizer, lr_scheduler)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    optimizers=optimizers,
    tokenizer=feature_extractor,
    args=seq2seq_training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

trainer.train()