In [1]:
%load_ext autoreload
%autoreload 2
%autosave 2

Autosaving every 2 seconds


In [2]:
import os
import re
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List

import torch
from torch import nn

import transformers
import tokenizers

from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer

from llava import conversation as conversation_lib

from llava.mm_utils import tokenizer_image_token
# from llava.train.gcloud_rsync_callback import GCloudRsyncCallback
from llava.train.wandb_nan_alert_callback import NanInfAlertWandbCallback
from llava.model import LlavaLlamaForCausalLM, LlavaMptForCausalLM
# , \
#     LlavaMistralForCausalLM, LlavaCohereForCausalLM, LlavaMixtralForCausalLM

from PIL import Image

from packaging import version

from ezcolorlog import root_logger as logger


  from .autonotebook import tqdm as notebook_tqdm


06:04:49 [32m__init__.py:40 [I][0m → Notebook logger initialized.


In [3]:

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    version: Optional[str] = field(default="v0")
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default=None)
    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='linear')
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=True)
    mm_patch_merge_type: Optional[str] = field(default='flat')
    mm_vision_select_feature: Optional[str] = field(default="patch")


@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = False
    image_folder: Optional[str] = field(default=None)
    image_aspect_ratio: str = 'square'
    image_token_len: int = 576  # (336 // 14)**2
    image_position: int = 35  # depends on v1 conv

 
@dataclass
# class TrainingArguments(transformers.TrainingArguments):
class TrainingArguments:
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    freeze_mm_mlp_adapter: bool = field(default=False)
    unfreeze_mm_vision_tower: bool = field(default=False)
    mpt_attn_impl: Optional[str] = field(default="triton")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
            "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=16,
        metadata={"help": "How many bits to use."}
    )
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"
    mm_projector_lr: Optional[float] = None
    group_by_modality_length: bool = field(default=False)
    mm_vision_tower_lr: Optional[float] = None

    # GCSFS
    gcp_project: Optional[str] = field(default=None)
    """Can also set GCP_PROJECT environment variable."""
    gcs_output_dir: Optional[str] = field(default=None)
    """gs://<bucket>/<prefix>"""

In [4]:

parser = transformers.HfArgumentParser(
    (ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses([
    "--model_name_or_path", "lmsys/vicuna-7b-v1.5",
    "--version", "v1",
    "--data_path", "/mnt/disks/storage/data/finetune_data/5565kL.jsonl",
    "--image_folder", "/mnt/disks/storage/data/finetune_data",
        "--vision_tower", "mae-vit-l-16",
        "--image_token_len", "196",
        # "--vision_tower", "apple/DFN2B-CLIP-ViT-L-14",
        # "--image_token_len", "256",
        # "--vision_tower", "timm/ViT-SO400M-14-SigLIP-384",
        # "--image_token_len", "729",
        # "--vision_tower", "timm/ViT-SO400M-14-SigLIP",
        # "--image_token_len", "256",
    # "--vision_tower", "openai/clip-vit-large-patch14-336",
    # "--image_token_len", "576",
    "--mm_projector_type", "mlp2x_gelu",
    "--mm_vision_select_layer", "-2",
    "--mm_use_im_start_end", "False",
    "--mm_use_im_patch_token", "False",
    "--image_aspect_ratio", "pad",
    "--group_by_modality_length", "True",
    # "--bf16", "False",
    # "--output_dir", "./checkpoints/dummy",
    # "--num_train_epochs", "1",
    # "--per_device_train_batch_size", "16",
    # "--per_device_eval_batch_size", "4",
    # "--gradient_accumulation_steps", "1",
    # "--evaluation_strategy", "no",
    # "--save_strategy", "steps",
    # "--save_steps", "100000",
    # "--save_total_limit", "1",
    # "--learning_rate", "2e-5",
    # "--weight_decay", "0.",
    # "--warmup_ratio", "0.03",
    # "--lr_scheduler_type", "cosine",
    # "--logging_steps", "1",
    # "--tf32", "False",
    "--model_max_length", "2048",
    # "--gradient_checkpointing", "True",
    # "--dataloader_num_workers", "4",
    "--lazy_preprocess", "True",
])

model_args, data_args, training_args

(ModelArguments(model_name_or_path='lmsys/vicuna-7b-v1.5', version='v1', freeze_backbone=False, tune_mm_mlp_adapter=False, vision_tower='mae-vit-l-16', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch'),
 DataArguments(data_path='/mnt/disks/storage/data/finetune_data/5565kL.jsonl', lazy_preprocess=True, is_multimodal=False, image_folder='/mnt/disks/storage/data/finetune_data', image_aspect_ratio='pad', image_token_len=196, image_position=35),
 TrainingArguments(cache_dir=None, optim='adamw_torch', remove_unused_columns=False, freeze_mm_mlp_adapter=False, unfreeze_mm_vision_tower=False, mpt_attn_impl='triton', model_max_length=2048, double_quant=True, quant_type='nf4', bits=16, lora_enable=False, lora_r=64, lora_alpha=16, lora_dropout=0.05, lora_weight_path='', lora_bias='none', mm_projector_lr=None, group_by_modality_length=True, mm

In [5]:

# copy image_token_len and image_position to model_args
model_args.image_token_len = data_args.image_token_len
model_args.image_position = data_args.image_position

# Assuming model_args.model_name_or_path is a string that includes the model size
model_name = model_args.model_name_or_path

bnb_model_from_pretrained_args = {}

logger.warning(f"Vision tower, loading LlavaLlamaForCausalLM: {model_args.model_name_or_path}")
model = LlavaLlamaForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    do_sample=True,
    torch_dtype=(None),
    **bnb_model_from_pretrained_args
)

model.config.use_cache = False
model.generation_config.do_sample = True

06:05:01 [33m1205543874.py:10 [W][0m → Vision tower, loading LlavaLlamaForCausalLM: lmsys/vicuna-7b-v1.5


You are using a model of type llama to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.20it/s]


In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=False
)

In [7]:
def log_rank0(log):
    logger.info(log, stacklevel=2)

def print_rank0(*args):
    log = ""
    for arg in args:
        log += str(arg)
    log_rank0(log)

print_rank0("tokenizer id before operation is ", tokenizer.pad_token_id)

log_rank0(f"Model Conv Version: {model_args.version}")
log_rank0(f"Default conversation version: {conversation_lib.default_conversation.version}")

tokenizer.pad_token = tokenizer.unk_token
if model_args.version in conversation_lib.conv_templates:
    conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
    logger.warning(f"Conversation version {model_args.version} not found. Using default `vicuna_v1`")
    conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
log_rank0(f"Default conversation version: {conversation_lib.default_conversation.version}")

06:05:06 [32m1421403862.py:8 [I][0m → tokenizer id before operation is 0
06:05:06 [32m1421403862.py:12 [I][0m → Model Conv Version: v1
06:05:06 [32m1421403862.py:13 [I][0m → Default conversation version: v1
06:05:06 [32m1421403862.py:21 [I][0m → Default conversation version: v1


In [8]:
log_rank0("Initializing vision modules...")
model_args.unfreeze_mm_vision_tower = training_args.unfreeze_mm_vision_tower
model.get_model().initialize_vision_modules(
    model_args=model_args,
)
model.config.unfreeze_mm_vision_tower = training_args.unfreeze_mm_vision_tower
vision_tower = model.get_vision_tower()
# vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

# if not training_args.unfreeze_mm_vision_tower:
#     vision_tower.to(dtype=torch.bfloat16, device=training_args.device)
# else:
#     vision_tower.to(device=training_args.device)
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True

model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length

model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
    log_rank0("Tuning multimodal mlp adapter only...")
    model.requires_grad_(False)
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = True

model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
    log_rank0("Freezing multimodal mlp adapter...")
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = False
if training_args.unfreeze_mm_vision_tower:
    for p in model.get_model().get_vision_tower().parameters():
        p.requires_grad = True

model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
log_rank0("Vision modules initialized.")

06:05:06 [32m2108420027.py:1 [I][0m → Initializing vision modules...
06:05:06 [32mbuilder.py:41 [I][0m → Loading **MAE** Vision Tower: mae-vit-l-16
06:05:11 [32m_builder.py:186 [I][0m → Loading pretrained weights from Hugging Face hub (timm/vit_large_patch16_224.mae)
06:05:34 [32m_hub.py:180 [I][0m → [timm/vit_large_patch16_224.mae] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
06:05:34 [32m2108420027.py:43 [I][0m → Vision modules initialized.


In [9]:
from llava.train.train_fsdp import make_supervised_data_module

log_rank0("Configuring data module...")
assert model.get_model().get_vision_tower().num_patches == data_args.image_token_len, (model.get_model().get_vision_tower().num_patches, data_args.image_token_len)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)


In [10]:
data_module

{'train_dataset': <llava.train.train_fsdp.LazySupervisedDataset at 0x7f03e41964a0>,
 'eval_dataset': None,
 'data_collator': DataCollatorForSupervisedDataset(tokenizer=LlamaTokenizer(name_or_path='lmsys/vicuna-7b-v1.5', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
 	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 }, image_token_len=196, image_position=35)}

In [11]:
train_dataset = data_module["train_dataset"]
train_dataset

<llava.train.train_fsdp.LazySupervisedDataset at 0x7f03e41964a0>

In [12]:
item = train_dataset[0]
item.keys()

dict_keys(['input_ids', 'labels', 'image'])

In [13]:
item["input_ids"].shape, item["labels"].shape, item["image"].shape

(torch.Size([136]), torch.Size([136]), torch.Size([3, 224, 224]))

In [14]:
data = []
for i in range(10):
    data.append(train_dataset[i])
len(data)

10

In [15]:
data_collator = data_module["data_collator"]
batch = data_collator(data)
batch.keys()

dict_keys(['input_ids', 'labels', 'attention_mask', 'position_ids', 'images'])

In [16]:
for k, v in batch.items():
    print(k, v.shape)

input_ids torch.Size([10, 2048])
labels torch.Size([10, 2048])
attention_mask torch.Size([10, 2048])
position_ids torch.Size([10, 2048])
images torch.Size([10, 3, 224, 224])


In [17]:
input_ids, labels, attention_mask, position_ids, images = batch["input_ids"], batch["labels"], batch["attention_mask"], batch["position_ids"], batch["images"]

In [18]:
type(images)

torch.Tensor

In [19]:
concat_images = torch.cat([image for image in images], dim=0)
concat_images.shape

torch.Size([30, 224, 224])

In [20]:
self = model
image_features1 = self.get_model().get_vision_tower()(images)
print(image_features1.shape)
image_features = self.get_model().mm_projector(image_features1)
print(image_features.shape)

torch.Size([10, 196, 1024])
torch.Size([10, 196, 4096])


In [21]:
# from llava.model.multimodal_encoder.siglip_encoder import SiglipVisionTower

# # SiglipVisionTower.__base__.__base__

# issubclass(model.get_model().get_vision_tower().__class__, SiglipVisionTower.__base__.__base__)

In [49]:
model.model.vision_tower.num_patches

729

In [None]:
model.config

LlavaConfig {
  "_name_or_path": "lmsys/vicuna-7b-v1.5",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "do_sample": true,
  "eos_token_id": 2,
  "freeze_mm_mlp_adapter": false,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "image_aspect_ratio": "pad",
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "mm_hidden_size": 1152,
  "mm_patch_merge_type": "flat",
  "mm_projector_lr": null,
  "mm_projector_type": "mlp2x_gelu",
  "mm_use_im_patch_token": false,
  "mm_use_im_start_end": false,
  "mm_vision_select_feature": "patch",
  "mm_vision_select_layer": -2,
  "mm_vision_tower": "timm/ViT-SO400M-14-SigLIP",
  "mm_vision_tower_lr": null,
  "model_type": "llava_llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000