In [1]:
from transformers import (
     Blip2QFormerConfig,
     Blip2QFormerModel,
 )
import sys

import torch
from transformers import AutoModelForCausalLM

from model.memory_bank_ours.models import VLChatProcessor
from model.deepseek_vl.utils.io import load_pil_images
from torch import nn
from einops import rearrange
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


[2024-07-19 10:52:48,063] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Python version is above 3.10, patching the collections module.
Python version is above 3.10, patching the collections module.


In [None]:
model_path = "/data/Users/xyq/developer/happy_code/model_repo/deepseek-vl-1.3b-base"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_path,
)

In [None]:
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

In [None]:
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path)
vl_gpt = vl_gpt.to(torch.bfloat16).to("cuda:0").eval()

In [None]:
vl_gpt.qformer_config

In [None]:
vl_gpt.qformer

In [None]:
tokenizer.vocab_size

In [None]:
qformer_config = Blip2QFormerConfig(encoder_hidden_size=1024, hidden_size=1024, vocab_size=tokenizer.vocab_size, num_attention_heads=16)
qformer_config

In [None]:
qformer = Blip2QFormerModel(qformer_config)

In [None]:
# count training parameter of qformer
sum(p.numel() for p in qformer.parameters()) # 101M

In [None]:
qformer

In [None]:
query_tokens = nn.Parameter(
    torch.zeros(1, 32, vl_gpt.qformer_config.hidden_size) # [1,32,hidden_size]
)

In [None]:
query_tokens.shape

In [None]:
conversation = [
      {
        "role": "User",
        "content": "Current task: craft_stone_pickaxe\nBased on current task, historical observations and actions, predict the four actions that masked as <action>.\n<image_placeholder><a><attack></a><a><attack></a><a><attack></a><a><action></a><a><action></a><a><action></a><a><action></a><a><attack><x>-5.81</x><y>-1.61</y></a><a><attack><x>-10.00</x><y>-1.61</y></a><a><attack><forward><x>-5.81</x><y>0.00</y></a><image_placeholder>",
        "images": [
          "/data/Users/xyq/developer/happy_code/data/action_dpo/v1/mc_dataset_v1/craft_stone_pickaxe_1385/craft_stone_pickaxe_1385_frame_34.jpg",
          "/data/Users/xyq/developer/happy_code/data/action_dpo/v1/mc_dataset_v1/craft_stone_pickaxe_1385/craft_stone_pickaxe_1385_frame_35.jpg"
        ]
      },
      {
        "role": "Assistant",
        "content": "<a><attack></a><a><attack></a><a><attack></a><a><attack></a>"
      }
    ]
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True).to(
    vl_gpt.device
)

In [None]:
prepare_inputs.keys()


In [None]:
pixel_values = prepare_inputs["pixel_values"]
input_ids = prepare_inputs["input_ids"]
images_seq_mask = prepare_inputs["images_seq_mask"]

In [None]:
bs, n = pixel_values.shape[0:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
bs, n, images.shape

In [None]:
images_features = vl_gpt.vision_model(images)
images_features.shape

In [None]:
images_embeds = rearrange(images_features, "(b n) t d -> b (n t) d", b=bs, n=n)
images_embeds.shape

In [None]:
query_token_bs = query_tokens.expand(images_embeds.shape[0], -1, -1).to("cuda:0")
query_token_bs.shape

In [None]:
image_attention_mask = torch.ones(images_embeds.size()[:-1], dtype=torch.long, device="cuda:0")
# qformer = qformer.to("cuda:0")

In [None]:
query_token_bs

In [None]:
query_outputs = vl_gpt.qformer(
    query_embeds=query_token_bs.to(torch.bfloat16),
    encoder_hidden_states=images_embeds.to(torch.bfloat16),
    encoder_attention_mask=image_attention_mask,
    # output_attentions=output_attentions,
    # output_hidden_states=output_hidden_states,
    # return_dict=return_dict,
)

In [None]:
query_output = query_outputs[0]
query_output.shape # same as query

In [None]:
query_output

In [None]:
image_embeds_to_language = vl_gpt.aligner(query_output)
image_embeds_to_language.shape # [1, 32, 2048]

In [None]:
# find first False of m1, m1 is bool tensor
images_seq_mask = input_ids==100015
all_images_seq = torch.nonzero(images_seq_mask, as_tuple=True)
first_32_true_indices = all_images_seq[1][32:]
images_seq_mask[all_images_seq[0][32:], first_32_true_indices] = False

In [None]:
first_32_true_indices

In [None]:
images_seq_mask

In [None]:
input_ids[input_ids < 0] = 0
input_ids[input_ids == 100015] = 0
input_ids[images_seq_mask] = 100015

In [None]:
input_ids.shape

In [None]:
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(input_ids)
inputs_embeds.shape

In [None]:
images_emb_mask = prepare_inputs["images_emb_mask"]
torch.sum(images_emb_mask==True)

In [None]:
# [b x n, T2, D]
# todo: change to qformer!
# images_embeds = self.aligner(self.vision_model(images))

# [b x n, T2, D] -> [b, n x T2, D]
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")

image_attention_mask = torch.ones(images_embeds.size()[:-1], dtype=torch.long, device=images_embeds.device)

query_outputs = qformer(
query_embeds=query_tokens,
encoder_hidden_states=images_embeds,
encoder_attention_mask=image_attention_mask,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
)
query_output = query_outputs[0]

images_embeds = vl_gpt.aligner(query_output)

# [b, T, D]
input_ids[input_ids < 0] = 0  # ignore the image embeddings
# with torch.cuda.amp.autocast():
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(input_ids)

# replace with the image embeddings
# 只取32个token
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask].to(dtype=inputs_embeds.dtype)

In [1]:
f = {"a": True, "b": False, "c": True}

keys = list(filter(lambda x: f[x], f))
keys

['a', 'c']