In [1]:
#| default_exp 30_ngame-for-msmarco-with-hard-negatives

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.PPP0XX import DBT009,DBT011

In [4]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco'

## Setup

In [6]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/30_ngame-for-msmarco-with-hard-negatives'
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
config_key = 'data_entity-gpt_exact'

mname = 'sentence-transformers/msmarco-distilbert-dot-v5'

In [7]:
do_train_inference = False
do_test_inference = False

save_train_inference = False
save_test_inference = False

save_representation = False

use_sxc_sampler, only_test = True, False

In [8]:
pkl_file = f'{pkl_dir}/mogicX/msmarco_data-meta_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
if only_test: pkl_file = f'{pkl_file}_only-test'
pkl_file = f'{pkl_file}.joblib'

In [9]:
do_inference = do_train_inference or do_test_inference or save_train_inference or save_test_inference or save_representation

In [10]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, do_build=False, only_test=False)

CPU times: user 14 s, sys: 1.99 s, total: 16 s
Wall time: 16.4 s


### `Negatives`

In [45]:
import pickle, scipy.sparse as sp, numpy as np
from tqdm.auto import tqdm
from typing import Optional, List

from sugar.core import *

In [19]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/negatives"
fname = f"{data_dir}/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl"

In [20]:
with open(fname, 'rb') as file:
    hard_negatives = pickle.load(file)

In [22]:
trn_ids = [int(i) for i in block.train.dset.data.data_info['identifier']]
tst_ids = [int(i) for i in block.test.dset.data.data_info['identifier']]

In [42]:
def load_msmarco_hard_negatives(fname:str, data_ids:Optional[List]=None):
    with open(fname, 'rb') as file:
        negatives = pickle.load(file)

    data_ids = list(negatives) if data_ids is None else data_ids

    lbl_id2idx = dict()
    data, indices, indptr = [], [], [0]
    for idx in tqdm(data_ids):
        if idx in negatives:
            data.extend(list(negatives[idx].values()))
            for i in negatives[idx]:
                index = lbl_id2idx.setdefault(i, len(lbl_id2idx))
                indices.append(index)
        indptr.append(len(data))

    lbl_ids = sorted(lbl_id2idx, key=lambda x: lbl_id2idx[x])
    return data_ids, lbl_ids, sp.csr_matrix((data, indices, indptr), dtype=np.float32)
    

In [44]:
data_ids, neg_ids, data_neg = load_msmarco_hard_negatives(fname, trn_ids)

  0%|          | 0/502939 [00:00<?, ?it/s]

In [47]:
lbl_neg = sp.csr_matrix((block.n_lbl, data_neg.shape[1]), dtype=np.float32)

In [54]:
sp.save_npz(f'{data_dir}/negatives_trn_X_Y.npz', data_neg)
sp.save_npz(f'{data_dir}/negatives_lbl_X_Y_exact.npz', lbl_neg)

In [55]:
fname = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/raw_data/label.raw.txt'
lbl_ids, lbl_txt = load_raw_file(fname)
lbl_map = {k:v for k,v in zip(lbl_ids, lbl_txt)}

In [57]:
neg_txt = [lbl_map[str(i)] for i in neg_ids]

In [64]:
save_raw_file(f'{data_dir}/raw_data/negatives.raw.txt', neg_ids, neg_txt)

### `Block`

In [12]:
config_file = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/configs/negatives_exact.json"
config_key = 'data_negatives_exact'

In [13]:
pkl_file = get_pkl_file(pkl_dir, 'msmarco_data-neg_distilbert-base-uncased', True, True)

In [14]:
pkl_file

'/scratch/scai/phd/aiz218323/datasets/processed//msmarco_data-neg_distilbert-base-uncased_sxc_exact.joblib'

In [None]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, do_build=True, only_test=False, n_sdata_meta_samples=10)

In [None]:
o = block.train.dset.__getitems__([10, 20, 30, 40])    

In [None]:
o

### `Model`

In [15]:
#| export
from typing import Optional
from xcai.models.PPP0XX import XCModelOutput

In [16]:
#| export
class DBT021(DBT009):

    def __init__(
        self,
        config,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,
        n_negatives:Optional[int]=10,
        **kwargs
    ):
        super().__init__(config, margin=margin, tau=tau, apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        self.loss_fn = None
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,

        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        
        neg2data_data2ptr:Optional[torch.Tensor]=None,
        neg2data_idx:Optional[torch.Tensor]=None,
        neg2data_input_ids:Optional[torch.Tensor]=None,
        neg2data_attention_mask:Optional[torch.Tensor]=None,
        
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            if neg2data_input_ids is not None:
                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
                                                    output_attentions=output_attentions, 
                                                    output_hidden_states=output_hidden_states,
                                                    return_dict=return_dict)
                loss = None
                
        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### `loss`

In [20]:
from xcai.losses import MultiTripletFromScores

In [19]:
class MultiTripletWithNegatives(MultiTripletFromScores):

    def get_scores(self, inp:torch.Tensor, targ:torch.Tensor, ntarg:Optional[torch.Tensor]=None, n_inp2ntarg:Optional[int]=None):
        scores = inp @ targ.T
        if ntarg is not None:
            nscores = inp.unsqueeze(1) @ ntarg.view(len(inp), n_inp2ntarg, -1).transpose(1, 2)
            nscores = nscores.squeeze(1)
            scores = torch.hstack([scores, nscores])
        return scores

    def get_indices(self, inp2targ_idx:torch.Tensor, bsz:int, inp2ntarg_idx:Optional[torch.Tensor]=None, n_inp2ntarg:Optional[int]=None):
        inp2targ_idx = torch.repeat_interleave(inp2targ_idx.unsqueeze(0), bsz, 0)
        if inp2ntarg_idx is not None:
            inp2ntarg_idx = inp2ntarg_idx.view(bsz, n_inp2ntarg)
            inp2targ_idx = torch.hstack([inp2targ_idx, inp2ntarg_idx])
        return inp2targ_idx

    def forward(
        self, 
        inp:torch.FloatTensor,
        
        targ:torch.FloatTensor,
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,

        ntarg:torch.FloatTensor,
        n_inp2ntarg:torch.LongTensor,
        inp2ntarg_idx:torch.LongTensor,
        
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        assert torch.all(n_inp2ntarg == n_inp2ntarg.max()), "All datapoints should same number of negatives"
        scores = self.get_scores(inp, targ, ntarg, n_inp2ntarg.max())
        inp2targ_idx = self.get_indices(inp2targ_idx, len(inp), inp2ntarg_idx, n_inp2ntarg.max())
        return super().forward(scores, inp2targ_idx, n_pinp2targ=n_pinp2targ, pinp2targ_idx=pinp2targ_idx)
    

### Training

In [62]:
args = XCLearningArguments(
    output_dir=output_dir,
    logging_first_step=True,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=10,
    save_steps=10,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,                                                                                                                                          warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-5,

    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',

    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
)

In [63]:
metric = PrecReclMrr(block.test.dset.n_lbl, block.test.data_lbl_filterer,
                     pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

In [64]:
learn = XCLearner(
    model=model,
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [65]:
dataloader = learn.get_train_dataloader()
batch = next(iter(dataloader))

In [71]:
batch.keys()

dict_keys(['data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'neg2data_idx', 'neg2data_data2ptr', 'neg2data_input_ids', 'neg2data_attention_mask'])

In [75]:
batch = batch.to(model.device)

In [74]:
def func():
    import pdb; pdb.set_trace()
    o = model(**batch)
    

In [76]:
func()

> [0;32m/tmp/ipykernel_19285/3800815954.py[0m(3)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mbatch[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(59)[0;36mforward[0;34m()[0m
[0;32m     57 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m    ):
[0m[0;32m---> 59 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(61)[0;36mforward[0;34m()[0m
[0;32m     59 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m---> 61 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m            [0mencoder[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     63 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(63)[0;36mforward[0;34m()[0m
[0;32m     61 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m            [0mencoder[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 63 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m     65 [0;31m        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
[0m[0;32m     66 [0;31m                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(66)[0;36mforward[0;34m()[0m
[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m     65 [0;31m        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
[0m[0;32m---> 66 [0;31m                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m                                    return_dict=return_dict)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(67)[0;36mforward[0;34m()[0m
[0;32m     65 [0;31m        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
[0m[0;32m     66 [0;31m                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 67 [0;31m                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m                                    return_dict=return_dict)
[0m[0;32m     69 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m                                    return_dict=return_dict)
[0m[0;32m     69 [0;31m[0;34m[0m[0m
[0m[0;32m     70 [0;31m        [0mloss[0m[0;34m,[0m [0mlbl2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
[0m[0;32m     66 [0;31m                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(70)[0;36mforward[0;34m()[0m
[0;32m     68 [0;31m                                    return_dict=return_dict)
[0m[0;32m     69 [0;31m[0;34m[0m[0m
[0m[0;32m---> 70 [0;31m        [0mloss[0m[0;34m,[0m [0mlbl2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
[0m


ipdb>  data_repr.shape


torch.Size([20, 768])


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(71)[0;36mforward[0;34m()[0m
[0;32m     69 [0;31m[0;34m[0m[0m
[0m[0;32m     70 [0;31m        [0mloss[0m[0;34m,[0m [0mlbl2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 71 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
[0m[0;32m     73 [0;31m                                                [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m        [0mloss[0m[0;34m,[0m [0mlbl2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
[0m[0;32m     73 [0;31m                                                [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m                                                [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
[0m[0;32m---> 73 [0;31m                                                [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m                                                [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m                                                return_dict=return_dict)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(74)[0;36mforward[0;34m()[0m
[0;32m     72 [0;31m            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
[0m[0;32m     73 [0;31m                                                [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 74 [0;31m                                                [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m                                                return_dict=return_dict)
[0m[0;32m     76 [0;31m            [0mneg2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m                                                [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m                                                [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m                                                return_dict=return_dict)
[0m[0;32m     76 [0;31m            [0mneg2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m            [0;32mif[0m [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m        [0mloss[0m[0;34m,[0m [0mlbl2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
[0m[0;32m     73 [0;31m                                                [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m                                                [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(76)[0;36mforward[0;34m()[0m
[0;32m     74 [0;31m                                                [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m                                                return_dict=return_dict)
[0m[0;32m---> 76 [0;31m            [0mneg2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m            [0;32mif[0m [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(77)[0;36mforward[0;34m()[0m
[0;32m     75 [0;31m                                                return_dict=return_dict)
[0m[0;32m     76 [0;31m            [0mneg2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 77 [0;31m            [0;32mif[0m [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
[0m[0;32m     79 [0;31m                                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(78)[0;36mforward[0;34m()[0m
[0;32m     76 [0;31m            [0mneg2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m            [0;32mif[0m [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 78 [0;31m                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
[0m[0;32m     79 [0;31m                                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m                                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(79)[0;36mforward[0;34m()[0m
[0;32m     77 [0;31m            [0;32mif[0m [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
[0m[0;32m---> 79 [0;31m                                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m                                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m                                                    return_dict=return_dict)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(80)[0;36mforward[0;34m()[0m
[0;32m     78 [0;31m                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
[0m[0;32m     79 [0;31m                                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 80 [0;31m                                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m                                                    return_dict=return_dict)
[0m[0;32m     82 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m                                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m                                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m                                                    return_dict=return_dict)
[0m[0;32m     82 [0;31m[0;34m[0m[0m
[0m[0;32m     83 [0;31m                [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mneg2data_data2ptr[0m [0;34m==[0m [0mneg2data_data2ptr[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have equal negatives'[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(78)[0;36mforward[0;34m()[0m
[0;32m     76 [0;31m            [0mneg2data_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m            [0;32mif[0m [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 78 [0;31m                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
[0m[0;32m     79 [0;31m                                                    [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m                                                    [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(83)[0;36mforward[0;34m()[0m
[0;32m     81 [0;31m                                                    return_dict=return_dict)
[0m[0;32m     82 [0;31m[0;34m[0m[0m
[0m[0;32m---> 83 [0;31m                [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mneg2data_data2ptr[0m [0;34m==[0m [0mneg2data_data2ptr[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have equal negatives'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m            [0mscores[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_scores[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mlbl2data_repr[0m[0;34m,[0m [0mneg2data_repr[0m[0;34m)[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_get_indices[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(85)[0;36mforward[0;34m()[0m
[0;32m     83 [0;31m                [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mneg2data_data2ptr[0m [0;34m==[0m [0mneg2data_data2ptr[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have equal negatives'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m---> 85 [0;31m            [0mscores[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_scores[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mlbl2data_repr[0m[0;34m,[0m [0mneg2data_repr[0m[0;34m)[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_get_indices[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     86 [0;31m            [0mloss[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mloss_fn[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0midx[0

ipdb>  data_repr.shape


torch.Size([20, 768])


ipdb>  lbl2data_repr.shape


torch.Size([20, 768])


ipdb>  neg2data_repr.shape


torch.Size([200, 768])


ipdb>  s


--Call--
> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(16)[0;36m_get_scores[0;34m()[0m
[0;32m     14 [0;31m                                                 reduce='mean')
[0m[0;32m     15 [0;31m[0;34m[0m[0m
[0m[0;32m---> 16 [0;31m    [0;32mdef[0m [0m_get_scores[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mlbl2data_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_repr[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m        [0mbsz[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m        [0mn_meta[0m [0;34m=[0m [0mneg2data_repr[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m//[0m 

ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(17)[0;36m_get_scores[0;34m()[0m
[0;32m     15 [0;31m[0;34m[0m[0m
[0m[0;32m     16 [0;31m    [0;32mdef[0m [0m_get_scores[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mlbl2data_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_repr[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 17 [0;31m        [0mbsz[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m        [0mn_meta[0m [0;34m=[0m [0mneg2data_repr[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m//[0m [0mbsz[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(18)[0;36m_get_scores[0;34m()[0m
[0;32m     16 [0;31m    [0;32mdef[0m [0m_get_scores[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mlbl2data_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_repr[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m        [0mbsz[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 18 [0;31m        [0mn_meta[0m [0;34m=[0m [0mneg2data_repr[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m//[0m [0mbsz[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m[0;34m[0m[0m
[0m[0;32m     20 [0;31m        [0mlbl_scores[0m [0;34m=

ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(20)[0;36m_get_scores[0;34m()[0m
[0;32m     18 [0;31m        [0mn_meta[0m [0;34m=[0m [0mneg2data_repr[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m//[0m [0mbsz[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m[0;34m[0m[0m
[0m[0;32m---> 20 [0;31m        [0mlbl_scores[0m [0;34m=[0m [0mdata_repr[0m [0;34m@[0m [0mlbl2data_repr[0m[0;34m.[0m[0mT[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m     22 [0;31m        [0mneg_scores[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n_meta


10


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(22)[0;36m_get_scores[0;34m()[0m
[0;32m     20 [0;31m        [0mlbl_scores[0m [0;34m=[0m [0mdata_repr[0m [0;34m@[0m [0mlbl2data_repr[0m[0;34m.[0m[0mT[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m---> 22 [0;31m        [0mneg_scores[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m        [0;32mif[0m [0mneg2data_repr[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m@[0m [0mneg2data_repr[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  lbl_scores.shape


torch.Size([20, 20])


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(23)[0;36m_get_scores[0;34m()[0m
[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m     22 [0;31m        [0mneg_scores[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 23 [0;31m        [0;32mif[0m [0mneg2data_repr[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m@[0m [0mneg2data_repr[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mneg_scores[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(24)[0;36m_get_scores[0;34m()[0m
[0;32m     22 [0;31m        [0mneg_scores[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m        [0;32mif[0m [0mneg2data_repr[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 24 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m@[0m [0mneg2data_repr[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mneg_scores[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m


ipdb>  neg2data_repr.view(bsz, n_meta, -1).transpose(1, 2).shape


torch.Size([20, 768, 10])


ipdb>  data_repr.unsqueeze(1).shape


torch.Size([20, 1, 768])


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(25)[0;36m_get_scores[0;34m()[0m
[0;32m     23 [0;31m        [0;32mif[0m [0mneg2data_repr[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m@[0m [0mneg2data_repr[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mneg_scores[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m[0;32m     27 [0;31m        [0;32mreturn[0m [0mlbl_scores[0m [0;32mif[0m [0mneg_scores[0m [0;32mis[0m [0

ipdb>  neg_scores.shape


torch.Size([20, 1, 10])


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(27)[0;36m_get_scores[0;34m()[0m
[0;32m     25 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mneg_scores[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m[0;32m---> 27 [0;31m        [0;32mreturn[0m [0mlbl_scores[0m [0;32mif[0m [0mneg_scores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mtorch[0m[0;34m.[0m[0mhstack[0m[0;34m([0m[0;34m[[0m[0mlbl_scores[0m[0;34m,[0m [0mneg_scores[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     28 [0;31m[0;34m[0m[0m
[0m[0;32m     29 [0;31m    [0;32mdef[0m [0m_get_indices[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlbl2data_idx[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[

ipdb>  neg_scores.shape


torch.Size([20, 10])


ipdb>  n


--Return--
tensor([[0.80...CatBackward0>)
> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(27)[0;36m_get_scores[0;34m()[0m
[0;32m     25 [0;31m            [0mneg_scores[0m [0;34m=[0m [0mneg_scores[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m[0;32m---> 27 [0;31m        [0;32mreturn[0m [0mlbl_scores[0m [0;32mif[0m [0mneg_scores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mtorch[0m[0;34m.[0m[0mhstack[0m[0;34m([0m[0;34m[[0m[0mlbl_scores[0m[0;34m,[0m [0mneg_scores[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     28 [0;31m[0;34m[0m[0m
[0m[0;32m     29 [0;31m    [0;32mdef[0m [0m_get_indices[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlbl2data_idx[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m

ipdb>  n


--Call--
> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(29)[0;36m_get_indices[0;34m()[0m
[0;32m     27 [0;31m        [0;32mreturn[0m [0mlbl_scores[0m [0;32mif[0m [0mneg_scores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mtorch[0m[0;34m.[0m[0mhstack[0m[0;34m([0m[0;34m[[0m[0mlbl_scores[0m[0;34m,[0m [0mneg_scores[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     28 [0;31m[0;34m[0m[0m
[0m[0;32m---> 29 [0;31m    [0;32mdef[0m [0m_get_indices[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlbl2data_idx[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m        [0mbsz[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0mn_

ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(30)[0;36m_get_indices[0;34m()[0m
[0;32m     28 [0;31m[0;34m[0m[0m
[0m[0;32m     29 [0;31m    [0;32mdef[0m [0m_get_indices[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlbl2data_idx[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m        [0mbsz[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0mn_meta[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mneg2data_idx[0m[0;34m)[0m [0;34m//[0m [0mbsz[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(31)[0;36m_get_indices[0;34m()[0m
[0;32m     29 [0;31m    [0;32mdef[0m [0m_get_indices[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlbl2data_idx[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m        [0mbsz[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 31 [0;31m        [0mn_meta[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mneg2data_idx[0m[0;34m)[0m [0;34m//[0m [0mbsz[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m[0;34m[0m[0m
[0m[0;32m     33 [0;31m        [0mlbl_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mrepeat_interleave[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36

ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(33)[0;36m_get_indices[0;34m()[0m
[0;32m     31 [0;31m        [0mn_meta[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mneg2data_idx[0m[0;34m)[0m [0;34m//[0m [0mbsz[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m[0;34m[0m[0m
[0m[0;32m---> 33 [0;31m        [0mlbl_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mrepeat_interleave[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m,[0m [0mbsz[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     34 [0;31m        [0mneg_idx[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mneg2data_idx[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mneg2data_idx[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     35 [0;31m[0;34m[0m[0m
[0m


ipdb>  n_meta


10


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(34)[0;36m_get_indices[0;34m()[0m
[0;32m     32 [0;31m[0;34m[0m[0m
[0m[0;32m     33 [0;31m        [0mlbl_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mrepeat_interleave[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m,[0m [0mbsz[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 34 [0;31m        [0mneg_idx[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mneg2data_idx[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mneg2data_idx[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     35 [0;31m[0;34m[0m[0m
[0m[0;32m     36 [0;31m        [0;32mreturn[0m [0mlbl_idx[0m [0;32mif[0m [0mneg_idx[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mtorch[0m[0;34m.[0m[0mhstack[0m[0;34m([0m[0;34m[[0m[0mlbl_idx[0m[0;34m,[0m [0mneg_id

ipdb>  lbl_idx.shape


torch.Size([20, 20])


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(36)[0;36m_get_indices[0;34m()[0m
[0;32m     34 [0;31m        [0mneg_idx[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mneg2data_idx[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mneg2data_idx[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     35 [0;31m[0;34m[0m[0m
[0m[0;32m---> 36 [0;31m        [0;32mreturn[0m [0mlbl_idx[0m [0;32mif[0m [0mneg_idx[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mtorch[0m[0;34m.[0m[0mhstack[0m[0;34m([0m[0;34m[[0m[0mlbl_idx[0m[0;34m,[0m [0mneg_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m[0;34m[0m[0m
[0m[1;31m1[0;32m    38 [0;31m    def forward(
[0m


ipdb>  neg_idx.shape


torch.Size([20, 10])


ipdb>  n


--Return--
tensor([[ 451...vice='cuda:0')
> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(36)[0;36m_get_indices[0;34m()[0m
[0;32m     34 [0;31m        [0mneg_idx[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mneg2data_idx[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mneg2data_idx[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mn_meta[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     35 [0;31m[0;34m[0m[0m
[0m[0;32m---> 36 [0;31m        [0;32mreturn[0m [0mlbl_idx[0m [0;32mif[0m [0mneg_idx[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mtorch[0m[0;34m.[0m[0mhstack[0m[0;34m([0m[0;34m[[0m[0mlbl_idx[0m[0;34m,[0m [0mneg_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m[0;34m[0m[0m
[0m[1;31m1[0;32m    38 [0;31m    def forward(
[0m


ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(86)[0;36mforward[0;34m()[0m
[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m            [0mscores[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_scores[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mlbl2data_repr[0m[0;34m,[0m [0mneg2data_repr[0m[0;34m)[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_get_indices[0m[0;34m([0m[0mlbl2data_idx[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 86 [0;31m            [0mloss[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mloss_fn[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0midx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m[0;34m[0m[0m
[0m[0;32m     88 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  scores.shape


torch.Size([20, 30])


ipdb>  idx.shape


torch.Size([20, 30])


ipdb>  b self.loss_fn.forward


Breakpoint 2 at /scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py:341


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(353)[0;36mforward[0;34m()[0m
[0;32m    351 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    352 [0;31m    ):
[0m[0;32m--> 353 [0;31m        [0mstore_attr[0m[0;34m([0m[0;34m'margin,tau,apply_softmax,n_negatives'[0m[0;34m,[0m [0mis_none[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    354 [0;31m[0;34m[0m[0m
[0m[0;32m    355 [0;31m        [0;32massert[0m [0mscores[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(355)[0;36mforward[0;34m()[0m
[0;32m    353 [0;31m        [0mstore_attr[0m[0;34m([0m[0;34m'margin,tau,apply_softmax,n_negatives'[0m[0;34m,[0m [0mis_none[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    354 [0;31m[0;34m[0m[0m
[0m[0;32m--> 355 [0;31m        [0;32massert[0m [0mscores[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    356 [0;31m        [0;32massert[0m [0minp2targ_idx[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    357 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(356)[0;36mforward[0;34m()[0m
[0;32m    354 [0;31m[0;34m[0m[0m
[0m[0;32m    355 [0;31m        [0;32massert[0m [0mscores[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 356 [0;31m        [0;32massert[0m [0minp2targ_idx[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    357 [0;31m[0;34m[0m[0m
[0m[0;32m    358 [0;31m        mask = self.get_row_mask(inp2targ_idx.new_full((inp2targ_idx.size(0),), inp2targ_idx.size(1)), inp2targ_idx.flatten(), 
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(358)[0;36mforward[0;34m()[0m
[0;32m    356 [0;31m        [0;32massert[0m [0minp2targ_idx[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    357 [0;31m[0;34m[0m[0m
[0m[0;32m--> 358 [0;31m        mask = self.get_row_mask(inp2targ_idx.new_full((inp2targ_idx.size(0),), inp2targ_idx.size(1)), inp2targ_idx.flatten(), 
[0m[0;32m    359 [0;31m                                 n_pinp2targ, pinp2targ_idx)
[0m[0;32m    360 [0;31m        [0mmask[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(359)[0;36mforward[0;34m()[0m
[0;32m    357 [0;31m[0;34m[0m[0m
[0m[0;32m    358 [0;31m        mask = self.get_row_mask(inp2targ_idx.new_full((inp2targ_idx.size(0),), inp2targ_idx.size(1)), inp2targ_idx.flatten(), 
[0m[0;32m--> 359 [0;31m                                 n_pinp2targ, pinp2targ_idx)
[0m[0;32m    360 [0;31m        [0mmask[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    361 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(358)[0;36mforward[0;34m()[0m
[0;32m    356 [0;31m        [0;32massert[0m [0minp2targ_idx[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    357 [0;31m[0;34m[0m[0m
[0m[0;32m--> 358 [0;31m        mask = self.get_row_mask(inp2targ_idx.new_full((inp2targ_idx.size(0),), inp2targ_idx.size(1)), inp2targ_idx.flatten(), 
[0m[0;32m    359 [0;31m                                 n_pinp2targ, pinp2targ_idx)
[0m[0;32m    360 [0;31m        [0mmask[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(360)[0;36mforward[0;34m()[0m
[0;32m    358 [0;31m        mask = self.get_row_mask(inp2targ_idx.new_full((inp2targ_idx.size(0),), inp2targ_idx.size(1)), inp2targ_idx.flatten(), 
[0m[0;32m    359 [0;31m                                 n_pinp2targ, pinp2targ_idx)
[0m[0;32m--> 360 [0;31m        [0mmask[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    361 [0;31m[0;34m[0m[0m
[0m[0;32m    362 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mscores[0m[0;34m[[0m[0mmask[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  mask


tensor([ True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, 

ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(362)[0;36mforward[0;34m()[0m
[0;32m    360 [0;31m        [0mmask[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    361 [0;31m[0;34m[0m[0m
[0m[0;32m--> 362 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mscores[0m[0;34m[[0m[0mmask[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    363 [0;31m        [0mpos_n_inp2targ[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    364 [0;31m[0;34m[0m[0m
[0m


ipdb>  mask.shape


torch.Size([20, 30])


ipdb>  mask


tensor([[ True, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False,  True, Fals

ipdb>  mask[0]


tensor([ True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False],
       device='cuda:0')


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(363)[0;36mforward[0;34m()[0m
[0;32m    361 [0;31m[0;34m[0m[0m
[0m[0;32m    362 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mscores[0m[0;34m[[0m[0mmask[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 363 [0;31m        [0mpos_n_inp2targ[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    364 [0;31m[0;34m[0m[0m
[0m[0;32m    365 [0;31m        [0mpos_scores[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpos_scores[0m[0;34m,[0m [0mpos_n_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(365)[0;36mforward[0;34m()[0m
[0;32m    363 [0;31m        [0mpos_n_inp2targ[0m [0;34m=[0m [0mmask[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    364 [0;31m[0;34m[0m[0m
[0m[0;32m--> 365 [0;31m        [0mpos_scores[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpos_scores[0m[0;34m,[0m [0mpos_n_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    366 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;34m~[0m[0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    367 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(366)[0;36mforward[0;34m()[0m
[0;32m    364 [0;31m[0;34m[0m[0m
[0m[0;32m    365 [0;31m        [0mpos_scores[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpos_scores[0m[0;34m,[0m [0mpos_n_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 366 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;34m~[0m[0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    367 [0;31m[0;34m[0m[0m
[0m[0;32m    368 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(368)[0;36mforward[0;34m()[0m
[0;32m    366 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;34m~[0m[0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    367 [0;31m[0;34m[0m[0m
[0m[0;32m--> 368 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    369 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m*[0m [0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    370 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(369)[0;36mforward[0;34m()[0m
[0;32m    367 [0;31m[0;34m[0m[0m
[0m[0;32m    368 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 369 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m*[0m [0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    370 [0;31m[0;34m[0m[0m
[0m[0;32m    371 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(371)[0;36mforward[0;34m()[0m
[0;32m    369 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m*[0m [0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    370 [0;31m[0;34m[0m[0m
[0m[0;32m--> 371 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    372 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    373 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(372)[0;36mforward[0;34m()[0m
[0;32m    370 [0;31m[0;34m[0m[0m
[0m[0;32m    371 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 372 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    373 [0;31m[0;34m[0m[0m
[0m[0;32m    374 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(374)[0;36mforward[0;34m()[0m
[0;32m    372 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    373 [0;31m[0;34m[0m[0m
[0m[0;32m--> 374 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    375 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;32mTrue[0m[0;34m)

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(375)[0;36mforward[0;34m()[0m
[0;32m    373 [0;31m[0;34m[0m[0m
[0m[0;32m    374 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 375 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    376 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m,[0m [0m

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(376)[0;36mforward[0;34m()[0m
[0;32m    374 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    375 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 376 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m,[0m [0mneg_incidence[0m[0;34m.[0m[0mgather[0

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(378)[0;36mforward[0;34m()[0m
[0;32m    376 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m,[0m [0mneg_incidence[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    377 [0;31m[0;34m[0m[0m
[0m[0;32m--> 378 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    379 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    380 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(379)[0;36mforward[0;34m()[0m
[0;32m    377 [0;31m[0;34m[0m[0m
[0m[0;32m    378 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 379 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    380 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    381 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(380)[0;36mforward[0;34m()[0m
[0;32m    378 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    379 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 380 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    381 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    382 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(381)[0;36mforward[0;34m()[0m
[0;32m    379 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    380 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 381 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    382 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    383 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(382)[0;36mforward[0;34m()[0m
[0;32m    380 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    381 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 382 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    383 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m*[0m[0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    384 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(383)[0;36mforward[0;34m()[0m
[0;32m    381 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    382 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 383 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m*[0m[0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    384 [0;31m[0;34m[0m[0m
[0m[0;32m    385 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(385)[0;36mforward[0;34m()[0m
[0;32m    383 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m*[0m[0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    384 [0;31m[0;34m[0m[0m
[0m[0;32m--> 385 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-9[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    386 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    387 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(386)[0;36mforward[0;34m()[0m
[0;32m    384 [0;31m[0;34m[0m[0m
[0m[0;32m    385 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-9[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 386 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    387 [0;31m[0;34m[0m[0m
[0m[0;32m    388 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'mean'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(388)[0;36mforward[0;34m()[0m
[0;32m    386 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    387 [0;31m[0;34m[0m[0m
[0m[0;32m--> 388 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'mean'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    389 [0;31m        [0;32melif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'sum'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    390 [0;31m        [0;32melse[0m[0;34m:[0m [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;

ipdb>  


--Return--
tensor(0.0284...eanBackward0>)
> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/losses.py[0m(388)[0;36mforward[0;34m()[0m
[0;32m    386 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    387 [0;31m[0;34m[0m[0m
[0m[0;32m--> 388 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'mean'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    389 [0;31m        [0;32melif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'sum'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    390 [0;31m        [0;32melse[0m[0;34m:[0m [0;32mra

ipdb>  


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(88)[0;36mforward[0;34m()[0m
[0;32m     86 [0;31m            [0mloss[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mloss_fn[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0midx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m[0;34m[0m[0m
[0m[0;32m---> 88 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     89 [0;31m            [0mo[0m [0;34m=[0m [0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mlbl2data_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     90 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m+[0m [0mo[0m[0;34m)[0m [0;32mif[0m [0mloss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mo[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_19285/1054349523.py[0m(92)[0;36mforward[0;34m()[0m
[0;32m     90 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m+[0m [0mo[0m[0;34m)[0m [0;32mif[0m [0mloss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mo[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     91 [0;31m[0;34m[0m[0m
[0m[0;32m---> 92 [0;31m        return XCModelOutput(
[0m[0;32m     93 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     94 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m



In [173]:
batch = batch.to(model.device)
o = model(**batch)

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/30_ngame-for-msmarco-with-hard-negatives'

    config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/negatives_exact.json.json'
    config_key = 'data_negatives_exact'
    
    mname = 'sentence-transformers/msmarco-distilbert-cos-v5'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/mogicX/msmarco_data-neg_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}.joblib'

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, 
                        only_test=input_args.only_test, use_meta_distribution=True, meta_oversample=True, n_sdata_meta_samples=10)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
    
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    def model_fn(mname, bsz):
        model = DBT021.from_pretrained(mname, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                                       apply_softmax=True, use_encoder_parallel=True)
        return model
    
    def init_fn(model): 
        model.init_dr_head()

    metric = PrecReclMrr(block.test.dset.n_lbl, block.test.data_lbl_filterer,
                     pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    model = load_model(args.output_dir, model_fn, {"mname": mname, "bsz": bsz}, init_fn, do_inference=do_inference, use_pretrained=input_args.use_pretrained)
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.test.dset.n_lbl)
    