In [1]:
from PIL import Image
import requests
import torch

In [2]:
from factory import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="togethercomputer/RedPajama-INCITE-Base-3B-v1",
    tokenizer_path="togethercomputer/RedPajama-INCITE-Base-3B-v1",
    cross_attn_every_n_layers=2,
    use_local_files=True
)

print(model)

Flamingo(
  (vision_encoder): VisionTransformer(
    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-23): 24 x ResidualAttentionBlock(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((1024,), eps=1e-05, elementwis

In [3]:
from huggingface_hub import hf_hub_download

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-4B-vitl-rpj3b", "checkpoint.pt")
mismatched_keys = model.load_state_dict(torch.load(checkpoint_path), strict=False)

# print("Missing keys:", mismatched_keys.missing_keys)
# print("Unexpected keys:", mismatched_keys.unexpected_keys)

  mismatched_keys = model.load_state_dict(torch.load(checkpoint_path), strict=False)


In [4]:
from PIL import Image
import requests
import torch

"""
Step 1: Load images
"""
demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)

demo_image_two = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
        stream=True
    ).raw
)

query_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 
        stream=True
    ).raw
)


"""
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 
 batch_size x num_media x num_frames x channels x height x width. 
 In this case batch_size = 1, num_media = 3, num_frames = 1,
 channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)

"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
 We also expect an <|endofchunk|> special token to indicate the end of the text 
 portion associated with an image.
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
    return_tensors="pt",
)

print(vision_x.shape)


"""
Step 4: Generate text
"""
generated_text = model.generate(
    media=vision_x,
    x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=30,
    num_beams=1,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))

torch.Size([1, 3, 1, 3, 224, 224])


Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.


torch.Size([64, 1024])
Generated text:  <image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of a bathroom sink..</image>An image of a bathroom sink..</image>An image of a bathroom sink..</image>An image of
