*Note: Please refer to "vit(gemma)_from_head.ipynb" notebooks for introdution to ViT and Gemma models*

*In this notebook we will just combine the two*

In [88]:
import sys
sys.path.append('../src')
from gemma import Gemma
from vit import ViT

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load reference HF model

In [108]:
model_id = "google/paligemma-3b-mix-224"
pg_model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

Loading checkpoint shards:  67%|██████▋   | 2/3 [00:17<00:09,  9.27s/it]

In [None]:
prompt = "What is on the flower?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt")
output = pg_model.generate(**inputs, max_new_tokens=20)

print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



bee


In [15]:
# original preprocessor tokenizes the prompt for Gemma
# and preprocesses the image for ViT, and then combines the two in the dictionary
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'pixel_values'])

In [51]:
text_weights = pg_model.language_model.state_dict()
video_weights = pg_model.vision_tower.vision_model.state_dict()
multi_modal_projector_weights = pg_model.multi_modal_projector.state_dict()
config = pg_model.config
pg_text_config = config.text_config
pg_vision_config = config.vision_config

# Load custom Gemma and ViT models

In [82]:
gemma = Gemma(
    dim = pg_text_config.hidden_size,
    n_layers = pg_text_config.num_hidden_layers,
    n_heads = pg_text_config.num_attention_heads,
    num_key_value_heads = pg_text_config.num_key_value_heads,
    fc_intermediate_size = pg_text_config.intermediate_size,
    vocab_size = pg_text_config.vocab_size,
    rms_norm_eps = pg_text_config.rms_norm_eps,
    max_position_embeddings = pg_text_config.max_position_embeddings,
    with_embedding = False,
)
gemma.load_hf_weights(text_weights)

In [30]:
x = torch.tensor([[  5706, 125942, 151223, 139977,  96629, 160977, 251909, 209214,   6190,
         102413, 247227,  84615,  12321, 102069, 250598, 165257, 213011, 223305,
         108701, 223214]])

cutom_out = gemma(x)
print(cutom_out.shape)
cutom_out

torch.Size([1, 20, 257216])


tensor([[[ -0.0916,   5.1240, -15.7540,  ...,  -0.8838,  -0.8866,  -0.8847],
         [  0.6308,   4.3901, -20.1562,  ...,  -0.8476,  -0.8535,  -0.8455],
         [ -0.1848,   5.5989, -21.0075,  ...,  -0.8385,  -0.8447,  -0.8407],
         ...,
         [ -0.1042,   0.3676, -10.0569,  ...,  -1.7666,  -1.7766,  -1.7793],
         [  0.3203,   0.7445, -15.5378,  ...,  -1.6099,  -1.6231,  -1.6265],
         [  1.3958,   5.7269, -15.0676,  ...,  -0.5363,  -0.5472,  -0.5411]]],
       grad_fn=<UnsafeViewBackward0>)

In [43]:
vit = ViT(
    dim=pg_vision_config.hidden_size,
    n_channels=pg_vision_config.num_channels,
    n_layers=pg_vision_config.num_hidden_layers,
    n_heads=pg_vision_config.num_attention_heads,
    image_size=pg_vision_config.image_size,
    patch_size=pg_vision_config.patch_size,
    fc_intermediate_size=pg_vision_config.intermediate_size,
    norm_eps=pg_vision_config.layer_norm_eps,
)
vit.load_hf_weights(video_weights)

In [56]:
visual_embedding_to_text_embedding = torch.nn.Linear(
    pg_vision_config.hidden_size,
    pg_text_config.hidden_size,
    bias=True
)
visual_embedding_to_text_embedding.weight.data = multi_modal_projector_weights['linear.weight']
visual_embedding_to_text_embedding.bias.data = multi_modal_projector_weights['linear.bias']

In [63]:
text_embedding_layer = nn.Embedding(pg_text_config.vocab_size, pg_text_config.hidden_size)
text_embedding_layer.weight.data = text_weights['model.embed_tokens.weight']

# Run dirty Paligemma inference 

In [46]:
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
img = processor.image_processor(raw_image).pixel_values[0]
img = torch.tensor(img)

torch.Size([1, 256, 1152])


In [57]:
# produce the visual embedding
with torch.no_grad():
    visual_emb = vit(img.unsqueeze(0))

print(visual_emb.shape)

torch.Size([1, 256, 1152])


In [70]:
# project the visual embedding to the text embedding
visual_emb = visual_embedding_to_text_embedding(visual_emb)
print(text_emb.shape)

torch.Size([1, 7, 2048])


In [68]:
# let's emmbed the text
prompt = "What is on the flower?"
tokens = processor.tokenizer(prompt, return_tensors="pt")
text_emb = text_embedding_layer(tokens['input_ids'])
print(text_emb.shape)

torch.Size([1, 7, 2048])


In [71]:
# let's concatenate the visual and text embeddings
multi_modal_emb = torch.cat((visual_emb, text_emb), dim=1)
print(multi_modal_emb.shape)

torch.Size([1, 263, 2048])


In [83]:
# let's normalize the multi-modal embedding as in gamma
normalizer = torch.tensor(pg_text_config.hidden_size**0.5, dtype=multi_modal_emb.dtype, device=multi_modal_emb.device)
multi_modal_emb_norm = multi_modal_emb * normalizer
multi_modal_emb_norm


tensor([[[ 95.2791, -16.8974,   0.9396,  ..., -10.6272,   5.6287,  21.7257],
         [-41.0327,   1.5128,  27.5928,  ...,  -6.2910,  17.8286,   1.0437],
         [ 81.8929, -15.0253,  18.9915,  ..., -15.1030,  26.8091,  -9.5997],
         ...,
         [ 10.9351,  -1.4130,  -4.7609,  ...,  -0.3168,   2.1305,   1.1936],
         [ 13.5130,  -2.6200,   2.3359,  ...,   1.4821,   4.2262,  -1.0913],
         [  9.9381,  -1.2802,  -5.5774,  ...,  -1.7720,   0.5118,  -3.4634]]],
       grad_fn=<MulBackward0>)

In [107]:
# let's pass the multi-modal embedding through the gemma model
gemma_out = gemma(multi_modal_emb_norm)
# gemma_out.shape
gemma_out

tensor([[[164.1517,  44.3966,  45.8237,  ..., 106.8878, 107.1223, 107.1785],
         [198.8348,  62.3247,  34.0407,  ..., 132.9104, 133.1945, 133.2679],
         [187.5081,  53.3461,  37.0409,  ..., 124.0491, 124.3092, 124.3831],
         ...,
         [ -4.9033,   4.9402, -19.8332,  ...,  -4.1096,  -4.1100,  -4.1222],
         [ -3.5158,   6.9640, -11.5104,  ...,  -3.1971,  -3.2035,  -3.2032],
         [ -7.5126,   6.9775, -13.5781,  ...,  -5.2280,  -5.2349,  -5.2321]]],
       grad_fn=<UnsafeViewBackward0>)

In [103]:
next_token = F.softmax(gemma_out[0,-1], dim=-1).argmax()
processor.tokenizer.decode(next_token.item())

' is'