# Fine-tuning notebook for the Llava using the Quilt-Pretrained dataset
This code will use the existing code repo as much as possible.
Its goal is for quickly prototyping the Llava-based chat system.
Reference
- [Llava doc](https://huggingface.co/docs/transformers/en/model_doc/llava)

In [1]:
from huggingface_hub import login

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
login("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")

In [3]:
import torch
from typing import Dict, List, Union
from transformers import BitsAndBytesConfig, AutoTokenizer
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
from llava.conversation import conv_llava_plain
from llava.train.train import ModelArguments



[2024-10-28 19:18:08,877] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [4]:
_VISION_TOWER = "wisdomik/QuiltNet-B-32"  # model_args.vision_tower
_MODEL_ID = "lmsys/vicuna-7b-v1.5" # model_args.model_name_or_path - Before pretraining
#_MODEL_ID = "wisdomik/Quilt-Llava-v1.5-7b" # Pretrained



In [5]:
_MODEL_ID

'lmsys/vicuna-7b-v1.5'

In [6]:
model_arguments = ModelArguments(
    model_name_or_path=_MODEL_ID, 
    version='plain', 
    freeze_backbone=False, 
    tune_mm_mlp_adapter=True, 
    vision_tower=_VISION_TOWER, 
    mm_vision_select_layer=-2, 
    pretrain_mm_mlp_adapter=None, 
    mm_projector_type='mlp2x_gelu', 
    mm_use_im_start_end=False, 
    mm_use_im_patch_token=False, 
    mm_vision_select_feature='patch',
)

In [7]:
# More details about this model can be found at https://huggingface.co/docs/transformers/en/model_doc/llava
model = LlavaLlamaForCausalLM.from_pretrained(
    pretrained_model_name_or_path=_MODEL_ID,
)

You are using a model of type llama to instantiate a model of type llava. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards: 100%|██████████| 2/2 [00:29<00:00, 14.71s/it]


In [8]:
tokenizer = AutoTokenizer.from_pretrained(
    _MODEL_ID,
    model_max_length=2048,
    padding_side="right",
    use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token

In [9]:
conversation_lib = conv_llava_plain  # model_args.version == 'plain'
print(conv_llava_plain)

Conversation(system='', roles=('', ''), messages=(), offset=0, sep_style=<SeparatorStyle.PLAIN: 4>, sep='\n', sep2=None, version='Unknown', skip_next=False)


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
model.get_model().initialize_vision_modules(
    model_args=model_arguments,
    fsdp=None,  # Don't use FSD https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
)
model = model.to(device=device, dtype=torch.float16)
llamma_model = model.get_model()
lm_head = model.lm_head
vision_tower = llamma_model.get_vision_tower()
mm_projector = llamma_model.mm_projector
token_embedder = llamma_model.embed_tokens
image_processor = vision_tower.image_processor   # cropsize of (224, 224) for the "lmsys/vicuna-7b-v1.5" model IDS


In [12]:
# Tune the mm projector only. Disable the gradient on all other blocks
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
    p.requires_grad = True

In [13]:
model.initialize_vision_tokenizer(model_args=model_arguments, tokenizer=tokenizer)

# Dataset definition

In [14]:
_IMAGE_FOLDER = "/jupyter-users-home/tan-2enguyen/datasets/pathology/quilt/quilt_llava/Quilt-LLaVA-Pretrain/quilt_1m"

In [15]:
from datasets import load_dataset
import numpy as np
from pathlib import Path
from datasets.arrow_dataset import Dataset
from transformers import AutoProcessor
from typing import Dict, Tuple

In [16]:
quilt_llava_dataset = load_dataset("wisdomik/Quilt-LLaVA-Pretrain")

In [17]:
ds_with_split = quilt_llava_dataset['train'].train_test_split(test_size=0.1)
train_ds = ds_with_split['train']
test_ds = ds_with_split['test']
print(f"train dataset = {train_ds}")
print(f"test dataset = {test_ds}")

train dataset = Dataset({
    features: ['image', 'conversations', 'id'],
    num_rows: 650995
})
test dataset = Dataset({
    features: ['image', 'conversations', 'id'],
    num_rows: 72333
})


In [18]:
IMAGE_TOKEN_INDEX = -200
IGNORE_INDEX = -100 
DEFAULT_IMAGE_TOKEN = "<image>"

In [19]:
from PIL import Image
from llava.conversation import conv_llava_plain

In [20]:
def _tokenizize_prompt(prompt: str, tokenizer: AutoTokenizer, return_tensor: str) -> Union[torch.Tensor, List[torch.Tensor]]:
    """Tokenize the prompt.
    
    This function will do the following:
        - Separate the prompt into multiple parts, tokenize them separately.
        - For each chunk, get rid of the BOS token if it is the first token, insert the IMAGE_TOKEN_INDEX in between the chunks
        - Convert to tensor from list of token ids if requested.
    Args:
        prompt: The string for the input prompt
        tokenizer: The tokenizer to tokenize the prompt.
        return_tensor (optional): If True, return the tensor of token ids. Otherwise, return the list of token ids. Defaults to True.
        
    Reference:
        tokenizer_image_token - https://github.com/thnguyn2/quilt-llava/blob/7e70fc39f792ac55de010eb37bff0a6d6f491c13/llava/mm_utils.py#L43
    """
    tokenized_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
    input_ids = [tokenizer.bos_token_id]  # Add the BOS token
    
    for chunk in tokenized_chunks[:-1]:
        if chunk[0] == tokenizer.bos_token_id:
            input_ids.extend(chunk[1:])
        else:
            input_ids.extend(chunk)
        input_ids.append(IMAGE_TOKEN_INDEX)
    
    # Add the last part of the prompt
    if tokenized_chunks[-1][0] == tokenizer.bos_token_id:
        input_ids.extend(tokenized_chunks[-1][1:])
    else:
        input_ids.extend(tokenized_chunks[-1])
        
    if return_tensor == "pt":
        return torch.tensor(input_ids, dtype=torch.long)
    return input_ids


In [21]:
from torch.nn.utils.rnn import pad_sequence

In [22]:
class CaptionDataset(Dataset):
    """A dataset for caption generation from a chat dataset.
    
    Args:
        image_folder: The path to the image folder
        tokenizer: The tokenizer to tokenize the caption.
        dataset: The dataset containing the image file name and the caption from the chat dataset
        sequence_length: The length of the sequence to pad.
    
    References:
        llava.train.train import LazySupervisedDataset
    """
    def __init__(self, image_folder: str, tokenizer: AutoTokenizer, image_processor, dataset: Dataset, max_sequence_length: int = 256) -> None:
        self._image_folder = Path(image_folder)
        self._tokenizer = tokenizer
        self._dataset = dataset
        self._image_processor = image_processor
        self._max_sequence_length = max_sequence_length 
    
    def __len__(self) -> int:
        return len(self._dataset)
    
    def __getitem__(self, idx: Union[int, List[int]]) -> Dict[str, torch.Tensor]:
        """Get one text samples from the conversation.
        
        Returns:
        
        Reference:
            LazySupervisedDataset
        """    
        if isinstance(idx, int):
            idx = [idx]
            
        raw_samples = self._dataset[idx]
        images = torch.tensor([np.array(Image.open(str(Path(self._image_folder) / file_name)).convert('RGB')) for file_name in raw_samples['image']])
        
        formated_conversation = self._generate_prompt_from_conversation(raw_samples['conversations'])
        raw_input_ids = self._tokenize_prompt_to_input_ids(formated_conversation)
        
        padded_input_ids = pad_sequence(
            raw_input_ids, 
            batch_first=True, 
            padding_value=self._tokenizer.pad_token_id
        )  # Reference DataCollatorForSupervisedDataset.__call__()
        
        padded_labels = torch.nn.utils.rnn.pad_sequence(
            self._compute_target_labels(input_ids=raw_input_ids, prompts=formated_conversation),
            batch_first=True,
            padding_value=IGNORE_INDEX
        )  #https://github.com/thnguyn2/quilt-llava/blob/7e70fc39f792ac55de010eb37bff0a6d6f491c13/llava/train/train.py#L725
        
        return {
            'image': self._image_processor.preprocess(images,  return_tensors='pt')['pixel_values'].type(torch.float16),
            'input_ids': padded_input_ids.type(torch.int64),
            'target': padded_labels.type(torch.int64),
            'description': [conv[1]['value'] for conv in raw_samples['conversations']],
            'attention_mask': padded_input_ids.ne(self._tokenizer.pad_token_id),  # Ref: https://github.com/thnguyn2/quilt-llava/blob/7e70fc39f792ac55de010eb37bff0a6d6f491c13/llava/train/train.py#L733
        }
    
    @staticmethod
    def _generate_prompt_from_conversation(conversations: List[List[Dict[str, str]]]) -> List[str]:
        """Generate the prompt from the conversation.
        
        Args:
            conversations: A list of dictionarly dictionary containing the conversation between the user and the bot. 
            One sample is for 1 minibatch sample
        
        Returns:
            A string containing the prompt with the image token following the format at
            https://github.com/thnguyn2/quilt-llava/blob/7e70fc39f792ac55de010eb37bff0a6d6f491c13/llava/train/train.py#L568
        """
        return [DEFAULT_IMAGE_TOKEN + conversation[1]['value'] + conv_llava_plain.sep for conversation in conversations]
    
    def _tokenize_prompt_to_input_ids(self, prompts: List[str]) -> List[torch.Tensor]:
        """Tokenizes the prompts.
        
        Returns:
            A list of tensors containing the tokenized ids of all prompts in the training minibatch. Each prompt has 
            the format of <bos><IMAGE_TOKEN_INDEX><caption ids><\n> where IMAGE_TOKEN_INDEX = -200 is the image token index
            
        Reference:
            https://github.com/thnguyn2/quilt-llava/blob/7e70fc39f792ac55de010eb37bff0a6d6f491c13/llava/mm_utils.py#L43
        """
        return [_tokenizize_prompt(prompt=prompt, tokenizer=self._tokenizer, return_tensor='pt') for prompt in prompts]
                
    def _compute_target_labels(self, input_ids: List[torch.Tensor], prompts: List[str]) -> List[torch.Tensor]:
        target_ids = []
        for input_id in input_ids:
            target_id = input_id.clone()
            target_id[:2] = IGNORE_INDEX
            target_ids.append(target_id)
        return target_ids


In [23]:
caption_train_ds = CaptionDataset(
    tokenizer=tokenizer,
    image_folder=_IMAGE_FOLDER,
    dataset=train_ds,
    image_processor=image_processor,
)

val_caption_ds = CaptionDataset(
    tokenizer=tokenizer,
    image_folder=_IMAGE_FOLDER,
    dataset=test_ds,
    image_processor=image_processor,
)


In [77]:
print(f"Dataset length = {len(caption_train_ds)}")
print(caption_train_ds[10].keys())
print(caption_train_ds[10]['target'])

Dataset length = 650995
dict_keys(['image', 'input_ids', 'target', 'description', 'attention_mask'])
tensor([[ -100,  -100,   450,   282,   682,   470, 16749, 10161,   526, 26718,
           310,  4457,   375,  9825, 29875,   542,  3637, 19263,   411, 20364,
           301,   962,   561,  4858,   459,   493, 29891, 29889,    13]])


# Dataloader

In [1]:
from torch.utils.data import DataLoader

In [2]:
train_dl = DataLoader(caption_train_ds, batch_size=3, shuffle=True, num_workers=2)

NameError: name 'caption_train_ds' is not defined

In [26]:
batch = next(iter(train_dl))

  images = torch.tensor([np.array(Image.open(str(Path(self._image_folder) / file_name)).convert('RGB')) for file_name in raw_samples['image']])
  images = torch.tensor([np.array(Image.open(str(Path(self._image_folder) / file_name)).convert('RGB')) for file_name in raw_samples['image']])


In [3]:
print(batch.keys())
print(batch['input_ids'].shape)
print(batch['image'].shape)
print(batch['target'].shape)
print(batch['attention_mask'].shape)
print(batch['input_ids'])

NameError: name 'batch' is not defined

# Process the image input

In [28]:
# Prepare the inputs for the VLM pretraining. LlavaMetaForCausalLM.prepare_inputs_labels_for_multimodal() at https://github.com/thnguyn2/quilt-llava/blob/7e70fc39f792ac55de010eb37bff0a6d6f491c13/llava/model/llava_arch.py#L99
# Get image projection LlavaMetaForCausalLM.encode_images()
def _prepare_input_labels_for_multimodal(
    input_ids: torch.Tensor,
    image: torch.Tensor,
    labels: torch.Tensor,
    attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Prepare the input and labels for the multimodal model.
    
    Args:
        input_ids: The input ids of the text and image tokens. A tensor of shape [B, S] that stores the tokenized ids of the input prompt.
            Each prompt has the format of <bos><IMAGE_TOKEN_INDEX><caption ids><pad_token>...<pad_token> where 
            IMAGE_TOKEN_INDEX = -200 is the token to be replaced by the image.
            S = 1 + 1 + L + P where L is the length of the caption and P is the padding length. The first 2 ones are for the BOS and IMAGE_TOKEN_INDEX.
        image: An image tensor of shape [B, C, H, W] where H, W are the height and width of the image outpt of the image processor.    
        labels: The target labels tokenized vector of shape. A tensor of shape [B, S] that stores the tokenized ids of the input prompt.
            Each prompt has the format of <ignore token><ignore token><caption ids><ignore token>...<ignore token> where 
            ignore token = -100 is the ignore token.
        attention_mask: The attention mask of the input_ids of shape of shape [B, S] that stores the attention mask of the input prompt.
            Each item has a form of [TRUE][TRUE][...all TRUEs for captions ...][...all FALSES for padding...].
        
    Returns:
        An attention mask of shape [B, Sout] where Sout is the length of the output mask. Sout = 1 + Npatch + L + P = (Npatch - 1) + S. Each row has the form of
            [TRUE][...TRUE... for image features][...all TRUEs for captions ...][...all FALSES for padding...].
        Inputs image beddings of shape [B, Sout, D] where D is the dimension of the image embeddings.
        Target labels of shape [B, Sout] where each items contains the tokenized ids of the target labels. It has the shape of
            <ignore token>[...all ignore_tokens for image...]<caption ids><ignore token>...<ignore token>
            
    """
    image_features = mm_projector(vision_tower(image))   # encode images
    batch_size, num_image_feature_token =image_features.shape[:2]
    all_samples_input_embeds = []
    all_samples_labels = []
    for sample_idx, input_id in enumerate(input_ids):
        image_token_start = torch.where(input_id == IMAGE_TOKEN_INDEX)[0].item()
        cur_new_input_embeds = []
        cur_new_labels = []
        image_feature = image_features[sample_idx]
        cur_new_input_embeds.append(token_embedder(input_id[:image_token_start]))
        cur_new_input_embeds.append(image_feature)
        cur_input_ids = input_id[image_token_start+1:]
        cur_new_input_embeds.append(token_embedder(cur_input_ids))
        cur_new_input_embeds = [x.to(image_features.device) for x in cur_new_input_embeds]
        cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
        all_samples_input_embeds.append(cur_new_input_embeds)
        if labels is not None:
            cur_label = labels[sample_idx]
            cur_new_labels.append(cur_label[:image_token_start])
            cur_new_labels.append(torch.full((num_image_feature_token,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
            cur_label = cur_label[image_token_start+1:]
            cur_new_labels.append(cur_label)
            cur_new_labels = torch.cat(cur_new_labels, dim=0)
            all_samples_labels.append(cur_new_labels)
        
    all_samples_input_embeds = torch.stack(all_samples_input_embeds, dim=0)
    all_samples_labels = torch.stack(all_samples_labels, dim=0)
    
    #-1 for the image token.
    attention_mask = torch.cat(
        (
            torch.full((batch_size, num_image_feature_token - 1), True, dtype=attention_mask.dtype, device=attention_mask.device), 
            attention_mask
        ), 
        dim=1
    )
    return attention_mask, all_samples_input_embeds, all_samples_labels

In [29]:
attention_mask, input_embeds, all_labels = _prepare_input_labels_for_multimodal(
    input_ids = batch['input_ids'].to(device=device),
    image = batch['image'].to(device=device),
    labels = batch['target'].to(device=device),
    attention_mask = batch['attention_mask'].to(device=device),
)
print(f"attention_mask.shape = {attention_mask.shape}")
print(f"input_embeds.shape = {input_embeds.shape}")
print(f"all_labels.shape = {all_labels.shape}")

attention_mask.shape = torch.Size([3, 88])
input_embeds.shape = torch.Size([3, 88, 4096])
all_labels.shape = torch.Size([3, 88])


# Pre-training forward pass
Reference: 
- [LlavaLlamaForCausalLM.forward()](https://github.com/thnguyn2/quilt-llava/blob/7e70fc39f792ac55de010eb37bff0a6d6f491c13/llava/model/language_model/llava_llama.py#L56)
- [LlamaModel.forward()]() - Transformer package.

In [30]:
llama_output = llamma_model(
    input_ids=None,
    attention_mask=attention_mask,
    position_ids=None,
    past_key_values=None,
    inputs_embeds=input_embeds,
    use_cache=None,
    output_attentions=False,
    output_hidden_states=False,
    return_dict=True,
)
output_logits = lm_head(llama_output.last_hidden_state)


In [31]:
print(f"llama_output.keys() = {llama_output.keys()}")
print(f"llama_output.last_hidden_state.shape = {llama_output.last_hidden_state.shape}")
print(f"lm_out.shape = {output_logits.shape}")

llama_output.keys() = odict_keys(['last_hidden_state', 'past_key_values'])
llama_output.last_hidden_state.shape = torch.Size([3, 88, 4096])
lm_out.shape = torch.Size([3, 88, 32000])


# Compute the loss function

In [32]:
from torch.nn import CrossEntropyLoss

In [33]:
# The predicted logits is delayed by 1 token compared to the target labels i.e. out_logits[0] is the prediction for target[1], out_logits[1] is the prediction for target[2], etc.
SHIFT_AMOUNT = 1
logits = output_logits[:,:-SHIFT_AMOUNT,:].contiguous()
targets = all_labels[:,SHIFT_AMOUNT:].contiguous()
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = CrossEntropyLoss()(logits, targets)

print(f"logits.shape = {logits.shape}")
print(f"targets.shape = {targets.shape}")
print(f"loss = {loss}")

logits.shape = torch.Size([261, 32000])
targets.shape = torch.Size([261])
loss = 5.765625
