In [1]:
%pip install datasets

Collecting datasets
  Downloading datasets-2.19.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=12.0.0 (from datasets)
  Downloading pyarrow-16.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Downloading pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (19 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->data

In [1]:
import re
from llava.model.builder import load_pretrained_model
from llava.mm_utils import (
    tokenizer_image_token,
    process_images,
    get_model_name_from_path)
from llava.eval.run_llava import eval_model
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IM_END_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IMAGE_TOKEN,
    IMAGE_PLACEHOLDER
)
from llava.conversation import conv_templates

model_path = "liuhaotian/llava-v1.6-mistral-7b"
model_name = get_model_name_from_path(model_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path),
    load_4bit=True
)

  from .autonotebook import tqdm as notebook_tqdm
Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
def create_prompt(query, model, model_name, caption=None):
    image_token_se = DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_TOKEN+DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in query:
        if model.config.mm_use_im_start_end:
            query = re.sub(IMAGE_PLACEHOLDER, image_token_se, query)
        else:
            query = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, query)
    else:
        if model.config.mm_use_im_start_end:
            query = image_token_se +"\n"+query
        else:
            query = DEFAULT_IMAGE_TOKEN+"\n"+query

    conv_mode = "mistral_instruct"
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], query)
    if caption is not None:
        conv.append_message(conv.roles[1],caption)
    else:
        conv.append_message(conv.roles[1],None)
    return conv.get_prompt()

In [None]:
import json
import torch
from PIL import Image
def process_prepare_img(data, image_processor, model, device):
    images = []
    
    for i in data:
        images.append(f'./images/{i['image']}')
    images = [Image.open(i).convert('RGB') for i in images]
    images_tensor = process_images(
        images, image_processor, model.config
    ).to(
        device, dtype=torch.bfloat16
    )
    image_sizes = [image.size for image in images]
    return images_tensor, image_sizes

In [None]:
from torch.nn.utils.rnn import pad_sequence
def tokenize_and_create_labels(example_batch, image_processor, tokenizer, model, device,ignore_index):
    pad_token_id = tokenizer.pad_token_id
    images = [i['image'] for i in example_batch]
    image_tensor, image_sizes = process_prepare_img(images,image_processor, model, device)
    query = "What is this picture about?"

    tokenized_conv_without_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, None) for _ in example_batch['caption'])
    ]

    tokenized_conv_with_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, caption) for caption in example_batch['caption'])
    ]

    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conv_with_caption])
    attention_mask = (input_ids!=pad_token_id).long().to(device)
    labels = torch.full_like(input_ids, fill_value=ignore_index)
    for i, tcwc in tokenized_conv_without_caption:
        input_id_no_caption = tcwc.squeeze(0)
        labels[i, len(input_id_no_caption):] = input_ids[i, len(input_id_no_caption):]


    inputs={
        "input_ids":input_ids,
        "attention_mask":attention_mask,
        "images":image_tensor,
        "image_sizes":image_sizes,
        "labels":labels
    }
    return inputs

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "mm_projector","up_proj","down_proj","gate_proj"
    ],
    lora_dropout=0.05,
    bias=None
)

model = get_peft_model(model, config)

In [None]:
from torch.utils.data import DataLoader
from datasets import load_dataset

ds = load_dataset('json',data_files="final_data.json",)
train_ds = ds[:50000]
eval_ds = ds[50000:]
train_ds = DataLoader(train_ds, batch_size=4, shuffle=True)

train_ds = [tokenize_and_create_labels(batch, image_processor, tokenizer, model, device="cuda", ignore_index=0) for batch in train_ds]

eval_ds = DataLoader(eval_ds, batch_size=4, shuffle=True)

eval_ds = [tokenize_and_create_labels(batch, image_processor, tokenizer, model, device="cuda", ignore_index=0) for batch in eval_ds]


output_model_name = "llava-mistral-1.6-finetuned-med"

In [None]:
from transformers import TrainingArguments, Trainer

trainingArgs = TrainingArguments(
    output_dir=output_model_name,
    learning_rate=1e-4,
    bf16=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=6,
    gradient_accumulation_steps=1,
    dataloader_pin_memory=False,
    save_total_limit=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=0.2,
    eval_steps=0.2,
    logging_steps=1,
    num_train_epochs=3,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    report_to=None,
    optim="adamw_torch"
)

trainer = Trainer(
    model=model,
    args=trainingArgs,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
)

trainer.train()

In [None]:
from llava.utils import disable_torch_init

def eval_model(tokenizer, model, image_processor, context_len, image_file,query,model_name=model_name, temperature=1.0, max_new_tokens=512, num_beams = 1,sep=","):
    disable_torch_init()
    prompt = create_prompt(query, model, model_name)
    if isinstance(image_file,str):
        images_tensor, image_sizes = process_prepare_img([image_file],image_processor, model, model.device)
    else:
        images_tensor, image_sizes = process_prepare_img(image_file, image_processor, model, model.device)
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(model.device)
    )
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            do_sample=temperature != 0,
            temperature=temperature,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
        )

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    print(outputs)

In [None]:
raw_image = Image.open('./hand_image.jpeg').convert('RGB')
prompt = "What is this image about?"

eval_model(
    tokenizer,
    model,
    image_processor,
    context_len,
    raw_image,
    prompt
)