In [10]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
from pathlib import Path
import torch

model_name = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

idx = 12

nlq_list = json.loads(Path("/data/soyeonhong/GroundVQA/data/unified/annotations.NLQ_val.json").read_text())
llava_caption = json.loads(Path(f"/data/soyeonhong/GroundVQA/llava-v1.6-34b/global/0aabdda2-d305-44c8-b085-f018ea62d872.json").read_text())['answers']

In [4]:
caption_list = []
query_list = []
frame_idx_list = []
for caption in llava_caption:
    frame_idx_list.append(caption[0])
    caption_list.append(caption[2])
    query_list.append(nlq_list[idx]['question'])
    
# caption_list[-1] = "I put the fire gun"
# caption_list[-1]

features = tokenizer(query_list, caption_list, max_length= 172, return_tensors='pt', padding='max_length', truncation=True) # 무조건 172 나오게 하기

In [5]:
query_list[0]

'How many wires did i pick from the floor?'

In [6]:
model.eval()
with torch.no_grad():
    output = model(**features, return_dict=True, output_hidden_states=True)
    print(output.hidden_states[-1].shape)

torch.Size([49, 172, 384])


In [7]:
pool = model.bert.pooler(output.hidden_states[-1])
pool.shape

torch.Size([49, 384])

In [54]:
# mean = torch.mean(output.hidden_states[-1], dim=1)
# mean.shape

In [43]:
# result_from_mean = model.classifier(mean)
# result_from_mean

In [9]:
result = model.classifier(pool)

for idx, res in enumerate(result):
    
    print(f"{idx * 300}, {idx * 10}: {round(res.item(), 2)}\n")

0, 0: -10.9
300, 10: -9.73
600, 20: -10.53
900, 30: -10.52
1200, 40: -10.97
1500, 50: -10.76
1800, 60: -10.77
2100, 70: -9.84
2400, 80: -10.83
2700, 90: -11.17
3000, 100: -11.04
3300, 110: -11.22
3600, 120: -11.26
3900, 130: -9.91
4200, 140: -10.41
4500, 150: -10.78
4800, 160: -10.86
5100, 170: -10.74
5400, 180: -10.93
5700, 190: -10.05
6000, 200: -9.7
6300, 210: -8.27
6600, 220: -10.06
6900, 230: -10.55
7200, 240: -11.01
7500, 250: -10.32
7800, 260: -6.2
8100, 270: -11.02
8400, 280: -10.77
8700, 290: -10.88
9000, 300: -10.56
9300, 310: -10.72
9600, 320: -11.07
9900, 330: -10.69
10200, 340: -11.04
10500, 350: -11.22
10800, 360: -10.99
11100, 370: -11.16
11400, 380: -7.62
11700, 390: -10.71
12000, 400: -4.66
12300, 410: -10.18
12600, 420: -11.21
12900, 430: -11.18
13200, 440: -11.21
13500, 450: -11.23
13800, 460: -11.22
14100, 470: -10.79
14400, 480: -11.22


In [None]:
print(f"{query_list[0]}\n")
for caption, logit in zip(caption_list, output.logits):
    print(f"Logit: {logit.item()}\n{caption}\n")

In [109]:
save_list = []
for frame_idx, result in zip(frame_idx_list, mean_result):
    save_list.append((frame_idx, result))

In [110]:
torch.save(save_list, f"{nlq_list[0]['sample_id']}.pt")

In [111]:
x = torch.load(f"{nlq_list[0]['sample_id']}.pt")

In [112]:
for frame_idx, result in x:
    print(f"{frame_idx:6d}: {result.shape}")

     0: torch.Size([384])
   300: torch.Size([384])
   600: torch.Size([384])
   900: torch.Size([384])
  1200: torch.Size([384])
  1500: torch.Size([384])
  1800: torch.Size([384])
  2100: torch.Size([384])
  2400: torch.Size([384])
  2700: torch.Size([384])
  3000: torch.Size([384])
  3300: torch.Size([384])
  3600: torch.Size([384])
  3900: torch.Size([384])
  4200: torch.Size([384])
  4500: torch.Size([384])
  4800: torch.Size([384])
  5100: torch.Size([384])
  5400: torch.Size([384])
  5700: torch.Size([384])
  6000: torch.Size([384])
  6300: torch.Size([384])
  6600: torch.Size([384])
  6900: torch.Size([384])
  7200: torch.Size([384])
  7500: torch.Size([384])
  7800: torch.Size([384])
  8100: torch.Size([384])
  8400: torch.Size([384])
  8700: torch.Size([384])
  9000: torch.Size([384])
  9300: torch.Size([384])
  9600: torch.Size([384])
  9900: torch.Size([384])
 10200: torch.Size([384])
 10500: torch.Size([384])
 10800: torch.Size([384])
 11100: torch.Size([384])
 11400: torc

In [106]:
tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.cls_token, tokenizer.sep_token

(0, 101, 102, '[CLS]', '[SEP]')

In [107]:
feature_one_set = tokenizer(query_list[0], caption_list[0], max_length= 172, return_tensors='pt', padding=True, truncation=True)
feature_one_set

{'input_ids': tensor([[  101,  2073,  2106,  1045,  2404,  1996,  2543,  3282,  1029,   102,
          1996,  3746,  3065,  2019,  5992,  5997,  2007,  2536,  6177,  1012,
          2045,  2024,  4984, 24742,  1010,  2029,  2024,  2109,  2000,  4047,
          2019,  5992,  4984,  2013,  4053,  3303,  2011,  2019,  2058, 11066,
          2030,  2460,  4984,  1012,  1996,  5997,  2036,  2038,  1037,  2417,
          2422,  2006,  2327,  1010,  2029,  2003,  3497,  2019, 17245,  2005,
          2019,  8598,  2030,  5432,  2291,  1012,  2000,  1996,  2157,  1010,
          2045,  2003,  1037,  3384, 10535,  6729,  2114,  1996,  2813,  1010,
          2029,  2003,  4141,  2109,  2005,  3229,  2075,  3020,  2752,  1010,
          2107,  2004, 14832,  2015,  2030, 15753,  1012,  1996,  4292,  3544,
          2000,  2022,  2019,  4592,  2686,  1010,  4298,  1037,  9710,  2282,
          2030,  1037,  2112,  1997,  1037,  2311,  2073,  5992,  3001,  2024,
          7431,  1012,   102]]), 'toke

In [108]:
for idx, (token, token_id) in enumerate(zip(
    ['[CLS]'] + tokenizer.tokenize(query_list[0]) + ['[SEP]'] + tokenizer.tokenize(caption_list[0]) + ['[SEP]'],
    features['input_ids'][0].tolist()
)):
    if token_id not in [tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id]:
        print(f"{idx:2d}     {token:15s} {token_id}")
    else:
        print(f"{idx:2d}     {token:15s} {token_id} (special token)")

 0     [CLS]           101 (special token)
 1     where           2073
 2     did             2106
 3     i               1045
 4     put             2404
 5     the             1996
 6     fire            2543
 7     gun             3282
 8     ?               1029
 9     [SEP]           102 (special token)
10     the             1996
11     image           3746
12     shows           3065
13     an              2019
14     electrical      5992
15     panel           5997
16     with            2007
17     various         2536
18     components      6177
19     .               1012
20     there           2045
21     are             2024
22     circuit         4984
23     breakers        24742
24     ,               1010
25     which           2029
26     are             2024
27     used            2109
28     to              2000
29     protect         4047
30     an              2019
31     electrical      5992
32     circuit         4984
33     from            2013
34     damage    