LISA
====

**LISA: Reasoning Segmentation via Large Language Model**

 * Paper: https://arxiv.org/pdf/2308.00692

![LISA overview](../assets/lisa_overview.png)

```bash
git clone https://github.com/dvlab-research/LISA.git lisa_repo

# ubuntu 20.04, nvidia cuda 12.2
# modeify requirements:
# comment #--extra-index-url https://download.pytorch.org/whl/cu117
# torch==2.1.2
```

 * **Installation:**
```bash
conda create -n lisa python=3.10 -y
conda activate lisa

cv lisa_repo/
pip install -r requirements.txt
# note: cannot install flash-attn on ubuntu 20.04
```

In [1]:
import os
import sys

from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor

sys.path.append("lisa_repo")
from model.LISA import LISAForCausalLM
from model.llava import conversation as conversation_lib
from model.llava.mm_utils import tokenizer_image_token
from model.segment_anything.utils.transforms import ResizeLongestSide
from utils.utils import (
    DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
    DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set the variables according to this argument parser defaults:

VERSION = "xinlai/LISA-13B-llama2-v1"
OUTPUT_PATH = "./vis_output"
PRECISION = "bf16"
IMAGE_SIZE = 1024
MODEL_MAX_LENGTH = 512
LORA_R = 8
VISION_TOWER = "openai/clip-vit-large-patch14"
LOCAL_RANK = 0
LOAD_IN_8BIT = True  # to save memory
LOAD_IN_4BIT = False
USE_MM_START_END = True
CONV_TYPE = "llava_v1"  # choices: llava_v1 llava_llama_2


In [3]:
# Create model
tokenizer = AutoTokenizer.from_pretrained(
    VERSION,
    cache_dir=None,
    model_max_length=MODEL_MAX_LENGTH,
    padding_side="right",
    use_fast=False,
)

tokenizer.pad_token = tokenizer.unk_token
seg_token_idx = tokenizer(
    "[SEG]", add_special_tokens=False
).input_ids[0]


torch_dtype = torch.float32
if PRECISION == "bf16":
    torch_dtype = torch.bfloat16
elif PRECISION == "fp16":
    torch_dtype = torch.half

kwargs = {"torch_dtype": torch_dtype}
if LOAD_IN_4BIT:
    kwargs.update(
        {
            "torch_dtype": torch.half,
            "load_in_4bit": True,
            "quantization_config": BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                llm_int8_skip_modules=["visual_model"],
            ),
        }
    )
elif LOAD_IN_8BIT:
    kwargs.update(
        {
            "torch_dtype": torch.half,
            "quantization_config": BitsAndBytesConfig(
                llm_int8_skip_modules=["visual_model"],
                load_in_8bit=True,
            ),
        }
    )

model = LISAForCausalLM.from_pretrained(
    VERSION, low_cpu_mem_usage=True,
    vision_tower=VISION_TOWER,
    seg_token_idx=seg_token_idx,
    **kwargs
)

model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype)

if PRECISION == "bf16":
    model = model.bfloat16().cuda()
elif PRECISION == "fp32":
    model = model.float().cuda()

vision_tower = model.get_model().get_vision_tower()
vision_tower.to(device=LOCAL_RANK)

clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
transform = ResizeLongestSide(IMAGE_SIZE)

model.eval();


You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
Loading checkpoint shards: 100%|██████████| 3/3 [00:13<00:00,  4.47s/it]
You shouldn't move a model that is dispatched using accelerate hooks.


In [4]:
print("DEFAULT_IM_END_TOKEN:", DEFAULT_IM_END_TOKEN)
print("DEFAULT_IM_START_TOKEN:", DEFAULT_IM_START_TOKEN)
print("DEFAULT_IMAGE_TOKEN:", DEFAULT_IMAGE_TOKEN)
print("IMAGE_TOKEN_INDEX:", IMAGE_TOKEN_INDEX)

DEFAULT_IM_END_TOKEN: <im_end>
DEFAULT_IM_START_TOKEN: <im_start>
DEFAULT_IMAGE_TOKEN: <image>
IMAGE_TOKEN_INDEX: -200


In [5]:
PIXEL_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
PIXEL_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)

def preprocess(
        x,
        pixel_mean=PIXEL_MEAN, pixel_std=PIXEL_STD,
        img_size=1024,
    ) -> torch.Tensor:
    """Normalize pixel values and pad to a square input."""
    # Normalize colors
    x = (x - pixel_mean) / pixel_std
    # Pad
    h, w = x.shape[-2:]
    padh = img_size - h
    padw = img_size - w
    x = F.pad(x, (0, padw, 0, padh))
    return x
