In [None]:
import pandas as pd
from datasets import Dataset, DatasetDict, load_from_disk
import torchvision.transforms as transforms
import json
import os
import re
import datasets
import numpy as np
from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTModel, ViTImageProcessor, ViTFeatureExtractor
import wandb
from PIL import Image
import torch

In [None]:
torch.cuda.is_available()

In [None]:
from PIL import PngImagePlugin
LARGE_ENOUGH_NUMBER = 1000
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

ViT =  ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "gpt2")

In [None]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

model.generation_config.pad_token_id = model.generation_config.eos_token_id

model.generation_config

In [None]:

def tokenization_fn(captions, max_target_length):
    labels = tokenizer(captions, 
                      padding="max_length", 
                      max_length=max_target_length, truncation=True).input_ids

    return labels


def feature_extraction_fn(images):
    encoder_inputs = feature_extractor(images=images, return_tensors="pt")
    return encoder_inputs.pixel_values

def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
    image_paths = examples['raw_image']
    captions = examples['caption']
    
    model_inputs = {}
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['input_ids'] = model_inputs['labels']
    model_inputs['pixel_values'] = feature_extraction_fn(image_paths)

    return model_inputs

In [None]:
# processed_dataset.save_to_disk("processed_dataset")
processed_dataset = load_from_disk('processed_dataset')

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir="./image-captioning-output",
)


In [None]:
import evaluate
metric = evaluate.load("rouge")

In [None]:
import nltk

nltk.download("punkt", quiet=True)
metric = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    preds = [pred[0].tolist() for pred in preds]

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    assert len(decoded_preds) == len(decoded_labels)

    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    return result

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import default_data_collator

LORA_R = 256
LORA_ALPHA = 512 
LORA_DROPOUT = 0.05

#CONFIG FOR LORA, DORA, and RSLORA

lora_config = LoraConfig(
                # use_rslora=True,
                use_dora=True,
                 r = LORA_R, 
                 lora_alpha = LORA_ALPHA, 
                 lora_dropout = LORA_DROPOUT, 
                 bias="none",
                 target_modules=["query", "value", "key","gate","up","down","out"],
)

#quantization 
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
from transformers import default_data_collator, DataCollatorForSeq2Seq

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['val'],
    data_collator=default_data_collator,
)


In [None]:
wandb.init(project='LLM_Project_lora')

trainer.train()


In [None]:
trainer.save_model("models/dora")