In [None]:
from dataclasses import dataclass

import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from pathlib import Path
from typing import List

In [None]:
data_dir = "/home/yujiaxi/DLModels/llavalearn/data/liuhaotian/LLaVA-CC3M-Pretrain-595K"
data_dir

In [None]:
class LlavaDataset(Dataset):
    def __init__(self, dataset_dir: str) -> None:
        super().__init__()

        self.chat_data, self.image_dir = self.build_dataset(dataset_dir)

    def build_dataset(self, data_dir:str) -> tuple[List, Path]:
        data_dir = Path(data_dir)
        chat_file = data_dir.joinpath("chat.json")
        image_dir = data_dir.joinpath("images-dl")

        chat_data = pd.read_json(chat_file).to_dict(orient="records")

        return chat_data, image_dir
    
    def __len__(self):
        return len(self.chat_data)
    
    def __getitem__(self, index):
        cur_data = self.chat_data[index]

        human_input = cur_data['conversations'][0]['value']
        gpt_output = cur_data['conversations'][1]['value']

        image_path = self.image_dir.joinpath(cur_data.get('image'))

        return (human_input, gpt_output, image_path)
    
test_llavatest = LlavaDataset(dataset_dir=data_dir)
    

In [None]:
len(test_llavatest)

In [None]:
test_llavatest[19]

In [None]:
Image.open(test_llavatest[19][2])

In [None]:
from transformers import AutoProcessor

llava_model_name_or_path = "/home/yujiaxi/DLModels/llavalearn/show_model/model001"
llava_processor = AutoProcessor.from_pretrained(llava_model_name_or_path)

In [None]:
test002 = test_llavatest[19]

In [None]:
from dataclasses import dataclass

@dataclass
class QaImageOutput:
    q_input_ids: torch.Tensor
    pixel_values:torch.Tensor
    a_input_ids:torch.Tensor


def build_qaimage(processor: AutoProcessor, q_text:str, a_text:str, image_path: Path):
    
    # instruction or input or question
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": q_text},
    ]
    prompt = processor.tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_file = image_path
    raw_image = Image.open(image_file)

    inputs = processor(prompt, raw_image, return_tensors="pt")

    a_input_ids = processor.tokenizer(
        a_text,
        return_tensors="pt",
        padding="longest",
        truncation=True,
    )["input_ids"]

    return QaImageOutput(
        q_input_ids=inputs['input_ids'],
        pixel_values=inputs["pixel_values"],
        a_input_ids=a_input_ids
        )  
    

c = build_qaimage(llava_processor, test002[0], test002[1], test002[2])
c #.keys()

In [None]:
c.q_input_ids

In [None]:
llava_processor.decode([90527, 12452,  1273,  9606,  1154,  1573,   279, 10778,   389,  6775,
                        8039,   659])

In [None]:
c.pixel_values.shape

In [None]:
from typing import Any


class TrainLlavaModelCollector:
    def __init__(self, processor:AutoProcessor, IGNORE_INDEX:int) -> None:
        self.processor = processor
        self.ignore_index = IGNORE_INDEX

    def convert_one_piece(self,
                          q_input_ids:torch.Tensor,
                          a_input_ids:torch.Tensor):
        input_ids = torch.concat([
            q_input_ids,
            a_input_ids,
            torch.tensor(self.processor.tokenizer.eos_token_id).reshape(1, -1)
        ], dim=1)
        labels = torch.concat([
            torch.full_like(q_input_ids, fill_value=self.ignore_index),
            a_input_ids,
            torch.tensor(self.processor.tokenizer.eos_token_id).reshape(1, -1)
        ], dim=1)

        return input_ids, labels
    
    def __call__(self, features:List) -> Any:
        
        input_ids_list = []
        labels_list = []
        pixel_values = []
        max_input_len_list = []

        for feature in features:
            qaimage_output = build_qaimage(
                self.processor,
                feature[0],
                feature[1],
                feature[2]
            )
            temp_input_ids, temp_labels = self.convert_one_piece(
                qaimage_output.q_input_ids,
                qaimage_output.a_input_ids
            )
            max_input_len_list.append(temp_input_ids.shape[1])
            input_ids_list.append(temp_input_ids)
            labels_list.append(temp_labels)
            pixel_values.append(qaimage_output.pixel_values)

        max_input_len = max(max_input_len_list)

        final_input_ids = torch.concat([
            torch.concat([
                torch.full(
                    (1, max_input_len - max_input_len_list[index]),
                    self.processor.tokenizer.pad_token_id,
                ),value,
            ], axis=1)
            for index, value in enumerate(input_ids_list)
        ])
        
        final_labels = torch.concat([
            torch.concat([
                torch.full(
                    (1, max_input_len - max_input_len_list[index]),
                    self.processor.tokenizer.pad_token_id,
                ),value,
            ], axis=1)
            for index, value in enumerate(labels_list)
        ])

        final_pixel_values = torch.concat(pixel_values, axis=0)
        
        attention_mask = torch.ones_like(final_input_ids)
        attention_mask[final_input_ids == self.processor.tokenizer.pad_token_id] = 0

        return {
            "input_ids":final_input_ids,
            "labels":final_labels,
            "pixel_values":final_pixel_values,
            "attention_mask":attention_mask
        }

tlmc = TrainLlavaModelCollector(llava_processor, -100)
# tlmc.convert_one_piece(c.q_input_ids, c.a_input_ids)

d = tlmc([test_llavatest[13],])
d.keys()


In [None]:
# d["input_ids"].shape
# d["input_ids"]
d["labels"]
d["input_ids"]
d["attention_mask"]

In [None]:
from transformers import AutoProcessor, LlavaForConditionalGeneration

llava_model_name_or_path = "show_model/model001"
llava_model = LlavaForConditionalGeneration.from_pretrained(llava_model_name_or_path,
                                                            torch_dtype=torch.bfloat16,
                                                            device_map='cuda:0')

In [None]:
d.keys()

In [None]:
for tk in d.keys():
    d[tk] = d[tk].to(llava_model.device)

model_output = llava_model(**d)

In [None]:
model_output.loss