In [1]:
import os

from omegaconf import OmegaConf
from functools import partial
from PIL import Image
import torch

from open_flamingo import create_model_and_transforms 
from open_flamingo.train.any_res_data_utils import process_images

  from .autonotebook import tqdm as notebook_tqdm


## Inference code

In [10]:
# Set model configs.
model_ckpt="path/to/your/local/checkpoint.pt"
cfg = dict(
    model_family = 'xgenmm_v1',
    lm_path = 'microsoft/Phi-3-mini-4k-instruct',
    vision_encoder_path = 'google/siglip-so400m-patch14-384',
    vision_encoder_pretrained = 'google',
    num_vision_tokens = 128,
    image_aspect_ratio = 'anyres',
    anyres_patch_sampling = True,
    anyres_grids = [(1,2),(2,1),(2,2),(3,1),(1,3)],
    ckpt_pth = model_ckpt,
)
cfg = OmegaConf.create(cfg)

additional_kwargs = {
    "num_vision_tokens": cfg.num_vision_tokens,
    "image_aspect_ratio": cfg.image_aspect_ratio,
    "anyres_patch_sampling": cfg.anyres_patch_sampling,
}

# Initialize the model.
model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path=cfg.vision_encoder_path,
    clip_vision_encoder_pretrained=cfg.vision_encoder_pretrained,
    lang_model_path=cfg.lm_path,
    tokenizer_path=cfg.lm_path,
    model_family=cfg.model_family,
    **additional_kwargs)

ckpt = torch.load(cfg.ckpt_pth)["model_state_dict"]
model.load_state_dict(ckpt, strict=True)
torch.cuda.empty_cache()
model = model.eval().cuda()

base_img_size = model.base_img_size
anyres_grids = []
for (m,n) in cfg.anyres_grids:
    anyres_grids.append([base_img_size*m, base_img_size*n])
model.anyres_grids = anyres_grids

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.60it/s]


xgenmm_v1 model initialized with 3,931,031,619 trainable parameters
Vision encoder: 0 trainable parameters
Vision tokenizer: 109,901,568 trainable parameters
Language model: 3,821,130,051 trainable parameters
Vision encoder: 428,225,600 parameters
Vision tokenizer: 109,901,568 parameters
Language model: 3,821,130,051 parameters


In [11]:
# Preprocessing utils.

image_proc = partial(process_images, image_processor=image_processor, model_cfg=cfg)

def apply_prompt_template(prompt, cfg):
    if 'Phi-3' in cfg.lm_path:
        s = (
                '<|system|>\nA chat between a curious user and an artificial intelligence assistant. '
                "The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
                f'<|user|>\n{prompt}<|end|>\n<|assistant|>\n'
            )
    else:
        raise NotImplementedError
    return s

In [12]:
# Prep image input.
image_path_1 = 'example_images/image-1.jpeg'
image_path_2 = 'example_images/image-2.jpeg'

image_1 = Image.open(image_path_1).convert('RGB')
image_2 = Image.open(image_path_2).convert('RGB')
images = [image_1, image_2]
image_size = [image_1.size, image_2.size]
image_size = [image_size]
vision_x = [image_proc([img]) for img in images]
vision_x = [vision_x]

In [13]:
# Prep language input.
prompt = "Look at this image <image> and this image <image>. What is in the second image?"
prompt = apply_prompt_template(prompt, cfg)
lang_x = tokenizer([prompt], return_tensors="pt")

In [14]:
# Run inference.
kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=1024, top_p=None, num_beams=1)

generated_text = model.generate(
    vision_x=vision_x, 
    lang_x=lang_x['input_ids'].to(torch.device('cuda:0')), 
    image_size=image_size,
    attention_mask=lang_x['attention_mask'].to(torch.device('cuda:0')), 
    **kwargs_default)
    
generated_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
if 'Phi-3' in cfg.lm_path:
    text = generated_text.split('<|end|>')[0]
else:
    text=generated_text

print(text)

You are not running the flash-attention implementation, expect numerical differences.


A black and white cat. 
