In [1]:
import pandas as pd
from datasets import Dataset, DatasetDict, load_from_disk
import torchvision.transforms as transforms
import json
import os
import re
import datasets
import numpy as np
from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTModel, ViTImageProcessor, ViTFeatureExtractor
import wandb
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_train_dataset = load_from_disk("processed_train")
raw_val_dataset = load_from_disk("processed_val")

In [3]:
train_dataset = raw_val_dataset.select(range(1000))
val_dataset = raw_val_dataset.select(range(1000,1125))

ds = DatasetDict({'train':train_dataset, 'val':val_dataset})
ds

DatasetDict({
    train: Dataset({
        features: ['image_id', 'id', 'caption', 'img_path', 'is_file', 'raw_image', 'img'],
        num_rows: 100
    })
    val: Dataset({
        features: ['image_id', 'id', 'caption', 'img_path', 'is_file', 'raw_image', 'img'],
        num_rows: 25
    })
})

In [4]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

ViT =  ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "gpt2")

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

In [5]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

model.generation_config.pad_token_id = model.generation_config.eos_token_id

model.generation_config

GenerationConfig {
  "bos_token_id": 50256,
  "eos_token_id": 50256,
  "pad_token_id": 50256
}

In [6]:
# text preprocessing step
def tokenization_fn(captions, max_target_length):
    labels = tokenizer(captions, 
                      padding="max_length", 
                      max_length=max_target_length, truncation=True).input_ids

    return labels

# image preprocessing step
def feature_extraction_fn(images):
    encoder_inputs = feature_extractor(images=images, return_tensors="pt")
    return encoder_inputs.pixel_values

def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
    image_paths = examples['raw_image']
    captions = examples['caption']
    
    model_inputs = {}
    # This contains image path column
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['input_ids'] = model_inputs['labels']
    model_inputs['pixel_values'] = feature_extraction_fn(image_paths)

    return model_inputs

In [7]:
processed_dataset = ds.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": 128},
    remove_columns=ds['train'].column_names
)

processed_dataset['val']

Dataset({
    features: ['labels', 'input_ids', 'pixel_values'],
    num_rows: 25
})

In [8]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir="./image-captioning-output",
)


In [9]:
import evaluate
metric = evaluate.load("rouge")

In [10]:
import nltk

nltk.download("punkt", quiet=True)
metric = evaluate.load("rouge", "bleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    preds = [pred[0].tolist() for pred in preds]
    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    assert len(decoded_preds) == len(decoded_labels)

    # rougeLSum expects newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # result = {'dummy_metric': 0.7}
    return result

In [11]:
from transformers import default_data_collator, DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['val'],
    data_collator=default_data_collator,
)


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [12]:
len(processed_dataset['val']['labels'])

25

In [13]:

for batch in trainer.get_eval_dataloader(processed_dataset['val']):
    for label in batch['labels']:
        if len(label) != 128:
            print(len(label))

In [14]:
wandb.init(project='LLM_Project_few_shot')

trainer.train()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvaradnev[0m ([33mvarnevnlp2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


                                               
 33%|███▎      | 25/75 [01:40<02:33,  3.07s/it]

{'eval_loss': 0.5511487722396851, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 18.3417, 'eval_samples_per_second': 1.363, 'eval_steps_per_second': 0.382, 'epoch': 1.0}


                                               
 67%|██████▋   | 50/75 [03:30<01:22,  3.32s/it]

{'eval_loss': 0.42961522936820984, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 18.5315, 'eval_samples_per_second': 1.349, 'eval_steps_per_second': 0.378, 'epoch': 2.0}


                                               
100%|██████████| 75/75 [05:19<00:00,  4.26s/it]

{'eval_loss': 0.37983497977256775, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 22.8254, 'eval_samples_per_second': 1.095, 'eval_steps_per_second': 0.307, 'epoch': 3.0}
{'train_runtime': 319.7874, 'train_samples_per_second': 0.938, 'train_steps_per_second': 0.235, 'train_loss': 0.6545019022623698, 'epoch': 3.0}





TrainOutput(global_step=75, training_loss=0.6545019022623698, metrics={'train_runtime': 319.7874, 'train_samples_per_second': 0.938, 'train_steps_per_second': 0.235, 'train_loss': 0.6545019022623698, 'epoch': 3.0})

In [15]:
trainer.save_model("models/few_shot_1000")

Non-default generation parameters: {'max_length': 128, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
