In [1]:
from sentence_transformers import CrossEncoder, SentenceTransformer
from transformers import AutoTokenizer, AutoModel
import json
from pathlib import Path
import pandas as pd
import torch

model_name = 'sentence-transformers/all-mpnet-base-v2'
cross_encoder = CrossEncoder(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

llava_result_list = Path('/data/soyeonhong/GroundVQA/llava-v1.6-34b/global/').glob('*.json')
nlq_list = json.loads(Path("/data/soyeonhong/GroundVQA/data/unified/annotations.NLQ_val.json").read_text())

Some weights of MPNetForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-mpnet-base-v2 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
nlq_list[0]

{'video_id': 'f06d1935-550f-4caa-909c-b2db4c28f599',
 'sample_id': 'd5513548-e16e-486d-8def-740ea7b7fbb0_0',
 'question': 'what did I pick from the fridge?',
 'moment_start_frame': 517.7007,
 'moment_end_frame': 817.6800000000001,
 'clip_start_sec': 17.25669,
 'clip_end_sec': 27.256,
 'clip_duration': 480.0}

In [12]:
except_list = ['70a350cd-4f32-40ed-80dd-23a48e7c4e46',
               '04199001-307a-40fd-b20c-4b4128546b89'] # nlq_train

In [None]:
for annotation in nlq_list:
    video_id = annotation['video_id']
    sample_id = annotation['sample_id']
    question = annotation['question']
    if not video_id in except_list:
        llava_caption_list = json.loads(Path(f"/data/soyeonhong/GroundVQA/llava-v1.6-34b/global/{video_id}.json").read_text())['answers']
        
        save_data = []
        
        for caption in llava_caption_list:
            time = caption[0]
            # print(time)
            tok_output = tokenizer([caption[2], question], max_length=150, padding=True, truncation=True, return_tensors='pt')
            
            with torch.no_grad():
                outputs = model(**tok_output)
                token_embeddings = outputs.last_hidden_state
                
            temp_list = [caption[0],
                         tuple(token_embeddings.shape),
                         token_embeddings[0],
                         token_embeddings[1]]
                
            save_data.append(temp_list)
            
        torch.save(save_data, f'cross_encoding/{sample_id}.pt')
        print(f"{sample_id}.pt is saved")
        
    

In [22]:
token_embeddings[0].shape

torch.Size([89, 768])

In [20]:
tuple(token_embeddings.shape)

(2, 89, 768)

In [2]:
x = torch.load("/data/soyeonhong/GroundVQA/cross_encoding/val/0a454217-ecf1-4bb0-9365-30eed8e80c1d_0.pt")

In [8]:
x[0]

[0,
 (2, 128, 768),
 tensor([[-0.0646, -0.0799,  0.0160,  ...,  0.0276, -0.1183, -0.0521],
         [-0.0160, -0.1267, -0.0466,  ...,  0.0920, -0.0075, -0.1555],
         [ 0.1011, -0.2076, -0.0928,  ...,  0.0954,  0.0239, -0.0728],
         ...,
         [-0.0353,  0.0277,  0.0035,  ...,  0.0526, -0.0359, -0.0016],
         [ 0.0339,  0.1140, -0.0056,  ..., -0.0590, -0.1598, -0.0460],
         [ 0.0335, -0.0103, -0.0057,  ...,  0.0046, -0.1473, -0.0330]]),
 tensor([[ 0.1269, -0.2565,  0.0155,  ...,  0.0682, -0.0235,  0.0202],
         [ 0.2156, -0.2340, -0.0227,  ...,  0.1586,  0.0737,  0.0628],
         [ 0.1833, -0.2544,  0.0175,  ...,  0.1519, -0.0108,  0.0740],
         ...,
         [ 0.2280, -0.0289, -0.0055,  ...,  0.1417, -0.0306, -0.0196],
         [ 0.2280, -0.0289, -0.0055,  ...,  0.1417, -0.0306, -0.0196],
         [ 0.2280, -0.0289, -0.0055,  ...,  0.1417, -0.0306, -0.0196]])]

In [7]:
for frame_idx, (count, num_tokens, token_embeddings), caption_feature, query_feature in x:
    print(f'{frame_idx:6d}: caption -> {list(caption_feature.shape)}, query -> {list(query_feature.shape)}')

     0: caption -> [128, 768], query -> [128, 768]
   300: caption -> [99, 768], query -> [99, 768]
   600: caption -> [115, 768], query -> [115, 768]
   900: caption -> [142, 768], query -> [142, 768]
  1200: caption -> [131, 768], query -> [131, 768]
  1500: caption -> [99, 768], query -> [99, 768]
  1800: caption -> [132, 768], query -> [132, 768]
  2100: caption -> [106, 768], query -> [106, 768]
  2400: caption -> [113, 768], query -> [113, 768]
  2700: caption -> [143, 768], query -> [143, 768]
  3000: caption -> [129, 768], query -> [129, 768]
  3300: caption -> [106, 768], query -> [106, 768]
  3600: caption -> [121, 768], query -> [121, 768]
  3900: caption -> [100, 768], query -> [100, 768]
  4200: caption -> [113, 768], query -> [113, 768]
  4500: caption -> [105, 768], query -> [105, 768]
  4800: caption -> [120, 768], query -> [120, 768]
  5100: caption -> [81, 768], query -> [81, 768]
  5400: caption -> [126, 768], query -> [126, 768]
  5700: caption -> [106, 768], query 