# Phi-3-V-3.8B-Inference-Example

In [None]:
!git clone https://github.com/mbzuai-oryx/LLaVA-pp.git
%cd LLaVA-pp
!git submodule update --init --recursive

In [None]:
!ls

In [None]:
! cp Phi-3-V/train.py LLaVA/llava/train/train.py
! cp Phi-3-V/llava_phi3.py LLaVA/llava/model/language_model/llava_phi3.py
! cp Phi-3-V/builder.py LLaVA/llava/model/builder.py
! cp Phi-3-V/model__init__.py LLaVA/llava/model/__init__.py
! cp Phi-3-V/main__init__.py LLaVA/llava/__init__.py
! cp Phi-3-V/conversation.py LLaVA/llava/conversation.py

In [None]:
%cd LLaVA
! pip install --upgrade pip
! pip install -e .
! pip install git+https://github.com/huggingface/transformers@a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3

! export PYTHONPATH="./:$PYTHONPATH"

In [None]:
! git lfs install
! git clone https://huggingface.co/MBZUAI/LLaVA-Phi-3-mini-4k-instruct

In [None]:
from llava.utils import disable_torch_init

In [None]:
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import requests
from PIL import Image
from io import BytesIO
import re
from llava.utils import disable_torch_init

In [None]:
def image_parser(args):
    out = args.image_file.split(args.sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

In [None]:
# Load Model

disable_torch_init()

model_path = "LLaVA-Phi-3-mini-4k-instruct"
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)

In [None]:
# Create the prompt in Phi3 Format
qs = "Describe the image in detail"
conv_mode = "llava_v0"

image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
    if model.config.mm_use_im_start_end:
        qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
    else:
        qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
    if model.config.mm_use_im_start_end:
        qs = image_token_se + "\n" + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

print(prompt)

In [None]:
# Download and display the image

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

! wget http://images.cocodataset.org/val2017/000000281759.jpg

img = mpimg.imread('000000281759.jpg')
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
# Phi3-V 모델에 넣기 위한 전처리

image_name = "000000281759.jpg"
image_files = [image_name]
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
    images,
    image_processor,
    model.config
).to(model.device, dtype=torch.float16)

input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .cuda()
)

In [None]:
# Generate and display the response

temperature = 0.2
top_p = 0.7
num_beams = 1
max_new_tokens = 512

with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=images_tensor,
        image_sizes=image_sizes,
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        top_p=top_p,
        num_beams=num_beams,
        max_new_tokens=max_new_tokens,
        use_cache=True,
    )

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
outputs = outputs.replace("<|end|>", "").strip()
print(f"\n{outputs}\n")