In [1]:
import gradio as gr
import peft
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from PIL import Image
import requests
import numpy as np
import torch.nn as nn

In [2]:
clip_model_name = "openai/clip-vit-base-patch32"
phi_model_name  = "microsoft/phi-2"
tokenizer  = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 768
phi_embed  = 2560

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
class SimpleResBlock(nn.Module):
    def __init__(self, phi_embed):
        super().__init__()
        self.pre_norm = nn.LayerNorm(phi_embed)
        self.proj = nn.Sequential(
            nn.Linear(phi_embed, phi_embed),
            nn.GELU(),
            nn.Linear(phi_embed, phi_embed)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)

In [11]:
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
resblock = SimpleResBlock(phi_embed).to(device)
phi_model = AutoModelForCausalLM.from_pretrained(
    phi_model_name,
    trust_remote_code=True
).to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
# load weights
projection.load_state_dict(torch.load('./model_chkpt/step1_projection.pth',map_location=torch.device(device)))
resblock.load_state_dict(torch.load('./model_chkpt/step1_resblock.pth',map_location=torch.device(device)))

<All keys matched successfully>

In [17]:
phi_model

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (dense): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2560,),

In [19]:
def model_generate_ans(img):

    max_generate_length = 50
    
    # image
    image_processed = processor(images=img, return_tensors="pt").to(device)
    clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
    val_image_embeds = projection(clip_val_outputs)
    val_image_embeds = resblock(val_image_embeds).to(torch.float16)

    img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
    img_token_embeds = phi_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
  
    val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds], dim=1) # 4, 69, 2560
    
    predicted_caption = torch.full((1,max_generate_length),50256)

    for g in range(max_generate_length):
        phi_output_logits = phi_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
        predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
        predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
        predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
        next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
        val_combined_embeds   = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
        
    predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
    
    return predicted_captions_decoded
    

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # Chat with MultiModal GPT !
    Build using combining clip model and phi-2 model.
    """
    )

    # app GUI
    with gr.Row():
        with gr.Column():
            img_input    = gr.Image(label='Image')
        with gr.Column():
            img_answer   = gr.Text(label ='Answer')

    section_btn = gr.Button("Submit")
    section_btn.click(model_generate_ans, inputs=[img_input], outputs=[img_answer])
    
if __name__ == "__main__":
    demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7868
Running on public URL: https://91cd4a354394c5067d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
