In [1]:
import pandas as pd
import os
import datasets
import torch
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer, ViTImageProcessor, AutoImageProcessor
from PIL import Image
from datasets import Dataset, DatasetDict
from transformers import DataCollatorForSeq2Seq

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["WANDB_PROJECT"]="vit-gpt"
os.environ["TOKENIZERS_PARALLELISM"]="true"

In [3]:
df_train = pd.read_csv('/users/snranepuradewage/roco-dataset-master/data-master/train_captions.csv')
df_val = pd.read_csv('/users/snranepuradewage/roco-dataset-master/data-master/valid_captions.csv')
df_test = pd.read_csv('/users/snranepuradewage/roco-dataset-master/data-master/test_captions.csv')

In [4]:
df_train["ID"] = "/users/snranepuradewage/roco-dataset-master/data-master/train/" + df_train["ID"]+ ".jpg"
df_val["ID"] = "/users/snranepuradewage/roco-dataset-master/data-master/valid/" + df_val["ID"]+ ".jpg"
df_test["ID"] = "/users/snranepuradewage/roco-dataset-master/data-master/test/" + df_test["ID"]+ ".jpg"

In [5]:
df_train = pd.concat([df_train, df_val], ignore_index=True)

train_dataset = Dataset.from_pandas(df_train)
test_dataset = Dataset.from_pandas(df_test)

dataset_clef = DatasetDict({"train": train_dataset, "test": test_dataset})

In [6]:
model_name = "vit_biomedlm_caption_model"

model = VisionEncoderDecoderModel.from_pretrained(model_name)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.to("cuda:0")

Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.71s/it]


VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_featur

In [7]:
# text preprocessing step
def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, 
                        padding="max_length", 
                        max_length=max_target_length, truncation=True).input_ids

    return labels

In [8]:
# image preprocessing step
def feature_extraction_fn(image_paths, check_image=True):
    """
    Run feature extraction on images
    If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
    Otherwise, an exception will be thrown.
    """
    model_inputs = {}

    if check_image:
        images = []
        for image_path in image_paths:
            i_image = Image.open(image_path)
            if i_image.mode != "RGB":
                i_image = i_image.convert(mode="RGB")

            images.append(i_image)

    encoder_inputs = feature_extractor(images=images, return_tensors="pt")
    return encoder_inputs.pixel_values


In [9]:
def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
    image_paths = examples['ID']
    captions = examples['Caption']    
        
    model_inputs = {}
    # This contains image path column
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)

    return model_inputs


In [10]:
processed_dataset = dataset_clef.map(
        function=preprocess_fn,
        batched=True,
        fn_kwargs={"max_target_length": 128},
        num_proc=1
            )

Map: 100%|██████████| 69862/69862 [07:42<00:00, 151.00 examples/s]
Map: 100%|██████████| 9927/9927 [01:09<00:00, 142.58 examples/s]


In [11]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
        predict_with_generate=False,
        num_train_epochs=2,
        #eval_steps=1000,
        evaluation_strategy= "no",
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=2,
        output_dir="./image-captioning-output-vit-biomedlm-roco2",
        optim="adafactor",
        fp16=True,
        report_to="wandb"
            )



In [12]:
import evaluate
metric = evaluate.load("rouge")

In [13]:
import numpy as np
import nltk

try:
        nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
        nltk.download("punkt", quiet=True)

ignore_pad_token_for_loss = True


In [14]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

In [15]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds,  references=decoded_labels, use_stemmer=False)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    return result

In [16]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['test'],
    data_collator=default_data_collator
        )

  trainer = Seq2SeqTrainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
#trainer.train()

In [None]:
#trainer.train(resume_from_checkpoint="./image-captioning-output-vit-biomedlm-roco2/checkpoint-17000")

In [None]:
# trainer.save_model("./image-captioning-output-vit-biomedlm-roco2")
# tokenizer.save_pretrained("./image-captioning-output-vit-biomedlm-roco2")
# feature_extractor.save_pretrained("./image-captioning-output-vit-biomedlm-roco2")

Caption generation + evaluation

In [17]:
path = "./image-captioning-output-vit-biomedlm-roco2"
model = VisionEncoderDecoderModel.from_pretrained(path).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(path)
image_processor = AutoImageProcessor.from_pretrained(path)

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.32it/s]


In [18]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = getattr(tokenizer, "bos_token_id", None) or tokenizer.eos_token_id
model.config.eos_token_id = tokenizer.eos_token_id

In [19]:
model.generation_config.pad_token_id = model.config.pad_token_id
model.generation_config.eos_token_id = model.config.eos_token_id
model.generation_config.max_new_tokens = 64
model.generation_config.num_beams = 3

In [20]:
test_image_dir = "/users/snranepuradewage/roco-dataset-master/data-master/test/"
df_test = pd.read_csv('/users/snranepuradewage/roco-dataset-master/data-master/test_captions.csv')

df_test["image_path"] = df_test["ID"].apply(lambda x: os.path.join(test_image_dir, f"{x}.jpg"))

images = [Image.open(p).convert("RGB") for p in df_test["image_path"]]
captions_true = df_test["Caption"].tolist()

In [21]:
# --- Generate captions ---
preds = []
for i, image in enumerate(images[:100]):   # try first 100 for speed
    inputs = image_processor(image, return_tensors="pt")
    pixel_values = inputs.pixel_values.to("cuda")
    attention_mask = torch.ones(pixel_values.shape[:-1], dtype=torch.long).to("cuda")  # dummy mask

    output_ids = model.generate(
        pixel_values=pixel_values,
        attention_mask=attention_mask,
        pad_token_id=model.config.pad_token_id,
        eos_token_id=model.config.eos_token_id,
        max_new_tokens=64,
        num_beams=3,
    )
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    preds.append(caption.strip())


In [22]:
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
bertscore = evaluate.load("bertscore")

In [23]:
result_rouge = rouge.compute(predictions=preds, references=captions_true[:len(preds)])
result_bleu = bleu.compute(predictions=preds, references=captions_true[:len(preds)])
result_bertscore = bertscore.compute(predictions=preds, references=captions_true[:len(preds)], lang="en")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
print(result_rouge,'\n')
print(result_bleu,'\n')
print(result_bertscore["f1"][0:10])

{'rouge1': 0.23708412701572446, 'rouge2': 0.08168826226621287, 'rougeL': 0.20921788797644297, 'rougeLsum': 0.20900434415085767} 

{'bleu': 0.03418002835753525, 'precisions': [0.24855491329479767, 0.0719207579672696, 0.02025202520252025, 0.003770028275212064], 'brevity_penalty': 1.0, 'length_ratio': 1.0288870008496176, 'translation_length': 2422, 'reference_length': 2354} 

[0.8490434288978577, 0.8765621781349182, 0.8977518081665039, 0.8716443181037903, 0.8715817332267761, 0.8917078375816345, 0.889004111289978, 0.8263463377952576, 0.8919671773910522, 0.8868426084518433]
