FLAVA
====

**FLAVA: A Foundational Language And Vision Alignment Mode**

* Paper: https://arxiv.org/abs/2112.04482


![FLAVA model overview](../assets/flava-model-overview.png)

```bash
pip install torch torchvision
pip install transformers
pip install matplotlib
pip install supervision
```

In [1]:
from PIL import Image
import torch
from transformers import FlavaProcessor, FlavaModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = FlavaProcessor.from_pretrained(
    "facebook/flava-full"
)
model = FlavaModel.from_pretrained(
    "facebook/flava-full"
).to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
image1 = Image.open("../samples/fruits-01.jpg")
image2 = Image.open("../samples/plants.jpg")

inputs = processor(
    text=[
        "a photo of fruits and vegetables",
        "a photo of indoor plants"
    ],
    images=[image1, image2],
    return_tensors="pt",
    padding="max_length",
    max_length=77
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
    outputs = model(**inputs)
image_embeddings = outputs.image_embeddings
text_embeddings = outputs.text_embeddings
multimodal_embeddings = outputs.multimodal_embeddings
print("image embeddings shape:", image_embeddings.shape)
print("text embeddings shape:", text_embeddings.shape)
print("multimodal embeddings shape:", multimodal_embeddings.shape)

image embeddings shape: torch.Size([2, 197, 768])
text embeddings shape: torch.Size([2, 77, 768])
multimodal embeddings shape: torch.Size([2, 275, 768])




In [3]:
## Pimage-only
from transformers import FlavaFeatureExtractor
## text-only
from transformers import BertTokenizer

feature_extractor = FlavaFeatureExtractor.from_pretrained(
    "facebook/flava-full"
)
inputs = feature_extractor(
    images=[image1, image2], return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
    outputs = model(**inputs)
image_embeddings = outputs.image_embeddings


tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")
inputs = tokenizer(
    [
        "a photo of fruits and vegetables",
        "a photo of indoor plants"
    ],
    return_tensors="pt",
    padding="max_length",
    max_length=77
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
    outputs = model(**inputs)
text_embeddings = outputs.text_embeddings

print("image embeddings shape:", image_embeddings.shape)
print("text embeddings shape:", text_embeddings.shape)




image embeddings shape: torch.Size([2, 197, 768])
text embeddings shape: torch.Size([2, 77, 768])
