In [5]:
import os
os.chdir('/fsx/wpq/github/metasummer2024/external/LLaVA') # jupyter lab moving ipynb does not change !pwd properly.
os.environ['WANDB_DIR'] = '/fsx/wpq/github/metasummer2024/cache'
os.makedirs(os.environ['WANDB_DIR'], exist_ok=True)
os.environ['WANDB_PROJECT'] = 'meta'

import re
import pathlib
import torch
import transformers

from llava import conversation as conversation_lib
from llava.model import *
from llava.train.llava_trainer import LLaVATrainer

from llava.train.train import (
    ModelArguments, DataArguments, TrainingArguments,
    maybe_zero_3, get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, get_mm_adapter_state_maybe_zero_3,
    find_all_linear_names, safe_save_model_for_hf_trainer,
    smart_tokenizer_and_embedding_resize,
    _tokenize_fn,
    _mask_targets,
    _add_speaker_and_signal,
    preprocess_multimodal,
    preprocess,
    LazySupervisedDataset,
    DataCollatorForSupervisedDataset,
    make_supervised_data_module,
)

In [6]:
attn_implementation = 'flash_attention_2'

# ## pretrain
# model_name_or_path = './results/baselines/lmsys/vicuna-7b-v1.5'
# data_path = './data/liuhaotian/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
# image_folder = './data/liuhaotian/LLaVA-Pretrain/images'
# vision_tower = './results/baselines/openai/clip-vit-large-patch14-336'
# mm_projector_type = 'mlp2x_gelu'
# train_size = 96
# output_dir = './results/pretrain/llava-v1.5-7b'


## finetune 
model_name_or_path = './results/baselines/lmsys/vicuna-7b-v1.5'
pretrain_mm_mlp_adapter = './results/pretrain/llava-v1.5-7b/mm_projector.bin'
data_path = './data/liuhaotian/LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
image_folder = './data/'
vision_tower = './results/baselines/openai/clip-vit-large-patch14-336'
mm_projector_type = 'mlp2x_gelu'
train_size = 96
output_dir = './results/sft/llava-v1.5-7b'


cmd = f"""
    --deepspeed ./scripts/zero2.json \
    --model_name_or_path {model_name_or_path} \
    --version plain \
    --data_path {data_path} \
    --image_folder {image_folder} \
    --vision_tower {vision_tower} \
    --mm_projector_type {mm_projector_type} \
    --tune_mm_mlp_adapter True \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir {output_dir} \
    {"--train_size " + str(train_size) if train_size else ""} \
    --num_train_epochs 1 \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 24000 \
    --save_total_limit 1 \
    --learning_rate 1e-3 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb
"""
import shlex
args = shlex.split(cmd)


parser = transformers.HfArgumentParser(
    (ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(args)
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))

model_args, data_args, training_args


 \
	--deepspeed ./scripts/zero2.json \
	--model_name_or_path ./results/baselines/lmsys/vicuna-7b-v1.5 \
	--version plain \
	--data_path ./data/liuhaotian/LLaVA-Instruct-150K/llava_v1_5_mix665k.json \
	--image_folder ./data/ \
	--vision_tower ./results/baselines/openai/clip-vit-large-patch14-336 \
	--mm_projector_type mlp2x_gelu \
	--tune_mm_mlp_adapter True \
	--mm_vision_select_layer -2 \
	--mm_use_im_start_end False \
	--mm_use_im_patch_token False \
	--bf16 True \
	--output_dir ./results/sft/llava-v1.5-7b \
	--train_size 96 \
	--num_train_epochs 1 \
	--per_device_train_batch_size 32 \
	--per_device_eval_batch_size 4 \
	--gradient_accumulation_steps 1 \
	--evaluation_strategy "no" \
	--save_strategy "steps" \
	--save_steps 24000 \
	--save_total_limit 1 \
	--learning_rate 1e-3 \
	--weight_decay 0. \
	--warmup_ratio 0.03 \
	--lr_scheduler_type "cosine" \
	--logging_steps 1 \
	--tf32 True \
	--model_max_length 2048 \
	--gradient_checkpointing True \
	--dataloader_num_workers 4 \
	--laz

(ModelArguments(model_name_or_path='./results/baselines/lmsys/vicuna-7b-v1.5', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='./results/baselines/openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch'),
 DataArguments(data_path='./data/liuhaotian/LLaVA-Instruct-150K/llava_v1_5_mix665k.json', lazy_preprocess=True, is_multimodal=False, image_folder='./data/', image_aspect_ratio='square', train_size=96),

In [None]:
import json

data_path = './data/liuhaotian/LLaVA-Instruct-150K/llava_v1_5_mix665k.json'


image_folder = './data/'

list_data_dict = json.load(open(data_path, "r"))
print(f'#examples: {len(list_data_dict)}')

In [None]:

example = list_data_dict[0]

def file_missing(example):
    if 'image' in example:
        image_file = example['image']
        image_path = os.path.join(image_folder, image_file)
        return not os.path.isfile(image_path)
    else:
        # text-only example, assume file is not missing.
        return False

list_data_dict_file_missing = list(filter(file_missing, list_data_dict))


In [None]:

bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
    from transformers import BitsAndBytesConfig
    bnb_model_from_pretrained_args.update(dict(
        device_map={"": training_args.device},
        load_in_4bit=training_args.bits == 4,
        load_in_8bit=training_args.bits == 8,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=training_args.bits == 4,
            load_in_8bit=training_args.bits == 8,
            llm_int8_skip_modules=["mm_projector"],
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=training_args.double_quant,
            bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
        )
    ))

bnb_model_from_pretrained_args

In [None]:

if model_args.vision_tower is not None:
    if 'mpt' in model_args.model_name_or_path:
        config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
        config.attn_config['attn_impl'] = training_args.mpt_attn_impl
        model = LlavaMptForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=training_args.cache_dir,
            **bnb_model_from_pretrained_args
        )
    else:
        model = LlavaLlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            attn_implementation=attn_implementation,
            torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
            **bnb_model_from_pretrained_args
        )
else:
    model = transformers.LlamaForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        attn_implementation=attn_implementation,
        torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
        **bnb_model_from_pretrained_args
    )
model.config.use_cache = False

model


In [None]:


if model_args.freeze_backbone:
    model.model.requires_grad_(False)

if training_args.bits in [4, 8]:
    from peft import prepare_model_for_kbit_training
    model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)

if training_args.gradient_checkpointing:
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

if training_args.lora_enable:
    from peft import LoraConfig, get_peft_model
    lora_config = LoraConfig(
        r=training_args.lora_r,
        lora_alpha=training_args.lora_alpha,
        target_modules=find_all_linear_names(model),
        lora_dropout=training_args.lora_dropout,
        bias=training_args.lora_bias,
        task_type="CAUSAL_LM",
    )
    if training_args.bits == 16:
        if training_args.bf16:
            model.to(torch.bfloat16)
        if training_args.fp16:
            model.to(torch.float16)
    rank0_print("Adding LoRA adapters...")
    model = get_peft_model(model, lora_config)



In [None]:

if 'mpt' in model_args.model_name_or_path:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right"
    )
else:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )

if model_args.version == "v0":
    if tokenizer.pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token="[PAD]"),
            tokenizer=tokenizer,
            model=model,
        )
elif model_args.version == "v0.5":
    tokenizer.pad_token = tokenizer.unk_token
else:
    tokenizer.pad_token = tokenizer.unk_token
    if model_args.version in conversation_lib.conv_templates:
        conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
    else:
        conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]

tokenizer

In [None]:

if model_args.vision_tower is not None:
    model.get_model().initialize_vision_modules(
        model_args=model_args,
        fsdp=training_args.fsdp
    )
    
    vision_tower = model.get_vision_tower()
    vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

    data_args.image_processor = vision_tower.image_processor
    data_args.is_multimodal = True

    model.config.image_aspect_ratio = data_args.image_aspect_ratio
    model.config.tokenizer_padding_side = tokenizer.padding_side
    model.config.tokenizer_model_max_length = tokenizer.model_max_length

    # wpq: not sure why set `tune_mm_mlp_adapter` for `training_args`.
    model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter

    # wpq:
    # pretrain: tune_mm_mlp_adapter=True, freeze LLM, optimize adapter only.
    # sft: tune_mm_mlp_adapter=False, optimize LLM & adapter jointly.
    #
    # 
    
    if model_args.tune_mm_mlp_adapter:
        model.requires_grad_(False)
        for p in model.get_model().mm_projector.parameters():
            p.requires_grad = True

    model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
    if training_args.freeze_mm_mlp_adapter:
        for p in model.get_model().mm_projector.parameters():
            p.requires_grad = False

    if training_args.bits in [4, 8]:
        model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)

    model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
    model.config.mm_projector_lr = training_args.mm_projector_lr
    training_args.use_im_start_end = model_args.mm_use_im_start_end
    model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
    model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
    

In [None]:


if training_args.bits in [4, 8]:
    from peft.tuners.lora import LoraLayer
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if training_args.bf16:
                module = module.to(torch.bfloat16)
        if 'norm' in name:
            module = module.to(torch.float32)
        if 'lm_head' in name or 'embed_tokens' in name:
            if hasattr(module, 'weight'):
                if training_args.bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)


In [None]:

data_module = make_supervised_data_module(tokenizer=tokenizer,
                                          data_args=data_args)
data_module

In [None]:
data_module['train_dataset'].list_data_dict

In [None]:
ds = data_module['train_dataset']
data = ds[6]
import matplotlib.pyplot as plt
plt.imshow(data['image'].numpy().transpose((1,2,0)))
print(tokenizer.decode(data['input_ids'].tolist()[2:]))

In [None]:
trainer = LLaVATrainer(model=model,
                tokenizer=tokenizer,
                args=training_args,
                **data_module)

In [None]:

if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
    trainer.train(resume_from_checkpoint=True)
else:
    trainer.train()
trainer.save_state()

In [None]:

model.config.use_cache = True

if training_args.lora_enable:
    state_dict = get_peft_state_maybe_zero_3(
        model.named_parameters(), training_args.lora_bias
    )
    non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
        model.named_parameters()
    )
    if training_args.local_rank == 0 or training_args.local_rank == -1:
        model.config.save_pretrained(training_args.output_dir)
        model.save_pretrained(training_args.output_dir, state_dict=state_dict)
        torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
    safe_save_model_for_hf_trainer(trainer=trainer,
                                   output_dir=training_args.output_dir)
