In [1]:
import sys
sys.path.append('/workspace/evidence_retrieval/lxmert/src')
from lxrt.modeling import *
import torch
from torch.cuda.amp import GradScaler, autocast
import easydict
from collections import OrderedDict
import pickle as pkl
from torch.utils.data import DataLoader
import numpy as np
import os
from tqdm.auto import tqdm
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="2"
from nltk.translate.meteor_score import single_meteor_score
from nltk.corpus import stopwords
stopwords = stopwords.words('english')

In [2]:
mode = 'newfinal'
exp_name = 'base'

In [3]:
class LXRTImageEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Obj-level image embedding layer
        self.visn_fc = VisualFeatEncoder(config)
        self.visn_cls = torch.nn.Parameter(torch.rand((1,768)))
        # Number of layers
        self.num_l_layers = VISUAL_CONFIG.l_layers
        self.num_x_layers = VISUAL_CONFIG.x_layers
        self.num_r_layers = VISUAL_CONFIG.r_layers
        print("LXRT encoder with %d l_layers, %d x_layers, and %d r_layers." %
              (self.num_l_layers, self.num_x_layers, self.num_r_layers))
        self.visn_cls = torch.nn.Parameter(torch.rand((1,1,768)) , requires_grad=True)
        # Layers
        # Using self.layer instead of self.l_layer to support loading BERT weights.
        self.r_layers = nn.ModuleList(
            [BertLayer(config) for _ in range(self.num_r_layers)]
        )

    def forward(self, lang_feats=None, lang_attention_mask=None,
                visn_feats=None, visn_attention_mask=None):
        # Run visual embedding layer
        # Note: Word embedding layer was executed outside this module.
        #       Keep this design to allow loading BERT weights.
        visn_feats = self.visn_fc(visn_feats)
        visn_cls = self.visn_cls.repeat((visn_feats.shape[0],1,1))
        visn_feats = torch.cat((visn_cls, visn_feats),dim = -2 )
        # Run cross-modality layers
        for layer_module in self.r_layers:
            visn_feats = layer_module(visn_feats, visn_attention_mask)

        return visn_feats

class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class LXRTImageModel(BertPreTrainedModel):
    """LXRT Model."""

    def __init__(self, config):
        super().__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = LXRTImageEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids=None, token_type_ids=None, attention_mask=None,
                visual_feats=None, visual_attention_mask=None):
        if visual_attention_mask is not None:
            extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
            extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
            extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
        else:
            extended_visual_attention_mask = None
        # Run LXRT backbone
        visn_feats = self.encoder(
            None,
            None,
            visn_feats=visual_feats,
            visn_attention_mask=extended_visual_attention_mask)
        pooled_output = self.pooler(visn_feats)

        return visn_feats, pooled_output

In [18]:
import os
import pickle as pkl
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from statistics import mean
from tqdm import tqdm
import csv
import sys
import time
import numpy as np
import base64
class segmentDataset(Dataset):
    """Segment dataset."""
    def __init__(self, vidlistpkl, feat_base, cook2_IVD_dir ="/workspace/evidence_retrieval/COOK2_IVD", mode = 'train' ):
        with open(os.path.join(vidlistpkl), "rb") as f:
            vid_pkl = pkl.load(f)
            
        self.vid_list = vid_pkl
        self.feat_base = feat_base
        with open(os.path.join(cook2_IVD_dir, f"pkl/{mode}.pkl"), "rb") as f:
            self.vid_pkl = pkl.load(f)
    def __len__(self):
        return len(self.vid_list)

    def __getitem__(self, idx):
        

        vid = self.vid_list[idx]
        with open(os.path.join(self.feat_base, vid+'.pkl'), 'rb') as fp:
                  item = pkl.load(fp)
        visn_feats, encode, temporal_label, label = item
        query = self.vid_pkl[vid]["query"]
        
        
        
        return {"vid": vid, "visn_feats":visn_feats, 'encode':encode, 'temporal_label':temporal_label,'label':label, 'query':query}

In [19]:
segment_dataset = segmentDataset('newsplit_train_vid_list.pkl','feat_dump')
valid_dataset = segmentDataset('newsplit_valid_vid_list.pkl','feat_dump', mode = 'vid')
test_dataset = segmentDataset('newsplit_test_vid_list.pkl','feat_dump', mode = 'test')

In [6]:
class prf_expansion(nn.Module):
    def __init__(self, is_cuda = True, modality = 'lang'):
        super(prf_expansion, self).__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.confounder_modaltiy = modality
        if confounder_modality == 'object':
            self.Z = torch.from_numpy( np.load('dic_v.npy') ).cuda().float()
            self.prior = torch.from_numpy(np.load('prior_v.npy')).cuda().float()
            self.y_att_matrix = nn.Linear(768,768)
            self.z_att_matrix = nn.Linear(2048,768) 
            self.e_z_x_transform = nn.Linear(2048,768)
        else:
            self.Z = torch.from_numpy( np.load('dic_t.npy') ).cuda().float()
            self.prior = torch.from_numpy(np.load('prior_t.npy')).cuda().float()
            self.y_att_matrix = nn.Linear(768,768)
            self.z_att_matrix = nn.Linear(768,768)
            self.e_z_x_transform = nn.Linear(768,768)
    def forward(self, y, vid = None):
            y = self.y_att_matrix(y)
            z = self.z_att_matrix(self.Z)
            a = self.softmax(torch.matmul(y, z.transpose(1,0)))
            a = self.prior * a        
            E_z_x = torch.matmul(a,self.Z)
            E_z_x = self.e_z_x_transform(E_z_x)
            return E_z_x

"class Do_Calculus(nn.Module):\n    def __init__(self, is_cuda = True, confounder_modality = 'lang'):\n        super(Do_Calculus, self).__init__()\n        with open(confounder_modality+'_confounder.pkl', 'rb') as fp:\n            self.Z = pkl.load( fp)\n        with open(confounder_modality+'_prior.pkl', 'rb') as fp:\n            self.prior = pkl.load( fp)\n        self.softmax = nn.Softmax(dim=-1)\n        self.confounder_modaltiy = confounder_modality\n        if confounder_modality == 'object':\n            for key in self.Z.keys():\n                if is_cuda:\n                    self.Z[key] = torch.from_numpy(self.Z[key]).cuda().float()\n                else:\n                    self.Z[key] = torch.from_numpy(self.Z[key]).float()\n            for key in self.prior.keys():\n                if is_cuda:\n                    self.prior[key] = torch.from_numpy(self.prior[key]).cuda().float()\n                else:\n                    self.prior[key] = torch.from_numpy(self.prior[ke

In [7]:
class expansion_merge(nn.Module):
    def __init__(self):
        super(Do_Evaluator, self).__init__()
        
        self.activation = nn.GELU()
        self.m1 = nn.Linear(768*2,768, bias = True)
        self.m2 = nn.Linear(768*2,768, bias = True)
    def forward(self, x):
        x = self.m1(x)
        #x = self.activation(x)
        #x = self.m2(x)
        return x

In [8]:
with open('yc2_recipes.json', 'r') as fp:
    recipe_dict = json.load(fp)

In [9]:
config = BertConfig('bert_config.json')
image_encoder = LXRTImageModel(config)
state_dict_path = os.path.join('lxmert', 'snap', 'pretrained', 'model_LXRT.pth') 
state_dict = torch.load(state_dict_path)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
    splittedkey = key.split('.')
    if 'bert' in splittedkey:
        newkey  = '.'.join(splittedkey[splittedkey.index('bert')+1:])
    else:
        newkey  = '.'.join(splittedkey[splittedkey.index('module')+1:])
    new_state_dict[newkey] = value
image_encoder.load_state_dict(new_state_dict, strict=False)
image_encoder.cuda().train()
print()

LXRT encoder with 9 l_layers, 5 x_layers, and 5 r_layers.



In [10]:
from transformers import BertModel, BertTokenizer
from transformers.models.bert.modeling_bert import BertOnlyMLMHead, BertConfig
config = BertConfig.from_pretrained('bert-base-uncased')
lang_encoder = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
temporal_classifier = torch.nn.Linear(768*2,3).cuda().train()
lang_encoder.cuda().train()
expansion_obj = prf_expansion(modality = 'object')
expansion_obj.cuda().train()
expansion_word = prf_expansion(modality = 'lang')
expansion_word.cuda().train()
merge1 =  expansion_merge()
merge1.cuda()
merge2 =  expansion_merge()
merge2.cuda()
print()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).





In [12]:
from lxrt.optimization import BertAdam
epoch = 100
batch_per_epoch = len(segment_dataset)
t_total = int(batch_per_epoch * epoch)
warmup_ratio = 0.05
warmup_iters = int(t_total * warmup_ratio)
optim = BertAdam(list(image_encoder.parameters()) + list(lang_encoder.parameters())+list(temporal_classifier.parameters())+list(do_cal.parameters())+list(do_evaluator.parameters())+list(do_cal2.parameters())+list(do_evaluator2.parameters()),
                       lr=1e-4, warmup=warmup_ratio, t_total=t_total)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
scaler = GradScaler()

In [13]:
flatten_recipe = {}
for foodname in recipe_dict.keys():
    flatten_recipe[foodname] = []
    for item in recipe_dict[foodname]:
        flatten_recipe[foodname].extend(item['split_ins'])

In [15]:
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
writer = SummaryWriter(f'{mode}/{exp_name}_{timestamp}')
print(timestamp)

2021-12-24 08:37:13


In [16]:
exp_name

'base'

In [17]:
epochs = 100

In [None]:
#train
min_loss = 0
sampling_num = 3
temporal_weight = 0.1
print(mode)
for epoch in tqdm(range(epochs)):
    random_idx = list(range(len(segment_dataset)))
    random.shuffle(random_idx)
    for batch_idx in range(len(segment_dataset)):
        batch = segment_dataset.__getitem__(random_idx[batch_idx])
        vid, visn_feats, encode, temporal_label, label = batch['vid'], batch["visn_feats"],batch['encode'], batch['temporal_label'],batch['label']
        frame_feats, box_feats = visn_feats
        foodname = batch['query']
        optim.zero_grad()
        image_encoder.train()
        lang_encoder.train()
        expansion_obj.train()
        expansion_word.train()
        merge1.train()
        merge2.train()
        temporal_classifier.train()
        with autocast():
            temporal_label = torch.tensor(temporal_label).long().cuda()
            label = torch.tensor(label).long().cuda()
            visn_feats = torch.tensor(frame_feats).float().cuda(), torch.tensor(box_feats).float().cuda()
            seq_output, pooled_output = image_encoder(visual_feats = visn_feats)
            encode = encode.cuda()
            output = lang_encoder(encode)
            sequence_output, lang_pooled_output = output[0], output[1]
            expanded_word = expansion_word(pooled_output, vid = vid)
            concat1 = torch.cat(( pooled_output, expanded_word), dim = -1)
            concat1 = merge1(concat1)
            expanded_obj = expansion_obj(lang_pooled_output, vid = vid)
            concat2 = torch.cat((lang_pooled_output, do_output2), dim = -1)
            concat2 = merge2(concat2)
            dotoutput = torch.matmul(concat1, concat2.transpose(1,0))
            loss = loss_fn(dotoutput, label)
            writer.add_scalar('Loss/train',loss.detach().cpu(), epoch*len(segment_dataset)+batch_idx)
            temporal_sample = []
            for visn_i in range(len(pooled_output)):
                    for lang_i in range(len(lang_pooled_output)):
                        temporal_sample.append(torch.cat((pooled_output[visn_i],
                                                          lang_pooled_output[lang_i]), -1).unsqueeze(0))
            temporal_sample = torch.cat(temporal_sample, 0)
            temporal_output = temporal_classifier(temporal_sample)
            temporal_loss = loss_fn(temporal_output, temporal_label)
            loss = loss+temporal_weight*temporal_loss
            writer.add_scalar('Loss/train_total',loss.detach().cpu(), epoch*len(segment_dataset)+batch_idx)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(list(image_encoder.parameters()) + list(lang_encoder.parameters())+list(temporal_classifier.parameters())+list(do_cal.parameters())+list(do_evaluator.parameters()),1 )
        scaler.step(optim)
        scaler.update()
    torch.save(lang_encoder.state_dict(), os.path.join(f'{mode}','lang_encoder_epoch{}.pth'.format(epoch)))
    torch.save(image_encoder.state_dict(),  os.path.join(f'{mode}','image_encoder_epoch{}.pth'.format(epoch)))
    torch.save(expansion_obj.state_dict(), os.path.join(f'{mode}','do_cal_epoch{}.pth'.format(epoch)))
    torch.save(expansion_word.state_dict(), os.path.join(f'{mode}','do_eval_{}.pth'.format(epoch)))
    torch.save(merge1.state_dict(),os.path.join(f'{mode}','do_cal2_epoch{}.pth'.format(epoch)))
    torch.save(merge2.state_dict(), os.path.join(f'{mode}','do_eval2_{}.pth'.format(epoch)))
    with torch.no_grad():
            ks = [1,3,5]
            recalls = {k:[] for k in ks}
            image_encoder.eval()
            lang_encoder.eval()
            expansion_obj.eval()
            expansion_word.eval()
            merge1.eval()
            merge2.eval()
            temporal_classifier.eval()
            with autocast():
                losses = 0
                for batch_idx in range(len(valid_dataset)):
                    batch = valid_dataset.__getitem__(batch_idx)
                    vid, visn_feats, encode, temporal_label, label = batch['vid'], batch["visn_feats"],batch['encode'], batch['temporal_label'],batch['label']
                    frame_feats, box_feats = visn_feats
                    foodname = batch['query']
                    temporal_label = torch.tensor(temporal_label).long().cuda()
                    label = torch.tensor(label).long().cuda()
                    visn_feats = torch.tensor(frame_feats).float().cuda(), torch.tensor(box_feats).float().cuda()
                    seq_output, pooled_output = image_encoder(visual_feats = visn_feats)
                    encode = encode.cuda()
                    output = lang_encoder(encode)
                    sequence_output, lang_pooled_output = output[0], output[1]
                    expanded_word = expansion_word(pooled_output, vid = vid)
                    concat1 = torch.cat(( pooled_output, expanded_word), dim = -1)
                    concat1 = merge1(concat1)
                    expanded_obj = expansion_obj(lang_pooled_output, vid = vid)
                    concat2 = torch.cat((lang_pooled_output, do_output2), dim = -1)
                    concat2 = merge2(concat2)
                    dotoutput = torch.matmul(concat1, concat2.transpose(1,0))
                    label = list(label.detach().cpu().numpy())
                    for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
                for k in ks:
                    print(epoch, k, sum(recalls[k])/len(recalls[k]))
                    writer.add_scalar(f'Recall/{k}',sum(recalls[k])/len(recalls[k]), epoch)


  0%|          | 0/100 [00:00<?, ?it/s]

newfinal


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


0 1 0.19477656702470966
0 3 0.5197850611254626
0 5 0.7503630852031805


  1%|          | 1/100 [03:30<5:46:47, 210.18s/it]

0 1 0.18926839276841556
0 3 0.5059511440939308
0 5 0.7381105282081996
1 1 0.20875102113822047
1 3 0.5289107248720955
1 5 0.7500553751300595


  2%|▏         | 2/100 [07:00<5:43:17, 210.18s/it]

1 1 0.19679316042671882
1 3 0.5175379612885255
1 5 0.7484007330327531
2 1 0.21040250497224755
2 3 0.5465086995140811
2 5 0.7697570276258016


  3%|▎         | 3/100 [10:29<5:39:12, 209.82s/it]

2 1 0.2018633780941282
2 3 0.5227114919581974
2 5 0.7639056064392304
3 1 0.26693942425924083
3 3 0.5934126300003663
3 5 0.8080714652428006


  4%|▍         | 4/100 [13:59<5:35:48, 209.88s/it]

3 1 0.25601512830092427
3 3 0.5846159431278186
3 5 0.8058208245440731
4 1 0.2785332944588686
4 3 0.6289419543104768
4 5 0.8288953060671378


  5%|▌         | 5/100 [17:28<5:32:02, 209.71s/it]

4 1 0.26817561686608554
4 3 0.6230110220617776
4 5 0.8280048827835386
5 1 0.2911223163915781
5 3 0.6415884217634462
5 5 0.8329179587949673


  6%|▌         | 6/100 [20:58<5:28:47, 209.87s/it]

5 1 0.2922538867091046
5 3 0.6435747233217225
5 5 0.8409589921493222
6 1 0.2961511783833965
6 3 0.645748812180138
6 5 0.8343017384104536


  7%|▋         | 7/100 [24:29<5:25:31, 210.02s/it]

6 1 0.3018925307396609
6 3 0.639429899133046
6 5 0.8363330877313226
7 1 0.32415307719454856
7 3 0.6707928320193501
7 5 0.8511564237999217


  8%|▊         | 8/100 [27:58<5:21:30, 209.68s/it]

7 1 0.33116306415537
7 3 0.6760722625231058
7 5 0.8589055619672612
8 1 0.33984923176678317
8 3 0.6878111965175043
8 5 0.8599057509483512


  9%|▉         | 9/100 [31:26<5:17:24, 209.28s/it]

8 1 0.3421613725020853
8 3 0.6816411013937719
8 5 0.8612159539819239
9 1 0.3423520310047439
9 3 0.6812619594057262
9 5 0.8572709126760832


 10%|█         | 10/100 [34:55<5:13:54, 209.28s/it]

9 1 0.3433729014858584
9 3 0.6799960745844703
9 5 0.8586345300523562
10 1 0.3528217454342772
10 3 0.6991903214514869
10 5 0.8640913082826653


 11%|█         | 11/100 [38:25<5:10:34, 209.38s/it]

10 1 0.3482542203092355
10 3 0.6925376418128488
10 5 0.8677541810106788
11 1 0.35520756478720245
11 3 0.7032953188597997
11 5 0.8655819182889399


 12%|█▏        | 12/100 [41:54<5:07:09, 209.42s/it]

11 1 0.36393862988192116
11 3 0.7092256705813361
11 5 0.8667713208502705
12 1 0.356821959500987
12 3 0.7022532207822688
12 5 0.8684353241657933


 13%|█▎        | 13/100 [45:24<5:03:44, 209.47s/it]

12 1 0.3522790962685266
12 3 0.7058188868076927
12 5 0.8748715541737242
13 1 0.3634093517302202
13 3 0.7006220550039698
13 5 0.8679342345330286


 14%|█▍        | 14/100 [48:54<5:00:32, 209.68s/it]

13 1 0.3631712872698616
13 3 0.706527966533062
13 5 0.8683563763754814
14 1 0.3771578908087649
14 3 0.7111353225888936
14 5 0.8698222136840624


 15%|█▌        | 15/100 [52:24<4:57:11, 209.78s/it]

14 1 0.36633519390613695
14 3 0.717211416176506
14 5 0.8725685671963618
15 1 0.3836498157968173
15 3 0.7162923687953121
15 5 0.8740787981220896


 16%|█▌        | 16/100 [55:54<4:53:35, 209.71s/it]

15 1 0.38563710646196103
15 3 0.7234648186806294
15 5 0.8749571849810347
16 1 0.3821812668014359
16 3 0.7181458513946444
16 5 0.8730526200921934


 17%|█▋        | 17/100 [59:24<4:50:13, 209.81s/it]

16 1 0.3815342770515268
16 3 0.7257177662130136
16 5 0.8716793289068009
17 1 0.3767707148825731
17 3 0.7068825362206359
17 5 0.8696124261553976


 18%|█▊        | 18/100 [1:02:54<4:46:51, 209.90s/it]

17 1 0.3729724302180129
17 3 0.7097179283825135
17 5 0.8746096161008236
18 1 0.3856262113833704
18 3 0.7131795777786601
18 5 0.8726483325799541


 19%|█▉        | 19/100 [1:06:22<4:42:44, 209.44s/it]

18 1 0.38029386015747657
18 3 0.7200143667368986
18 5 0.874131628463078
19 1 0.3833101140495185
19 3 0.7170034504340591
19 5 0.8747538528107192


 20%|██        | 20/100 [1:09:50<4:38:42, 209.04s/it]

19 1 0.3850368324630197
19 3 0.7270863510565776
19 5 0.8745795628501808
20 1 0.3859434540747815
20 3 0.7218651426894724
20 5 0.8772390171578488


 21%|██        | 21/100 [1:13:20<4:35:26, 209.19s/it]

20 1 0.38170475046367097
20 3 0.7272167700437021
20 5 0.8788008359193203
21 1 0.38301719482300495
21 3 0.7213708240240173
21 5 0.8726928370419201


 22%|██▏       | 22/100 [1:16:49<4:31:52, 209.14s/it]

21 1 0.37992629367606046
21 3 0.7305056743085517
21 5 0.8782040203255759
22 1 0.3836661275335266
22 3 0.7192324372121427
22 5 0.8741530584304014


 23%|██▎       | 23/100 [1:20:20<4:29:02, 209.64s/it]

22 1 0.3807502742441199
22 3 0.7259576074334299
22 5 0.8779649010536861
23 1 0.39281508239982266
23 3 0.7219740661973829
23 5 0.8757734094212878


 24%|██▍       | 24/100 [1:23:50<4:25:47, 209.84s/it]

23 1 0.39642943092929456
23 3 0.7226285046305423
23 5 0.8752401518522226
24 1 0.38526415927251656
24 3 0.720452654327065
24 5 0.8733486573161975


 25%|██▌       | 25/100 [1:27:21<4:22:40, 210.14s/it]

24 1 0.3828268345628145
24 3 0.7228266409101481
24 5 0.8692307916002335
25 1 0.38521261898179165
25 3 0.7223892034494648
25 5 0.874149947264344


 26%|██▌       | 26/100 [1:30:47<4:17:50, 209.06s/it]

25 1 0.3843645208356483
25 3 0.7293471732495509
25 5 0.8748352449294785
26 1 0.381945699136348
26 3 0.7185700329337024
26 5 0.8765655557217963


 27%|██▋       | 27/100 [1:34:14<4:13:26, 208.31s/it]

26 1 0.3695170442294796
26 3 0.7225889934080472
26 5 0.8785834283836932
27 1 0.38084314407270176
27 3 0.7177508121433073
27 5 0.8713801462688


 28%|██▊       | 28/100 [1:37:41<4:09:33, 207.96s/it]

27 1 0.3755618255354009
27 3 0.7164948342892488
27 5 0.867994281540041
28 1 0.3857877869576549
28 3 0.7231187594307289
28 5 0.8775029122070592


 29%|██▉       | 29/100 [1:41:07<4:05:20, 207.34s/it]

28 1 0.38746484955485483
28 3 0.7236069218769198
28 5 0.8767372803669744
29 1 0.3836057660531
29 3 0.7169043307646807
29 5 0.8728946455804806


 30%|███       | 30/100 [1:44:34<4:01:50, 207.29s/it]

29 1 0.38367041967368837
29 3 0.7188926877969849
29 5 0.8710106921414864
30 1 0.3822098620314254
30 3 0.7161348760227947
30 5 0.8743561760392154


 31%|███       | 31/100 [1:47:56<3:56:36, 205.74s/it]

30 1 0.3831015128054702
30 3 0.7140037035549454
30 5 0.8710464462739479
31 1 0.3857063988483347
31 3 0.7154560207173546
31 5 0.8714510867616249


 32%|███▏      | 32/100 [1:51:22<3:53:01, 205.61s/it]

31 1 0.37676994914099965
31 3 0.7110742065808
31 5 0.869031345668396
32 1 0.3779654573569768
32 3 0.7133010218860162
32 5 0.8710677920660717


 33%|███▎      | 33/100 [1:54:47<3:49:36, 205.62s/it]

32 1 0.3807602899400786
32 3 0.7075738862189364
32 5 0.8676465996067877
33 1 0.38317745215174054
33 3 0.7133054529726621
33 5 0.8708441137400206


 34%|███▍      | 34/100 [1:58:14<3:46:24, 205.83s/it]

33 1 0.3841968899936905
33 3 0.7116671827344211
33 5 0.8649857839758419
34 1 0.3779446842640757
34 3 0.7146506059103374
34 5 0.8681517512279183


 35%|███▌      | 35/100 [2:01:38<3:42:31, 205.41s/it]

34 1 0.38342971528824765
34 3 0.7139709077121212
34 5 0.8651861949529288
35 1 0.37965725626514174
35 3 0.7179952415933014
35 5 0.8702217890900094


 36%|███▌      | 36/100 [2:05:01<3:38:23, 204.74s/it]

35 1 0.3729909245195554
35 3 0.7173512576235255
35 5 0.8699546548206784
36 1 0.38653221810994914
36 3 0.7137812141148324
36 5 0.8716930754961754


 37%|███▋      | 37/100 [2:08:29<3:35:50, 205.57s/it]

36 1 0.38962162481337864
36 3 0.7099786645993665
36 5 0.8705271243863997
37 1 0.3800466081036877
37 3 0.7167622357937767
37 5 0.8722976214165831


 38%|███▊      | 38/100 [2:11:55<3:32:39, 205.79s/it]

37 1 0.38360877074599425
37 3 0.7192774354168584
37 5 0.8686647630343299
38 1 0.378526188457327
38 3 0.7097151042124822
38 5 0.8710682404788503


 39%|███▉      | 39/100 [2:15:21<3:29:14, 205.82s/it]

38 1 0.3810581339150242
38 3 0.704467043797446
38 5 0.8662019351750878
39 1 0.3791505142445169
39 3 0.713274425980578
39 5 0.8717523608923384


 40%|████      | 40/100 [2:18:47<3:25:53, 205.89s/it]

39 1 0.381258322003126
39 3 0.7114110680389123
39 5 0.8698238047715918
40 1 0.38315004883550496
40 3 0.70957857706991
40 5 0.8739113285183293


 41%|████      | 41/100 [2:22:15<3:23:04, 206.52s/it]

40 1 0.38103207514447557
40 3 0.7063957784468669
40 5 0.8710190988918098
41 1 0.37996269586392734
41 3 0.7139687589500656
41 5 0.8718292308176016


 42%|████▏     | 42/100 [2:25:43<3:20:12, 207.11s/it]

41 1 0.37912549940947526
41 3 0.7127001821152786
41 5 0.869518684846762
42 1 0.37254825478795855
42 3 0.7078974665507107
42 5 0.8698979047911715


 43%|████▎     | 43/100 [2:29:09<3:16:25, 206.77s/it]

42 1 0.3748598596977842
42 3 0.7079154927135275
42 5 0.8689391039098595
43 1 0.3818575205463893
43 3 0.7149025769665149
43 5 0.8741523522064076


 44%|████▍     | 44/100 [2:32:36<3:12:54, 206.69s/it]

43 1 0.3806429874920316
43 3 0.7139877519217754
43 5 0.872364741202292
44 1 0.38217711234834933
44 3 0.7110328766168239
44 5 0.8731742964722992


 45%|████▌     | 45/100 [2:36:01<3:09:03, 206.24s/it]

44 1 0.3790607114805931
44 3 0.7101960776602768
44 5 0.8736723232917789
45 1 0.3764237887762211
45 3 0.7103586416666275
45 5 0.8742776827891927


 46%|████▌     | 46/100 [2:39:26<3:05:15, 205.84s/it]

45 1 0.3760200500247392
45 3 0.7109603229942136
45 5 0.8722979320829839
46 1 0.3742527057891732
46 3 0.709223917834848
46 5 0.8675517845765602


 47%|████▋     | 47/100 [2:42:52<3:01:53, 205.92s/it]

46 1 0.3779712382693172
46 3 0.7061787317785898
46 5 0.8668615938270727
47 1 0.373715153255752
47 3 0.7059531632064326
47 5 0.8688262038395411


 48%|████▊     | 48/100 [2:46:18<2:58:23, 205.84s/it]

47 1 0.37249747617528384
47 3 0.701143575331323
47 5 0.8687216752474763
48 1 0.3773737831493294
48 3 0.7067572932297334
48 5 0.8682257607985928


 49%|████▉     | 49/100 [2:49:43<2:54:53, 205.76s/it]

48 1 0.37653976376286075
48 3 0.7067559280054448
48 5 0.8632316239183955
49 1 0.3752092285190204
49 3 0.7053793537129163
49 5 0.86724834829635


 50%|█████     | 50/100 [2:53:10<2:51:34, 205.89s/it]

49 1 0.37663342082330786
49 3 0.707012971251291
49 5 0.8648116716339261
50 1 0.3809976856062606
50 3 0.7102192600905745
50 5 0.8719248003945109


 51%|█████     | 51/100 [2:56:35<2:48:05, 205.84s/it]

50 1 0.38665342546507536
50 3 0.7127383954739884
50 5 0.8707332777917545
51 1 0.3773624082691862
51 3 0.7105463474439716
51 5 0.8698462505888227


 52%|█████▏    | 52/100 [2:59:58<2:43:52, 204.84s/it]

51 1 0.3729910286114342
51 3 0.7069412810884488
51 5 0.865071602241693
52 1 0.3753039208662643
52 3 0.7081302934784505
52 5 0.8681185138750337


 53%|█████▎    | 53/100 [3:03:24<2:40:43, 205.19s/it]

52 1 0.3789575993287441
52 3 0.70152190597326
52 5 0.8629447221206882
53 1 0.3676792379245922
53 3 0.703442955932083
53 5 0.8684308531641247


 54%|█████▍    | 54/100 [3:06:48<2:37:03, 204.87s/it]

53 1 0.37011533008279995
53 3 0.7029823318913986
53 5 0.8668548214001714
54 1 0.3692651637346985
54 3 0.7050394096176521
54 5 0.8666502678051623


 55%|█████▌    | 55/100 [3:10:15<2:34:05, 205.45s/it]

54 1 0.36428348146259737
54 3 0.7002228719346544
54 5 0.8616480948652985
55 1 0.3689712713510545
55 3 0.7010102299170461
55 5 0.8664762777915774


 56%|█████▌    | 56/100 [3:13:41<2:30:51, 205.72s/it]

55 1 0.3772447813325952
55 3 0.7024914530027512
55 5 0.8632570622717489
56 1 0.3735037793539117
56 3 0.7093836752613406
56 5 0.8695202673142802


 57%|█████▋    | 57/100 [3:17:07<2:27:35, 205.94s/it]

56 1 0.37705773821866473
56 3 0.7098180219685021
56 5 0.8669172135317914
57 1 0.3681102222427377
57 3 0.7074765466061018
57 5 0.8663109728560164


 58%|█████▊    | 58/100 [3:20:33<2:24:09, 205.94s/it]

57 1 0.3660917362520982
57 3 0.7096416490716229
57 5 0.8639035940724854
58 1 0.3793028915140154
58 3 0.7085297321694066
58 5 0.8698986042309493


 59%|█████▉    | 59/100 [3:23:57<2:20:15, 205.27s/it]

58 1 0.3820391848345664
58 3 0.7061899792786739
58 5 0.8635434583280149
59 1 0.373988329955571
59 3 0.7046222534040003
59 5 0.8692747959744046


 60%|██████    | 60/100 [3:27:23<2:16:56, 205.40s/it]

59 1 0.37780542004393264
59 3 0.7011760785690387
59 5 0.861968357859748
60 1 0.36981083152711064
60 3 0.7023295008656388
60 5 0.8685180432809667


 61%|██████    | 61/100 [3:30:49<2:13:33, 205.48s/it]

60 1 0.3723090105334006
60 3 0.6970740421225579
60 5 0.865022661219864


In [16]:
#load trained model
epoch = 23
lang_encoder.load_state_dict(torch.load(os.path.join(f'{mode}','lang_encoder_epoch{}.pth'.format(epoch))))
image_encoder.load_state_dict(torch.load(  os.path.join(f'{mode}','image_encoder_epoch{}.pth'.format(epoch))))
expansion_obj.load_state_dict(torch.load(os.path.join(f'{mode}','expansion_obj_epoch{}.pth'.format(epoch))))
expansion_word.load_state_dict(torch.load( os.path.join(f'{mode}','expansion_word_{}.pth'.format(epoch))))
merge1.load_state_dict(torch.load(os.path.join(f'{mode}','merge1_epoch{}.pth'.format(epoch))))
merge2.load_state_dict(torch.load( os.path.join(f'{mode}','merge2_{}.pth'.format(epoch))))

<All keys matched successfully>

In [20]:
    #eval recall on test set
    with torch.no_grad():
            ks = [1,3,5]
            recalls = {k:[] for k in ks}
            image_encoder.eval()
            lang_encoder.eval()
            expansion_obj.eval()
            expansion_word.eval()
            merge1.eval()
            merge2.eval()
            with autocast():
                losses = 0
                for batch_idx in range(len(test_dataset)):
                    batch = valid_dataset.__getitem__(batch_idx)
                    vid, visn_feats, encode, temporal_label, label = batch['vid'], batch["visn_feats"],batch['encode'], batch['temporal_label'],batch['label']
                    frame_feats, box_feats = visn_feats
                    foodname = batch['query']
                    temporal_label = torch.tensor(temporal_label).long().cuda()
                    label = torch.tensor(label).long().cuda()
                    visn_feats = torch.tensor(frame_feats).float().cuda(), torch.tensor(box_feats).float().cuda()
                    seq_output, pooled_output = image_encoder(visual_feats = visn_feats)
                    encode = encode.cuda()
                    output = lang_encoder(encode)
                    sequence_output, lang_pooled_output = output[0], output[1]
                    expanded_word = expansion_word(pooled_output, vid = vid)
                    concat1 = torch.cat(( pooled_output, expanded_word), dim = -1)
                    concat1 = merge1(concat1)
                    expanded_obj = expansion_obj(lang_pooled_output, vid = vid)
                    concat2 = torch.cat((lang_pooled_output, do_output2), dim = -1)
                    concat2 = merge2(concat2)
                    dotoutput = torch.matmul(concat1, concat2.transpose(1,0))
                    label = list(label.detach().cpu().numpy())
                    for k in ks:
                        pred = np.argsort(-1*dotoutput.detach().cpu().numpy(), axis = -1)[:,:k].squeeze()
                        recallatk = 0
                        examples = 0
                        for i, gt in enumerate(label):
                            if k > 1:
                                if gt == -1:
                                    continue
                                if gt in pred[i]:
                                    recallatk +=1
                                examples += 1
                            else:
                                if gt == -1:
                                    continue
                                if gt == pred[i]:
                                    recallatk +=1
                                examples += 1
                        recallatk = recallatk/examples
                        recalls[k].append(recallatk)
                for k in ks:
                    print(epoch, k, sum(recalls[k])/len(recalls[k]))

  app.launch_new_instance()


23 1 0.39642943092929456
23 3 0.7226285046305423
23 5 0.8752401518522226


In [21]:
import os
import pickle as pkl
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from statistics import mean
from tqdm import tqdm
import csv
import sys
import time
import numpy as np
import base64
cook2_IVD_dir = "/workspace/evidence_retrieval/COOK2_IVD"


with open(os.path.join(cook2_IVD_dir, "pkl/{}.pkl".format("train")), "rb") as f:
    vid_pkl = pkl.load(f)
    
csv.field_size_limit(sys.maxsize)
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
              "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]


def load_obj_tsv(fname, topk=None):
    """Load object features from tsv file.
    :param fname: The path to the tsv file.
    :param topk: Only load features for top K images (lines) in the tsv file.
        Will load all the features if topk is either -1 or None.
    :return: A list of image object features where each feature is a dict.
        See FILENAMES above for the keys in the feature dict.
    """
    data = []
    start_time = time.time()
    #print("Start to load Faster-RCNN detected objects from %s" % fname)
    with open(fname) as f:
        reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
        for i, item in enumerate(reader):

            for key in ['img_h', 'img_w', 'num_boxes']:
                item[key] = int(item[key])
            
            boxes = item['num_boxes']
            decode_config = [
                ('objects_id', (boxes, ), np.int64),
                ('objects_conf', (boxes, ), np.float32),
                ('attrs_id', (boxes, ), np.int64),
                ('attrs_conf', (boxes, ), np.float32),
                ('boxes', (boxes, 4), np.float32),
                ('features', (boxes, -1), np.float32),
            ]
            for key, shape, dtype in decode_config:
                item[key] = np.frombuffer(base64.b64decode(item[key]), dtype=dtype)
                item[key] = item[key].reshape(shape)
                item[key].setflags(write=False)

            data.append(item)
            if topk is not None and len(data) == topk:
                break
    elapsed_time = time.time() - start_time
    #print("Loaded %d images in file %s in %d seconds." % (len(data), fname, elapsed_time))
    return data

class segmentDataset(Dataset):
    """Segment dataset."""
    def __init__(self, cook2_IVD_dir, mode = "train"):
        with open(os.path.join(cook2_IVD_dir, "pkl/{}.pkl".format(mode)), "rb") as f:
            vid_pkl = pkl.load(f)
        
        self.vid_list = list(vid_pkl.keys())
        self.vid_pkl = vid_pkl

        self.max_frame_len = 552 ### IF you change this, segment_collate also needs to change !!!!!!
    def __len__(self):
        return len(self.vid_pkl)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        vid = self.vid_list[idx]
        max_frame_len = self.max_frame_len
        
        index_transcript = self.vid_pkl[vid]["index_trainscirpt"]
        segments_ = self.vid_pkl[vid]["segments"]
        
        trans_idx = torch.tensor([idx for idx, trans in index_transcript])
        trans = [trans for idx, trans in index_transcript]
        
        segs_idx = torch.tensor([(idx_start, idx_end) for idx_start, idx_end, segs in segments_])
        segs = [segs for idx_start, idx_end, segs in segments_]
        
        begin_idxs = list( int(b.cpu().tolist()) for b, e in segs_idx)
        end_idxs = list( int(e.cpu().tolist()) for b, e in segs_idx)

        
        query = self.vid_pkl[vid]["query"]
        frame_info = load_obj_tsv(os.path.join(cook2_IVD_dir, "features/{}_obj36.tsv".format(vid)))
        
        ## Make beign target  ( EX) [0, 0, 0, 1, ..... ,0, 1, 0, 0]
        begin_target = []
        end_target = []
        
        begin_distance = []
        end_distance = []
        
        begin_ratio = []
        end_ratio = []
        
        
        mask = []
        
        for idx in range(max_frame_len):
            __flag__ = 0
            target_b = -1
            target_e = -1
            middle = -1
            for b, e in segs_idx:
                if(idx >= b and idx <= e): __flag__ = 1; target_b = b; target_e = e; middle = (b + e) / 2.0
                    
            
            if(__flag__): 
                begin_target.append(1.); end_target.append(1.);
                begin_distance.append(abs(target_b - idx)); end_distance.append(abs(target_e - idx)) 
                begin_ratio.append(torch.sigmoid(middle - idx)); end_ratio.append(torch.sigmoid(idx - middle))
                
                
            else: 
                begin_target.append(0.); end_target.append(0.)
                begin_distance.append(999.); end_distance.append(999.)
                begin_ratio.append(0.); end_ratio.append(0.)
                
        
            if(idx < len(frame_info)): 
                mask.append(True)
            else: 
                mask.append(False)
        """
        for idx in range(max_frame_len):
            if(idx in begin_idxs): begin_target.append(1.)
            else: begin_target.append(0.)
                
            if(idx in end_idxs): end_target.append(1.)
            else: end_target.append(0.)
                
        
            if(idx < len(frame_info)): mask.append(True)
            else: mask.append(False)
        """

        
        
        
        return {"vid": vid, "trans_idx": trans_idx, "trans": trans, "segs_idx": segs_idx, "segs": segs, "query": query, "frame_info": frame_info, \
               "begin_target": begin_target, "end_target": end_target, "mask": mask,\
               "begin_distance": begin_distance, "end_distance": end_distance, "begin_ratio": begin_ratio, "end_ratio": end_ratio}
    
    
def segment_collate(samples):
    vid, query = [], []
    trans_idx, trans = [], []
    segs_idx, segs = [], []
    frame_info = []
    
    begin_target, end_target = [], []
    begin_distance, end_distance =  [], []
    begin_ratio, end_ratio = [], []
    mask = []
    
    max_trans_len = max( list(len(s["trans_idx"]) for s in samples))
    max_segs_len = max( list(len(s["segs_idx"]) for s in samples))
    #max_frame_len = max(list(len(s["frame_info"]) for s in samples))
    max_frame_len = 552

    for sample in samples:
        vid.append(sample["vid"])
        query.append(sample["query"])
        
        trans_idx.append(sample["trans_idx"])
        trans.append(sample["trans"] + ["<PAD>"] * (max_trans_len - len(sample["trans"])) )
        
        segs_idx.append(sample["segs_idx"])
        segs.append(sample["segs"] + ["<PAD>"] * (max_segs_len - len(sample["segs"])) )
        begin_target.append(sample["begin_target"])
        end_target.append(sample["end_target"])
        
        begin_distance.append(sample["begin_distance"]); end_distance.append(sample["end_distance"])
        begin_ratio.append(sample["begin_ratio"]); end_ratio.append(sample["end_ratio"])
        mask.append(sample["mask"])
        
        frame_info.append(sample["frame_info"] + ["<PAD>"] * (max_frame_len - len(sample["frame_info"])))
        
    padded_trans_idx = torch.nn.utils.rnn.pad_sequence(trans_idx, batch_first=True, padding_value = -4444).contiguous()
    padded_trans = trans 
    padded_segs_idx = torch.nn.utils.rnn.pad_sequence(segs_idx, batch_first=True, padding_value = -4444).contiguous()
    padded_segs = segs 


    return {"vid": vid, "trans_idx": padded_trans_idx, "trans": padded_trans, \
            "segs_idx": padded_segs_idx, "segs": padded_segs, "query": query, "frame_info": frame_info,\
           "begin_target": torch.tensor(begin_target), "end_target": torch.tensor(end_target), "mask": torch.tensor(mask),\
           "begin_distance": begin_distance, "end_distance": end_distance, "begin_ratio": begin_ratio, "end_ratio": end_ratio}

In [22]:

segment_dataset = segmentDataset(cook2_IVD_dir)
valid_dataset = segmentDataset(cook2_IVD_dir, mode = 'test')


In [23]:
dataloader = DataLoader(segment_dataset, collate_fn = segment_collate, batch_size = 1, shuffle = True, num_workers = 4 )
validloader = DataLoader(valid_dataset, collate_fn = segment_collate, batch_size = 1, num_workers = 4)

In [24]:
with open('yc2_recipes.json', 'r') as fp:
    recipes = json.load(fp)

In [25]:
    sampling_num = 3
    matched_recipe = {}
    with torch.no_grad():
        image_encoder.eval()
        lang_encoder.eval()
        with autocast():
            losses = 0
            for i, batch in tqdm(enumerate(dataloader), total = len(dataloader)):
                maxlen = max([len(i) for i in batch['frame_info']])
                annot, segs_idx, item =batch['segs'][0],batch['segs_idx'][0],batch['frame_info'][0]
                query = batch['query'][0]
                vid = batch['vid'][0]
                vid_recipe = recipes[query]
                flatten_recipe = []
                for recipe in vid_recipe:
                    for recipe_doc in recipe['split_ins']:
                        if recipe_doc!= '':
                            flatten_recipe.append(recipe_doc)
                frame_feats = []
                box_feats = []
                frame_indices = []
                for frame_idx in range(len(item)//sampling_num):
                    frame_idx = sampling_num*frame_idx+1
                    if item[frame_idx] != '<PAD>':
                        frame_feats.append(item[frame_idx]['features'])
                        box_feats.append(item[frame_idx]['boxes'])
                        frame_indices.append(frame_idx)
                    else:
                        break
                visn_feats = torch.tensor(frame_feats).float().cuda(), torch.tensor(box_feats).float().cuda()
                seq_output, pooled_output = image_encoder(visual_feats = visn_feats)
                encode = [tokenizer.encode(sent) for sent in flatten_recipe]
                maxlen = max([len(e) for e in encode])
                encode = torch.tensor([e+[tokenizer.pad_token_id]*(maxlen-len(e)) for e in encode]).cuda()
                output = lang_encoder(encode)
                sequence_output, lang_pooled_output = output[0], output[1]
                expanded_word = expansion_word(pooled_output, vid = vid)
                concat1 = torch.cat(( pooled_output, expanded_word), dim = -1)
                concat1 = merge1(concat1)
                expanded_obj = expansion_obj(lang_pooled_output, vid = vid)
                concat2 = torch.cat((lang_pooled_output, do_output2), dim = -1)
                concat2 = merge2(concat2)
                dotoutput = torch.matmul(concat1, concat2.transpose(1,0))
                pred = np.argmax(-1*dotoutput.detach().cpu().numpy(), axis = -1)
                vid_matched_recipes = {frame_indices[j]:flatten_recipe[pred[j]] for j in range(len(pred))}
                matched_recipe[vid] = vid_matched_recipes
            for i, batch in tqdm(enumerate(validloader), total = len(validloader)):
                maxlen = max([len(i) for i in batch['frame_info']])
                annot, segs_idx, item =batch['segs'][0],batch['segs_idx'][0],batch['frame_info'][0]
                query = batch['query'][0]
                vid = batch['vid'][0]
                vid_recipe = recipes[query]
                flatten_recipe = []
                for recipe in vid_recipe:
                    for recipe_doc in recipe['split_ins']:
                        if recipe_doc!= '':
                            flatten_recipe.append(recipe_doc)
                frame_feats = []
                box_feats = []
                frame_indices = []
                for frame_idx in range(len(item)//sampling_num):
                    frame_idx = sampling_num*frame_idx+1
                    if item[frame_idx] != '<PAD>':
                        frame_feats.append(item[frame_idx]['features'])
                        box_feats.append(item[frame_idx]['boxes'])
                        frame_indices.append(frame_idx)
                    else:
                        break
                visn_feats = torch.tensor(frame_feats).float().cuda(), torch.tensor(box_feats).float().cuda()
                seq_output, pooled_output = image_encoder(visual_feats = visn_feats)
                encode = [tokenizer.encode(sent) for sent in flatten_recipe]
                maxlen = max([len(e) for e in encode])
                encode = torch.tensor([e+[tokenizer.pad_token_id]*(maxlen-len(e)) for e in encode]).cuda()
                output = lang_encoder(encode)
                sequence_output, lang_pooled_output = output[0], output[1]
                expanded_word = expansion_word(pooled_output, vid = vid)
                concat1 = torch.cat(( pooled_output, expanded_word), dim = -1)
                concat1 = merge1(concat1)
                expanded_obj = expansion_obj(lang_pooled_output, vid = vid)
                concat2 = torch.cat((lang_pooled_output, do_output2), dim = -1)
                concat2 = merge2(concat2)
                dotoutput = torch.matmul(concat1, concat2.transpose(1,0))
                pred = np.argmax(-1*dotoutput.detach().cpu().numpy(), axis = -1)
                vid_matched_recipes = {frame_indices[j]:flatten_recipe[pred[j]] for j in range(len(pred))}
                matched_recipe[vid] = vid_matched_recipes

100%|██████████| 1090/1090 [16:57<00:00,  1.07it/s]
100%|██████████| 267/267 [04:26<00:00,  1.00it/s]


In [26]:
with open('dual_encoder_final_matched_result.json', 'w') as fp:
    json.dump(matched_recipe, fp)

In [28]:
with open('dual_encoder_final_matched_result.pkl', 'wb') as fp:
    pkl.dump(matched_recipe, fp)