In [None]:
"""
load.py

Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical
IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub).
"""
import json
import os
from pathlib import Path
from typing import List, Optional, Union

from huggingface_hub import hf_hub_download

from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform
from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY

from prismatic.models.vlms import PrismaticVLM
from prismatic.overwatch import initialize_overwatch

import torch

# Initialize Overwatch =>> Wraps `logging.Logger`
overwatch = initialize_overwatch(__name__)


# === HF Hub Repository ===
HF_HUB_REPO = "TRI-ML/prismatic-vlms"


# === Available Models ===
def available_models() -> List[str]:
    return list(MODEL_REGISTRY.keys())


def available_model_names() -> List[str]:
    return list(GLOBAL_REGISTRY.items())


def get_model_description(model_id_or_name: str) -> str:
    if model_id_or_name not in GLOBAL_REGISTRY:
        raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`")

    # Print Description & Return
    print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2))

    return description

model_id_or_path = 
hf_token = None
cache_dir = None,

if os.path.isdir(model_id_or_path):
    overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`")

    # Get paths for `config.json` and pretrained checkpoint
    assert (config_json := run_dir / "config.json").exists(), f"Missing `config.json` for `{run_dir = }`"
    assert (checkpoint_pt := run_dir / "checkpoints" / "latest-checkpoint.pt").exists(), "Missing checkpoint!"
else:
    if model_id_or_path not in GLOBAL_REGISTRY:
        raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`")

    overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub")
    config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir)
    checkpoint_pt = hf_hub_download(
        repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir
    )

overwatch.info("\n\n🚀 [bold green](LATEST) Loading Prismatic VLM for Inference 🚀[/] 🚀\n\n")
# Load Model Config from `config.json`
with open(config_json, "r") as f:
    cfg = json.load(f)
    model_cfg = cfg["model"]

# = Load Individual Components necessary for Instantiating a VLM =
#   =>> Print Minimal Config
overwatch.info(
    f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n"
    f"             Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n"
    f"             LLM Backbone    =>> [bold]{model_cfg['llm_backbone_id']}[/]\n"
    f"             Arch Specifier  =>> [bold]{model_cfg['arch_specifier']}[/]\n"
    f"             Mitigation Strategy      =>> [bold]{cfg.get('mitigation', None)}[/]"
    f"             LoRA      =>> [bold]rank: {cfg.get('lora_rank', None)}, alpha: {cfg.get('lora_alpha', None)}, lora_target_modules: {cfg.get('lora_target_modules', None)}[/]"
    f"             Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
)

In [None]:

# Load Vision Backbone
overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]")
vision_backbone, image_transform = get_vision_backbone_and_transform(
    model_cfg["vision_backbone_id"],
    model_cfg["image_resize_strategy"],
)

# Check if VLM checkpoint does not contain llm base weights. If not, then get_llm_backbone_and_tokenizer must load the default LLama2/Vicuna weights
# model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
# if ("projector" in model_state_dict) and ("llm_backbone" not in model_state_dict):
#     overwatch.info(f"[bold blue]LLM Weights not found![/]", )
#     load_from_hf_anyway = True
# del model_state_dict

# Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers")
llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
    model_cfg["llm_backbone_id"],
    llm_max_length=model_cfg.get("llm_max_length", 2048),
    hf_token=hf_token,
    inference_mode=True,
    cfg=cfg,
)

# Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint; Freezing Weights 🥶")
vlm = PrismaticVLM.from_pretrained(
    checkpoint_pt,
    model_cfg["model_id"],
    vision_backbone,
    llm_backbone,
    arch_specifier=model_cfg["arch_specifier"],
)


In [1]:
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from transformers import LlamaForCausalLM

#hf_hub_path = '/localdisk/ssrivas9/prismatic-vlms/runs/lora-stage-0-after-llava-test/'
hf_token = None
# llm_config = AutoConfig.from_pretrained("lmsys/vicuna-7b-v1.5", token=hf_token)
# llm_model = AutoModelForCausalLM.from_pretrained(
#         "lmsys/vicuna-7b-v1.5",
#         config=llm_config,  # Pass the modified configuration
#         token=hf_token,
#         load_in_4bit=True 
#     )
llm_config = AutoConfig.from_pretrained("lmsys/vicuna-7b-v1.5", token=hf_token, 
                 #quantization_config=self.bnb_config if self.mitigation=='qlora' else None, #
                    #torch_dtype = torch.bfloat16 if self.mitigation=='qlora' else None
            )
llm_model = LlamaForCausalLM.from_pretrained(
        "lmsys/vicuna-7b-v1.5",
        config=llm_config, 
        token=hf_token,
        load_in_4bit=True 
    )

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.70s/it]


In [2]:

#llm_model = LlamaForCausalLM(config=llm_config) #LlamaForCausalLM._from_config(llm_config)

tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5", model_max_length=2048, token=hf_token)
#tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")

tokenizer.add_special_tokens({"pad_token": "<PAD>"})
llm_model.config.pad_token_id = tokenizer.pad_token_id
llm_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
# Increae the 
from peft import (
    LoraConfig,
    PrefixTuningConfig, #Prefix-Tuning
    PromptEncoderConfig, #P-Tuning
    PromptTuningConfig, # Prompt Tuning
    IA3Config,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training,
    set_peft_model_state_dict,
    prepare_model_for_kbit_training
)
lora_r=16
lora_target_modules= 'all-linear'#["q_proj", "v_proj"]
lora_alpha=8
lora_dropout=0.05
# Add PEFT LoRA on top of this.
llm_model = prepare_model_for_kbit_training(llm_model)
loraconfig = LoraConfig(
    r=lora_r, lora_alpha=lora_alpha, target_modules=lora_target_modules,
    lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM"
)
llm_model = get_peft_model(llm_model, loraconfig)

In [3]:
import torch
model_path = "/localdisk/ssrivas9/prismatic-vlms/runs/qlora-stage-0-after-llava-test/checkpoints/latest-checkpoint.pt"
#model_path = "/localdisk/ssrivas9/prismatic-vlms/runs/lora-stage-0-after-llava-vqav2/checkpoints/latest-checkpoint.pt"
model_state_dict = torch.load(model_path, map_location="cpu")

In [4]:
import torch

#model_path = "/localdisk/ssrivas9/prismatic-vlms/runs/qlora-stage-0-after-llava-test/checkpoints/latest-checkpoint.pt"

model_state_dicts = model_state_dict['model']
model_state_dicts.keys()

dict_keys(['projector', 'llm_backbone'])

In [5]:
model_state_dicts['llm_backbone'].keys()

odict_keys(['llm.base_model.model.model.embed_tokens.weight', 'llm.base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight', 'llm.base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.absmax', 'llm.base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.quant_map', 'llm.base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight.quant_state.bitsandbytes__fp4', 'llm.base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight', 'llm.base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight', 'llm.base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight', 'llm.base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight.absmax', 'llm.base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight.quant_map', 'llm.base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight.quant_state.bitsandbytes__fp4', 'llm.base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight', 'llm.base_model.mo

In [6]:
new_model_state_dict = {}
for k, v in model_state_dicts['llm_backbone'].items():
    if k.startswith('llm.'):
        new_model_state_dict[k.replace('llm.', '')] = v
    else:
        new_model_state_dict[k] = v

In [7]:
llm_model.load_state_dict(new_model_state_dict)

<All keys matched successfully>

In [20]:
for key, param in model_state_dict['model'].items():
    for mkey in model_state_dicts:
        if key.startswith(mprefix := f"{mkey}."):
            model_state_dicts[mkey][key.removeprefix(mprefix)] = param

In [21]:
# Replace the llm.base_model... with just base_model..
new_model_state_dict = {}
for k, v in model_state_dict['model']['llm_backbone'].items():
    new_model_state_dict[k.replace('llm.','')] = v

In [22]:
llm_model.load_state_dict(new_model_state_dict)

<All keys matched successfully>

: 

In [None]:
llm_model