# Losses

In [1]:
#| default_exp losses

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import functools, torch, torch.nn as nn, torch.nn.functional as F, pickle
from typing import MutableSequence, Union, Tuple

from fastcore.utils import *
from fastcore.meta import *

from xcai.torch_core import *
from xcai.core import *

In [4]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [5]:
import pickle, torch.autograd.profiler as profiler, copy
from xcai.block import *
from xcai.models.MMM0XX import *
from xcai.main import *

### `Data`

In [6]:
data_dir = '/Users/suchith720/Projects/data'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

pkl_dir = '/Users/suchith720/Projects/data/processed/'
pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [7]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_slbl_samples=2, 
                    n_sdata_meta_samples=3, do_build=True)



In [8]:
block.train.dset.data.main_oversample = False

In [9]:
batch = block.train.one_batch(4)

In [194]:
batch = block.train.one_batch(10)

In [195]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [196]:
batch['lbl2data_data2ptr'], batch['plbl2data_data2ptr']

(tensor([2, 1, 2, 2, 2, 2, 1, 2, 2, 2]),
 tensor([ 3,  1,  2,  4,  4, 25,  1, 14, 10,  5]))

In [197]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [198]:
batch['cat2data_data2ptr'], batch['pcat2data_data2ptr']

(tensor([3, 3, 3, 3, 1, 3, 3, 1, 2, 3]),
 tensor([ 8,  9, 38, 10,  1,  9, 11,  1,  2,  4]))

### `Model`

In [206]:
import types
from xcai.models.PPP0XX import DBTConfig, DBT009

In [207]:
config = DBTConfig(
    margin = 0.3,
    num_negatives = 10,
    tau = 0.1,
    apply_softmax = True,
    reduction = "mean",

    normalize = True,
    use_layer_norm = True,
    
    use_encoder_parallel = False,
    loss_function = "triplet"
)

In [208]:
model = DBT009.from_pretrained('distilbert-base-uncased', config=config)

Some weights of DBT009 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [209]:
def forward(self, *args, **kwargs): return 1.0
model.loss_fn.forward = types.MethodType(forward, model.loss_fn)

In [210]:
output = model(**batch.to(model.device))
inp, pos_targ, targ = output.data_repr, output.lbl2data_repr, output.lbl2data_repr

In [211]:
output = model(data_input_ids=batch['data_input_ids'], data_attention_mask=batch['data_attention_mask'], 
               lbl2data_data2ptr=batch['cat2data_data2ptr'], lbl2data_idx=batch['cat2data_idx'], 
               lbl2data_input_ids=batch['cat2data_input_ids'], lbl2data_attention_mask=batch['cat2data_attention_mask'], 
               plbl2data_data2ptr=batch['pcat2data_data2ptr'], plbl2data_idx=batch['pcat2data_idx'])
_, neg_targ = output.data_repr, output.lbl2data_repr

In [212]:
pos_targ.shape, neg_targ.shape

(torch.Size([18, 768]), torch.Size([25, 768]))

In [213]:
pos_idx, n_pos = batch['lbl2data_idx'], batch['lbl2data_data2ptr']
ppos_idx, n_ppos = batch['plbl2data_idx'], batch['plbl2data_data2ptr']

neg_idx, n_neg = batch['cat2data_idx'], batch['cat2data_data2ptr']

In [214]:
inp2targ_idx, n_inp2targ = batch['lbl2data_idx'], batch['lbl2data_data2ptr']
pinp2targ_idx, n_pinp2targ = batch['plbl2data_idx'], batch['plbl2data_data2ptr']

In [215]:
pos_idx, n_pos, pos_idx.shape, n_pos.shape

(tensor([    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67,   105,   102]),
 tensor([2, 1, 2, 2, 2, 2, 1, 2, 2, 2]),
 torch.Size([18]),
 torch.Size([10]))

In [216]:
ppos_idx, n_ppos, ppos_idx.shape, n_ppos.shape

(tensor([    0,     1,     2,     3,     9, 26766,    12,    13,    14,    15,
            16,    17,    18, 56258,    19,    20,    21,    22,    23,    24,
            25,    26,    27,    28,    29,    30,    31,    32,    33,    34,
            35,    36,    37,    38,    39,    40,    41,    42, 10243,    45,
            48,    49,    50,    51,    52,    53,    54,    55,    56,    57,
            58,    59,    60,    61,    62,    63,    64,    65,    66,    67,
            68,    69,    70, 81953,   101,   102,   103,   104,   105]),
 tensor([ 3,  1,  2,  4,  4, 25,  1, 14, 10,  5]),
 torch.Size([69]),
 torch.Size([10]))

In [217]:
neg_idx, n_neg, neg_idx.shape, n_neg.shape

(tensor([130669, 130668,  94657,  56026,  79350, 165108, 131288, 131262, 144686,
          71496,   3056,  71494, 144199,  60100,  52253,  77773,  53995,  65499,
          96896, 143634,  83602,  86585,  82231, 138912,  28165]),
 tensor([3, 3, 3, 3, 1, 3, 3, 1, 2, 3]),
 torch.Size([25]),
 torch.Size([10]))

## Helper

In [218]:
#| export
def get_sparse_matrix(data_idx:torch.Tensor, n_data:torch.Tensor, scores:Optional[torch.Tensor]=None, 
                      size:Optional[Tuple]=None):
    data_ptr = torch.cat([torch.zeros(1, device=n_data.device, dtype=n_data.dtype), n_data.cumsum(0)])
    if scores is None: scores = torch.ones_like(data_idx)
    if data_idx.shape != scores.shape: raise ValueError(f'`data_idx` and `scores` should have same shape.')
    return (
        torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)
        if size is None else
        torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device, size=size)
    )
    

In [219]:
#| export
def mix_classes(class_a, class_b, name=None):
    name = name or f"{class_a.__name__}And{class_b.__name__}"
    return type(name, (class_a, class_b), {})
    

## BaseLoss

In [29]:
#| export
class BaseLoss(nn.Module):

    def __init__(
        self, 
        reduce:Optional[str]=None, 
        **kwargs
    ):
        super().__init__()
        self.reduce = reduce

    @property
    def reduction(self) -> str: return self.reduce
    
    @reduction.setter
    def reduction(self, v:str):
        "Sets the reduction style (typically 'mean', 'sum', or 'none')" 
        self.reduce = v
        

## MultiCrossEntropy

In [177]:
#| export
class MultiCrossEntropy(BaseLoss):

    def __init__(
        self,
        tn_targ:Optional[int]=None, 
        ig_tok:Optional[int]=0,
        vocab_weights:Optional[torch.Tensor]=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('tn_targ,ig_tok,vocab_weights')
        self.o = torch.ones(tn_targ, dtype=torch.int64) if tn_targ is not None else None
        

In [None]:
vocab_weights = torch.rand(m.config.vocab_size)
mce_fn = MultiCrossEntropy(1000, vocab_weights=vocab_weights, reduce='mean')

### `Method 1`

In [178]:
#| export
@patch
def forward(
    cls:MultiCrossEntropy,
    inp:torch.FloatTensor,
    targ:torch.LongTensor,
    n_inp2targ:Optional[torch.LongTensor]=None,
    tn_targ:Optional[int]=None, 
    ig_tok:Optional[int]=None,
    vocab_weights:Optional[torch.Tensor]=None,
    **kwargs
):
    store_attr('tn_targ,ig_tok,vocab_weights', is_none=False)
    
    cls.o = cls.o.to(inp.device) if cls.o is not None else None
    cls.vocab_weights = cls.vocab_weights.to(inp.device) if cls.vocab_weights is not None else None
    
    tn_targ, targ_len = targ.shape
    bsz, inp_len, vocab_sz = inp.shape
    
    if cls.vocab_weights is not None and cls.vocab_weights.shape[0] != vocab_sz: 
        raise ValueError(f"`vocab_weights` should have {vocab_sz} elements.")
    
    seq_len = min(targ_len, inp_len)
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len].transpose(1,2), targ[:, :seq_len]
    if cls.vocab_weights is not None: inp *= cls.vocab_weights.unsqueeze(1)
    
    if n_inp2targ is not None:
        mn_targ = n_inp2targ.max()
    
        inp2targ_ptr = n_inp2targ.cumsum(dim=0)-1
        xn_inp2targ = mn_targ-n_inp2targ+1
        r_targ = (
            torch.ones(tn_targ, dtype=torch.int64, device=inp.device).scatter(0, inp2targ_ptr, xn_inp2targ)
            if cls.tn_targ is None or tn_targ > cls.tn_targ else
            cls.o[:tn_targ].scatter(0, inp2targ_ptr, xn_inp2targ)
        )
        xtarg = targ.repeat_interleave(r_targ, dim=0)
        s = inp.gather(1, xtarg.view(bsz, -1, seq_len)).view(-1, seq_len)
        s /= r_targ.repeat_interleave(r_targ, dim=0).view(-1, 1)
    else:
        if bsz != tn_targ: raise ValueError("`inp` and `targ` should have same number of elements as `n_inp2targ` is empty.")
        s = inp.gather(1, targ.view(bsz, -1, seq_len)).view(-1, seq_len); xtarg = targ
    
    idx = torch.where(xtarg != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    
    if cls.reduction == 'mean': return (loss/len(torch.where(targ != cls.ig_tok)[0])).sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')


In [None]:
loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr); loss

tensor(7.8086, device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
loss = mce_fn(lbl2data_logits, data_input_ids); loss

tensor(9.0191, device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
    print(loss)

tensor(16.0431, device='cuda:0', grad_fn=<SumBackward0>)


STAGE:2024-04-26 08:09:57 4294:4294 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-26 08:09:57 4294:4294 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-26 08:09:57 4294:4294 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [None]:
print(prof)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::max        15.76%     398.000us        17.19%     434.000us     434.000us           0 b           0 b         512 b           0 b             1  
                                            aten::empty         0.67%      17.000us         0.67%      17.000us      17.000us           0 b           0 b         512 b         512 b             1  
         

### `Method 2`

In [179]:
@patch
def forward(
    cls:MultiCrossEntropy, 
    inp:torch.FloatTensor, 
    targ:torch.LongTensor, 
    n_inp2targ:torch.LongTensor, 
    **kwargs
):
    inp_len, targ_len = inp.shape[1], targ.shape[1]
    seq_len = min(inp_len, targ_len)
    
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len], targ[:, :seq_len].unsqueeze(2)
    if cls.vocab_weights is not None: inp *= cls.vocab_weights
    inp = inp.repeat_interleave(n_inp2targ, dim=0)
    
    s = inp.gather(2, targ)
    idx = torch.where(targ != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr); loss

tensor(7.8086, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
loss = mce_fn(lbl2data_logits, data_input_ids, torch.ones(len(data_input_ids), dtype=data_input_ids.dtype, device=data_input_ids.device)); loss

tensor(9.0191, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
    print(loss)

STAGE:2024-05-05 16:20:26 189949:189949 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


tensor(14.2102, grad_fn=<MeanBackward0>)


STAGE:2024-05-05 16:20:28 189949:189949 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-05 16:20:28 189949:189949 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [None]:
print(prof)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::log_softmax         0.18%       5.000us        15.61%     441.000us     441.000us           0 b           0 b     112.00 Mb           0 b             1  
                                     aten::_log_softmax        14.86%     420.000us        15.43%     436.000us     436.000us           0 b           0 b     112.00 Mb     112.00 Mb             1  
         

### `Method 3`

In [180]:
@patch
def forward(
    cls:MultiCrossEntropy, 
    inp:torch.FloatTensor, 
    targ:torch.LongTensor, 
    n_inp2targ:torch.LongTensor, 
    **kwargs
):
    inp_len, targ_len = inp.shape[1], targ.shape[1]
    seq_len = min(inp_len, targ_len)
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len], targ[:, :seq_len]
    num, s = 0, []
    for i,n in zip(inp, n_inp2targ):
        for _ in range(n):
            s.append(i.gather(1, targ[num].view(-1, 1)).view(1, -1))
            num += 1
    s = torch.vstack(s)
    idx = torch.where(targ != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
    

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
    print(loss)

In [None]:
print(prof)

<unfinished torch.autograd.profile>


## `Calibration`

In [103]:
#| export
class Calibration(BaseLoss):

    def __init__(
        self,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        n_negatives:Optional[int]=10,
        apply_softmax:Optional[bool]=True,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('margin,tau,n_negatives,apply_softmax')
        

In [104]:
#| export
@patch
def forward(
    cls:Calibration,
    einp:torch.FloatTensor,
    inp:torch.FloatTensor, 
    targ:torch.LongTensor, 
    n_inp2targ:torch.LongTensor,
    inp2targ_idx:torch.LongTensor,
    n_pinp2targ:torch.LongTensor,
    pinp2targ_idx:torch.LongTensor,
    margin:Optional[float]=None,
    tau:Optional[float]=None,
    n_negatives:Optional[int]=None,
    apply_softmax:Optional[bool]=None,
    **kwargs
):
    store_attr('margin', is_none=False)

    einp, inp, targ = einp.float(), inp.float(), targ.float()
    esc, sc = einp@targ.T, inp@targ.T
    
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    pos = get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ, size=(len(n_pinp2targ), idx.max()+1)).to_dense()[:, idx[:len(inp2targ_idx)]]

    mul = 2 * pos - 1
    loss = F.relu((sc-esc) * mul + cls.margin)

    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[1]), dim=1, largest=True)
        esc, sc, mul = esc.gather(1, idx), sc.gather(1, idx), mul.gather(1, idx)
    
    if cls.apply_softmax:
        m = loss != 0
        s = torch.where(mul == 1, sc, esc)
        p = s/cls.tau * m
        p = torch.softmax(p, dim=1)
        loss = loss*p
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

### Example

In [None]:
loss_fn = Calibration(0.3, reduce='mean')

In [None]:
loss = loss_fn(data_repr+torch.randn(data_repr.shape), data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
               kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx']); loss

tensor(0.8689, grad_fn=<MeanBackward0>)

## `MultiTriplet (deprecated)`

In [183]:
class MultiTriplet(BaseLoss):

    def __init__(
        self,
        bsz:Optional[int]=None, 
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.8,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('bsz,tn_targ,margin,tau,apply_softmax,n_negatives')
        self.u = torch.arange(bsz, dtype=torch.int64) if bsz is not None else None
        self.v = torch.ones(tn_targ, dtype=torch.int64) if tn_targ is not None else None
        

In [None]:
mtl_fn = MultiTriplet(bsz, 10_000, 0.8, tau=0.1, n_negatives=5, apply_softmax=True, reduce='mean')

### `Method 1`

In [184]:
@patch
def forward(cls:MultiTriplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            n_inp2targ:torch.LongTensor,
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    
    cls.u = cls.u.to(inp.device) if cls.u is not None else None
    cls.v = cls.v.to(inp.device) if cls.v is not None else None
    
    bsz, tn_targ, mn_targ = inp.shape[0], targ.shape[0], n_inp2targ.max()
    u = torch.arange(bsz, dtype=torch.int64, device=inp.device) if cls.u is None or cls.bsz < bsz else cls.u[:bsz]
    v = (
        torch.ones(tn_targ, dtype=torch.int64, device=targ.device)
        if cls.tn_targ is None or tn_targ > cls.tn_targ else cls.v[:tn_targ]
    )
    
    targ2inp_ptr = u.repeat_interleave(n_inp2targ)
    sc = targ@inp.T
    ps = sc.gather(1, targ2inp_ptr.view(-1,1))
    
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    ne = ne.unsqueeze(1)
    
    inp2targ_ptr = n_inp2targ.cumsum(dim=0)-1
    xn_inp2targ = mn_targ-n_inp2targ+1
    r_targ = v.scatter(0, inp2targ_ptr, xn_inp2targ)

    psx = ps.repeat_interleave(r_targ).view(bsz, -1, 1)
    sc = sc.T.view(bsz, 1, -1)
    loss = F.relu((sc - psx + cls.margin)*ne)
    
    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[2]-n_inp2targ.max()), dim=2, largest=True)
        sc, ne = sc.expand(-1, mn_targ, -1).gather(2, idx), ne.expand(-1, mn_targ, -1).gather(2, idx)
    else: ne = ne.expand(-1, mn_targ, -1)
    
    if cls.apply_softmax:
        m = loss != 0
        p = sc/cls.tau * m
        p[ne == 0] = torch.finfo(p.dtype).min
        p = torch.softmax(p, dim=2)
        loss = loss*p
        
    loss /= (ne.sum(dim=2, keepdim=True) + 1e-9)
    
    xr_targ = r_targ.repeat_interleave(r_targ).view(bsz, -1, 1)
    loss /= xr_targ
    
    if cls.reduction == 'mean': return loss.sum()/n_inp2targ.sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, kwargs['plbl2data_data2ptr'],
              kwargs['plbl2data_idx'], n_negatives=5); loss

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


tensor(1.2115, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx)
    print(loss)

In [None]:
print(prof)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::max         0.95%     519.000us         1.11%     610.000us     610.000us           0 b           0 b         512 b           0 b             1  
                                            aten::empty         0.08%      43.000us         0.08%      43.000us      43.000us           0 b           0 b         512 b         512 b             1  
         

### `Method 2`

This the below function is not upto date, it has errors in it.

In [185]:
@patch
def forward(
    cls:MultiTriplet, 
    inp:torch.FloatTensor, 
    targ:torch.LongTensor, 
    n_inp2targ:torch.LongTensor,
    inp2targ_idx:torch.LongTensor,
    margin:Optional[float]=None,
    tau:Optional[float]=None,
    apply_softmax:Optional[bool]=None,
    n_negatives:Optional[int]=None,
    **kwargs
):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    bsz, tn_targ, mn_targ = inp.shape[0], targ.shape[0], n_inp2targ.max()
    u = torch.arange(bsz, dtype=torch.int64, device=inp.device) if cls.u is None or cls.bsz < bsz else cls.u[:bsz]
    v = (
        torch.ones(tn_targ, dtype=torch.int64, device=targ.device)
        if cls.tn_targ is None or tn_targ > cls.tn_targ else cls.v[:tn_targ]
    )
    
    targ2inp_ptr = u.repeat_interleave(n_inp2targ)
    sc = targ@inp.T
    ps = sc.gather(1, targ2inp_ptr.view(-1,1))
    
    _, idx = torch.unique(inp2targ_idx, return_inverse=True)
    ne = 1 - get_sparse_matrix(idx, n_inp2targ).to_dense()[:, idx]
    ne = ne.unsqueeze(1)
    
    inp2targ_ptr = n_inp2targ.cumsum(dim=0)-1
    xn_inp2targ = mn_targ-n_inp2targ+1
    r_targ = v.scatter(0, inp2targ_ptr, xn_inp2targ)

    psx = ps.repeat_interleave(r_targ).view(bsz, -1, 1)
    sc = sc.T.view(bsz, 1, -1)
    loss = torch.clamp((sc - psx + cls.margin)*ne, 0)
    
    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[2]-n_inp2targ.max()), dim=2, largest=True)
        sc, ne = sc.expand(-1, mn_targ, -1).gather(2, idx), ne.expand(-1, mn_targ, -1).gather(2, idx)
    else: ne = ne.expand(-1, mn_targ, -1)
    
    if cls.apply_softmax:
        m = loss != 0
        p = sc/cls.tau * m
        p[ne == 0] = -1e9
        p = torch.softmax(p, dim=2)
        loss = loss*p
        
    loss /= (ne.sum(dim=2, keepdim=True) + 1e-9)
    
    xr_targ = r_targ.repeat_interleave(r_targ).view(bsz, -1, 1)
    loss /= xr_targ
    
    if cls.reduction == 'mean': return loss.sum()/n_inp2targ.sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

### `Method 3`

In [186]:
@patch
def forward(
    cls:MultiTriplet, 
    inp:torch.FloatTensor, 
    targ:torch.LongTensor, 
    n_inp2targ:torch.LongTensor, 
    inp2targ_idx:torch.LongTensor,
    margin:Optional[float]=None,
    tau:Optional[float]=0.1,
    apply_softmax=False,
    **kwargs
):
    store_attr('margin,tau,apply_softmax', is_none=False)
    score = inp@targ.T
    ptr, loss, num_ne = 0, [], 0
    
    _, idx = torch.unique(inp2targ_idx, return_inverse=True)
    ne = 1 - get_sparse_matrix(idx, n_inp2targ).to_dense()[:, idx]
    
    for i, n in enumerate(n_inp2targ):
        ps = score[i, ptr:ptr+n].view(-1, 1)
        fs = torch.clamp((score[i] - ps + cls.margin)*ne[i], 0)
        if cls.apply_softmax: 
            m = fs != 0
            p = torch.softmax(score[i]/self.tau * m, dim=1)
            fs = fs*p
        loss.append(fs)
        ptr += n.item()
        num_ne += n*ne[i].sum()
    loss = torch.vstack(loss)
    if cls.reduction == 'mean': return loss.sum()/num_ne
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
             

In [None]:
loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx); loss

tensor(0.7512, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr)
    print(loss)

STAGE:2024-04-19 01:54:10 13866:13866 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-19 01:54:10 13866:13866 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-19 01:54:10 13866:13866 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


tensor(1.4957, device='cuda:0', grad_fn=<MeanBackward0>)


In [None]:
print(prof)

-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                    aten::numpy_T         0.00%       5.000us         0.03%      53.000us      53.000us           0 b           0 b           0 b           0 b             1  
                                    aten::permute         0.03%      38.000us         0.03%      48.000us      48.000us           0 b           0 b           0 b           0 b             1  
                                 aten::

## `LossOperations`

In [43]:
#| export
class LossOperations:

    # BaseMultiTriplet
    
    @staticmethod
    def align_indices(indices:torch.Tensor, group_lengths:torch.Tensor):
        n, num_groups, max_len = len(indices), len(group_lengths), group_lengths.max()
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=indices.device), group_lengths)
    
        row_indices = torch.arange(n, device=indices.device)
    
        group_start = torch.cat([torch.zeros(1, dtype=group_lengths.dtype, device=group_lengths.device), group_lengths.cumsum(0)[:-1]], dim=0)
    
        within_idx = row_indices - group_start[group_ids]
    
        output = torch.zeros((num_groups, max_len), dtype=indices.dtype, device=indices.device)
        mask = torch.zeros((num_groups, max_len), device=indices.device)
        output[group_ids, within_idx] = indices
        mask[group_ids, within_idx] = 1.0
    
        return output, mask

    @staticmethod
    def remove_redundant_indices(inp2targ_idx:torch.Tensor, n_inp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor):
        mask = torch.isin(pinp2targ_idx, inp2targ_idx)
        new_pinp2targ_idx = pinp2targ_idx[mask]
    
        num_groups = len(n_pinp2targ)
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=n_pinp2targ.device), n_pinp2targ)
        new_n_pinp2targ = torch.bincount(group_ids[mask], minlength=num_groups)
    
        return new_pinp2targ_idx, new_n_pinp2targ

    @staticmethod
    def reset_indices(inp2targ_idx:torch.Tensor, n_inp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor):
        _, reset_indices, counts = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True, return_counts=True)
    
        _, idx_sorted = torch.sort(reset_indices, stable=True)
        cum_sum = torch.cat((torch.zeros((1,), dtype=counts.dtype, device=counts.device), counts.cumsum(0)[:-1]))
        indices = idx_sorted[cum_sum]
    
        inp2targ_idx = reset_indices[:len(inp2targ_idx)]
        pinp2targ_idx = reset_indices[len(inp2targ_idx):]
    
        return inp2targ_idx, pinp2targ_idx, indices

    # MultiTripletFromScores
    
    @staticmethod
    def get_incidence(n_inp2targ:int, inp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor):
        row_idx = torch.arange(len(n_pinp2targ), device=n_pinp2targ.device)
        inp2targ_row_idx = row_idx.repeat_interleave(n_inp2targ)
        pinp2targ_row_idx = row_idx.repeat_interleave(n_pinp2targ)
        
        max_col_idx = max(int(inp2targ_idx.max()), int(pinp2targ_idx.max()))
        offset = max_col_idx + 1
    
        inp2targ_keys = inp2targ_row_idx * offset + inp2targ_idx
        pinp2targ_keys = pinp2targ_row_idx * offset + pinp2targ_idx
    
        return torch.isin(inp2targ_keys, pinp2targ_keys)

    @staticmethod
    def get_pos_scores(scores:torch.FloatTensor, n_inp2targ:torch.FloatTensor):
        row_idx = torch.arange(len(n_inp2targ), device=n_inp2targ.device)
        inp2targ_row_idx = row_idx.repeat_interleave(n_inp2targ)
        inp2targ_col_idx = torch.arange(n_inp2targ.sum(), device=n_inp2targ.device)
        return scores[inp2targ_row_idx, inp2targ_col_idx]

    # BaseWithNegatives
    
    @staticmethod
    def get_scores(inp:torch.Tensor, pos_targ:torch.Tensor, neg_targ:Optional[torch.Tensor]=None, n_neg:Optional[int]=None):
        scores = inp @ pos_targ.T
        if neg_targ is not None:
            neg_scores = inp.unsqueeze(1) @ neg_targ.view(len(inp), n_neg, -1).transpose(1, 2)
            neg_scores = neg_scores.squeeze(1)
            scores = torch.hstack([scores, neg_scores])
        return scores

    @staticmethod
    def get_indices(pos_idx:torch.Tensor, bsz:int, neg_idx:Optional[torch.Tensor]=None, n_neg:Optional[int]=None):
        indices = torch.repeat_interleave(pos_idx.unsqueeze(0), bsz, 0)
        if neg_idx is not None:
            neg_idx = neg_idx.view(bsz, n_neg)
            indices = torch.hstack([indices, neg_idx])
        return indices

    # MultiRankingFromScores
    
    @staticmethod
    def masked_inclusive_topk(values:torch.Tensor, mask:torch.Tensor, k:int):
        biased_vals = torch.where(mask.bool(), torch.finfo(values.dtype).max, values)
        
        provisional_vals, provisional_idx = torch.topk(biased_vals, k, dim=-1)
        topk_vals = values.gather(1, provisional_idx)
        
        sorted_vals, sort_idx = torch.sort(topk_vals, descending=True, dim=-1)
        sorted_idx = provisional_idx.gather(1, sort_idx)
        
        return sorted_vals, sorted_idx
        

## `MultiTriplet`

In [254]:
#| export
class BaseMultiTriplet(BaseLoss, LossOperations):

    def __init__(
        self,
        margin:Optional[float]=0.8,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('margin,tau,apply_softmax,n_negatives')

    def compute_scores(self, inp, targ, indices=None):
        if indices is not None: targ = targ[indices]
        return inp@targ.T

    def forward(
        self, 
        
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,

        inp:Optional[torch.FloatTensor]=None, 
        targ:Optional[torch.FloatTensor]=None,
        scores:Optional[torch.FloatTensor]=None,
        
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)

        inp, targ = inp.float(), targ.float()
        scores = scores if scores is None else scores.float()
        
        pinp2targ_idx, n_pinp2targ = self.remove_redundant_indices(inp2targ_idx, n_inp2targ, pinp2targ_idx, n_pinp2targ)
        inp2targ_idx, pinp2targ_idx, indices = self.reset_indices(inp2targ_idx, n_inp2targ, pinp2targ_idx, n_pinp2targ)

        scores = self.compute_scores(inp, targ, indices=indices) if scores is None else scores[:, indices]

        pos_indices, pos_mask = self.align_indices(inp2targ_idx, n_inp2targ)
        pos_scores = scores.gather(1, pos_indices)

        pos_incidence = torch.zeros_like(scores)
        
        ppos_indices, ppos_mask = self.align_indices(pinp2targ_idx, n_pinp2targ)
        pos_incidence = pos_incidence.scatter(1, ppos_indices, 1)

        ppos_indices[~ppos_mask.bool()] = -1
        row_idx = torch.where(torch.all(ppos_indices != 0, dim=1))[0]
        pos_incidence[row_idx, 0] = 0
        
        neg_incidence = 1 - pos_incidence

        loss = scores.unsqueeze(1) - pos_scores.unsqueeze(2) + self.margin
        loss = F.relu(loss * neg_incidence.unsqueeze(1))

        scores = scores.unsqueeze(1).expand_as(loss)
        neg_incidence = neg_incidence.unsqueeze(1).expand_as(loss)

        if self.n_negatives is not None:
            loss, idx = torch.topk(loss, min(self.n_negatives, loss.shape[2]), dim=2, largest=True)
            scores, neg_incidence = scores.gather(2, idx), neg_incidence.gather(2, idx)

        if self.apply_softmax:
            mask = loss != 0
            penalty = scores / self.tau * mask
            penalty[neg_incidence == 0] = torch.finfo(penalty.dtype).min
            penalty = torch.softmax(penalty, dim=2)
            loss = loss * penalty
            
        loss /= (neg_incidence.sum(dim=2, keepdim=True) + 1e-6)
        loss /= (n_inp2targ.unsqueeze(1).unsqueeze(1) + 1e-6)
        loss = loss[pos_mask.bool()].sum()
        
        if self.reduction == 'mean': return loss/len(n_inp2targ)
        elif self.reduction == 'sum': return loss
        else: raise ValueError(f'`reduction` cannot be `{self.reduction}`')
    

In [255]:
#| export
class MultiTriplet(BaseMultiTriplet):

    def forward(
        self, 
        inp:torch.FloatTensor, # bs x dim
        targ:torch.FloatTensor, # total labels in batch (t) x dim
        n_inp2targ:torch.LongTensor, # bs x dim (like indptr in sp.csr_matrix)
        inp2targ_idx:torch.LongTensor, # t x dim (index of label)
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
                               apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        

In [222]:
#| export
class MultiTripletFromInBatchScores(BaseMultiTriplet):

    def forward(
        self, 
        scores:torch.FloatTensor,  
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, scores=scores, margin=margin, tau=tau, 
                               apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        

### Example

In [256]:
margin, tau = 0.3, 0.1
apply_softmax = True
n_negatives = 10

In [251]:
loss_fn = MultiTriplet(margin, tau, apply_softmax, n_negatives, reduce='mean')

In [252]:
loss = loss_fn(inp, targ, n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx)

In [253]:
loss

tensor(0.0221, grad_fn=<DivBackward0>)

In [258]:
def func():
    import pdb; pdb.set_trace()
    loss = loss_fn(inp, targ, n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx)
    

In [259]:
func()

> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3723332974.py[0m(2)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    [0mloss[0m [0;34m=[0m [0mloss_fn[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mtarg[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0minp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  b loss_fn.forward


Breakpoint 4 at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3851174484.py:18


ipdb>  c


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3851174484.py[0m(18)[0;36mforward[0;34m()[0m
[0;32m     16 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m4[0;32m--> 18 [0;31m        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
[0m[0;32m     19 [0;31m                               [0mapply_softmax[0m[0;34m=[0m[0mapply_softmax[0m[0;34m,[0m [0mn_negatives[0m[0;34m=[0m[0mn_negatives[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m


ipdb>  n_inp2targ


tensor([2, 1, 2, 2, 2, 2, 1, 2, 2, 2])


ipdb>  inp2targ_idx


tensor([    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
           42,    45,    51,    52,    66,    67,   105,   102])


ipdb>  n_pinp2targ


tensor([ 3,  1,  2,  4,  4, 25,  1, 14, 10,  5])


ipdb>  pinp2targ_idx


tensor([    0,     1,     2,     3,     9, 26766,    12,    13,    14,    15,
           16,    17,    18, 56258,    19,    20,    21,    22,    23,    24,
           25,    26,    27,    28,    29,    30,    31,    32,    33,    34,
           35,    36,    37,    38,    39,    40,    41,    42, 10243,    45,
           48,    49,    50,    51,    52,    53,    54,    55,    56,    57,
           58,    59,    60,    61,    62,    63,    64,    65,    66,    67,
           68,    69,    70, 81953,   101,   102,   103,   104,   105])


ipdb>  l


[1;32m     13 [0m        [0mtau[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m     14 [0m        [0mapply_softmax[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mbool[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m     15 [0m        [0mn_negatives[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m     16 [0m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[1;32m     17 [0m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;31m4[0;32m--> 18 [0;31m        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
[0m[1;32m     19 [0m                               [0mapply_softmax[0m[0;34m=[0m[0mapply_softmax[0m[0;34m,[0m [0mn_negatives[0m[0;34m=[0m[0mn_negati

ipdb>  s


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3851174484.py[0m(19)[0;36mforward[0;34m()[0m
[0;32m     16 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m4[0;32m    18 [0;31m        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
[0m[0;32m---> 19 [0;31m                               [0mapply_softmax[0m[0;34m=[0m[0mapply_softmax[0m[0;34m,[0m [0mn_negatives[0m[0;34m=[0m[0mn_negatives[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3851174484.py[0m(18)[0;36mforward[0;34m()[0m
[0;32m     16 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m4[0;32m--> 18 [0;31m        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
[0m[0;32m     19 [0;31m                               [0mapply_softmax[0m[0;34m=[0m[0mapply_softmax[0m[0;34m,[0m [0mn_negatives[0m[0;34m=[0m[0mn_negatives[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m


ipdb>  s


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3851174484.py[0m(19)[0;36mforward[0;34m()[0m
[0;32m     16 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m4[0;32m    18 [0;31m        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
[0m[0;32m---> 19 [0;31m                               [0mapply_softmax[0m[0;34m=[0m[0mapply_softmax[0m[0;34m,[0m [0mn_negatives[0m[0;34m=[0m[0mn_negatives[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3851174484.py[0m(18)[0;36mforward[0;34m()[0m
[0;32m     16 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m4[0;32m--> 18 [0;31m        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
[0m[0;32m     19 [0;31m                               [0mapply_softmax[0m[0;34m=[0m[0mapply_softmax[0m[0;34m,[0m [0mn_negatives[0m[0;34m=[0m[0mn_negatives[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m


ipdb>  s


--Call--
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(19)[0;36mforward[0;34m()[0m
[0;32m     17 [0;31m        [0;32mreturn[0m [0minp[0m[0;34m@[0m[0mtarg[0m[0;34m.[0m[0mT[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m[0;34m[0m[0m
[0m[0;32m---> 19 [0;31m    def forward(
[0m[0;32m     20 [0;31m        [0mself[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(37)[0;36mforward[0;34m()[0m
[0;32m     35 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     36 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 37 [0;31m        [0mstore_attr[0m[0;34m([0m[0;34m'margin,tau,apply_softmax,n_negatives'[0m[0;34m,[0m [0mis_none[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m[0;34m[0m[0m
[0m[0;32m     39 [0;31m        [0minp[0m[0;34m,[0m [0mtarg[0m [0;34m=[0m [0minp[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mtarg[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(39)[0;36mforward[0;34m()[0m
[0;32m     37 [0;31m        [0mstore_attr[0m[0;34m([0m[0;34m'margin,tau,apply_softmax,n_negatives'[0m[0;34m,[0m [0mis_none[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m[0;34m[0m[0m
[0m[0;32m---> 39 [0;31m        [0minp[0m[0;34m,[0m [0mtarg[0m [0;34m=[0m [0minp[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mtarg[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m [0;32mif[0m [0mscores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mscores[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(40)[0;36mforward[0;34m()[0m
[0;32m     38 [0;31m[0;34m[0m[0m
[0m[0;32m     39 [0;31m        [0minp[0m[0;34m,[0m [0mtarg[0m [0;34m=[0m [0minp[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mtarg[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m [0;32mif[0m [0mscores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mscores[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m[0;34m[0m[0m
[0m[0;32m     42 [0;31m        [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mremove_redundant_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m

ipdb>  inp.shape


torch.Size([10, 768])


ipdb>  targ.shape


torch.Size([18, 768])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(42)[0;36mforward[0;34m()[0m
[0;32m     40 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m [0;32mif[0m [0mscores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mscores[0m[0;34m.[0m[0mfloat[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m[0;34m[0m[0m
[0m[0;32m---> 42 [0;31m        [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mremove_redundant_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m        [0minp2targ_idx[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mindices[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mreset_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m

ipdb>  pinp2targ_idx


tensor([    0,     1,     2,     3,     9, 26766,    12,    13,    14,    15,
           16,    17,    18, 56258,    19,    20,    21,    22,    23,    24,
           25,    26,    27,    28,    29,    30,    31,    32,    33,    34,
           35,    36,    37,    38,    39,    40,    41,    42, 10243,    45,
           48,    49,    50,    51,    52,    53,    54,    55,    56,    57,
           58,    59,    60,    61,    62,    63,    64,    65,    66,    67,
           68,    69,    70, 81953,   101,   102,   103,   104,   105])


ipdb>  n_pinp2targ


tensor([ 3,  1,  2,  4,  4, 25,  1, 14, 10,  5])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(43)[0;36mforward[0;34m()[0m
[0;32m     41 [0;31m[0;34m[0m[0m
[0m[0;32m     42 [0;31m        [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mremove_redundant_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 43 [0;31m        [0minp2targ_idx[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mindices[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mreset_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m[0;34m[0m[0m
[0m[0;32m     45 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcompute_scores[0m[0;34m([0m[0minp[

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(45)[0;36mforward[0;34m()[0m
[0;32m     43 [0;31m        [0minp2targ_idx[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mindices[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mreset_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m[0;34m[0m[0m
[0m[0;32m---> 45 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcompute_scores[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mtarg[0m[0;34m,[0m [0mindices[0m[0;34m=[0m[0mindices[0m[0;34m)[0m [0;32mif[0m [0mscores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mscores[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mindices[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m[0;34m[0m[0m
[0m[0;32m     47 [0;31m        [0mpos_indices

ipdb>  inp2targ_idx


tensor([ 0,  1,  2, 16,  3,  4,  5,  6, 17,  7,  8,  9, 10, 11, 12, 13, 15, 14])


ipdb>  pinp2targ_idx


tensor([ 0,  1,  2,  3, 16,  4,  5,  6, 17,  7,  8,  9, 10, 11, 12, 13, 14, 15])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(47)[0;36mforward[0;34m()[0m
[0;32m     45 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcompute_scores[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mtarg[0m[0;34m,[0m [0mindices[0m[0;34m=[0m[0mindices[0m[0;34m)[0m [0;32mif[0m [0mscores[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mscores[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mindices[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m[0;34m[0m[0m
[0m[0;32m---> 47 [0;31m        [0mpos_indices[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mpos_indices[0m[0;34m)[0m[0;34m[0m[0;34m

ipdb>  scores.shape


torch.Size([10, 18])


ipdb>  inp2targ_idx.shape


torch.Size([18])


ipdb>  n_inp2targ.shape


torch.Size([10])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(48)[0;36mforward[0;34m()[0m
[0;32m     46 [0;31m[0;34m[0m[0m
[0m[0;32m     47 [0;31m        [0mpos_indices[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0minp2targ_idx[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 48 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mpos_indices[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m[0;34m[0m[0m
[0m[0;32m     50 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mzeros_like[0m[0;34m([0m[0mscores[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  pos_indices


tensor([[ 0,  1],
        [ 2,  0],
        [16,  3],
        [ 4,  5],
        [ 6, 17],
        [ 7,  8],
        [ 9,  0],
        [10, 11],
        [12, 13],
        [15, 14]])


ipdb>  pos_mask


tensor([[1., 1.],
        [1., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(50)[0;36mforward[0;34m()[0m
[0;32m     48 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mpos_indices[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m[0;34m[0m[0m
[0m[0;32m---> 50 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mzeros_like[0m[0;34m([0m[0mscores[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m[0;34m[0m[0m
[0m[0;32m     52 [0;31m        [0mppos_indices[0m[0;34m,[0m [0mppos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(52)[0;36mforward[0;34m()[0m
[0;32m     50 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mzeros_like[0m[0;34m([0m[0mscores[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m[0;34m[0m[0m
[0m[0;32m---> 52 [0;31m        [0mppos_indices[0m[0;34m,[0m [0mppos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     53 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mpos_incidence[0m[0;34m.[0m[0mscatter[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mppos_indices[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m[0;34m[0m[0m
[0m


ipdb>  pos_incidence


tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(53)[0;36mforward[0;34m()[0m
[0;32m     51 [0;31m[0;34m[0m[0m
[0m[0;32m     52 [0;31m        [0mppos_indices[0m[0;34m,[0m [0mppos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 53 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mpos_incidence[0m[0;34m.[0m[0mscatter[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mppos_indices[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m[0;34m[0m[0m
[0m[0;32m     55 [0;31m        [0mppos_indices[0m[0;34m[[0m[0;34m~[0m[0mppos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m-[0m[0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(55)[0;36mforward[0;34m()[0m
[0;32m     53 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mpos_incidence[0m[0;34m.[0m[0mscatter[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mppos_indices[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m[0;34m[0m[0m
[0m[0;32m---> 55 [0;31m        [0mppos_indices[0m[0;34m[[0m[0;34m~[0m[0mppos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m-[0m[0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m        [0mrow_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mppos_indices[0m [0;34m!=[0m [0;36m0[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m        [0m

ipdb>  pos_incidence


tensor([[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.]])


ipdb>  ~ppos_mask.bool()


tensor([[False, False],
        [False,  True],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False,  True],
        [False, False],
        [False, False],
        [False, False]])


ipdb>  ppos_mask


tensor([[1., 1.],
        [1., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])


ipdb>  l


[1;32m     50 [0m        [0mpos_incidence[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mzeros_like[0m[0;34m([0m[0mscores[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m     51 [0m[0;34m[0m[0m
[1;32m     52 [0m        [0mppos_indices[0m[0;34m,[0m [0mppos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpinp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m     53 [0m        [0mpos_incidence[0m [0;34m=[0m [0mpos_incidence[0m[0;34m.[0m[0mscatter[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mppos_indices[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m     54 [0m[0;34m[0m[0m
[0;32m---> 55 [0;31m        [0mppos_indices[0m[0;34m[[0m[0;34m~[0m[0mppos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m-[0m[0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m     56 [0m        [0mrow_idx[0m [0;34m=[0m [0mtorch[0m

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(56)[0;36mforward[0;34m()[0m
[0;32m     54 [0;31m[0;34m[0m[0m
[0m[0;32m     55 [0;31m        [0mppos_indices[0m[0;34m[[0m[0;34m~[0m[0mppos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m-[0m[0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 56 [0;31m        [0mrow_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mppos_indices[0m [0;34m!=[0m [0;36m0[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m        [0mpos_incidence[0m[0;34m[[0m[0mrow_idx[0m[0;34m,[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m[0;34m[0m[0m
[0m


ipdb>  ppos_indices


tensor([[ 0,  1],
        [ 2, -1],
        [ 3, 16],
        [ 4,  5],
        [ 6, 17],
        [ 7,  8],
        [ 9, -1],
        [10, 11],
        [12, 13],
        [14, 15]])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(57)[0;36mforward[0;34m()[0m
[0;32m     55 [0;31m        [0mppos_indices[0m[0;34m[[0m[0;34m~[0m[0mppos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m-[0m[0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m        [0mrow_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mppos_indices[0m [0;34m!=[0m [0;36m0[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 57 [0;31m        [0mpos_incidence[0m[0;34m[[0m[0mrow_idx[0m[0;34m,[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m[0;34m[0m[0m
[0m[0;32m     59 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;36m1[0m [0;34m-[0m [0mpos_i

ipdb>  row_idx


tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(59)[0;36mforward[0;34m()[0m
[0;32m     57 [0;31m        [0mpos_incidence[0m[0;34m[[0m[0mrow_idx[0m[0;34m,[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m[0;34m[0m[0m
[0m[0;32m---> 59 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;36m1[0m [0;34m-[0m [0mpos_incidence[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  pos_incidence


tensor([[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.]])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(61)[0;36mforward[0;34m()[0m
[0;32m     59 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;36m1[0m [0;34m-[0m [0mpos_incidence[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m---> 61 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mloss[0m [0;34m*[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     63 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(62)[0;36mforward[0;34m()[0m
[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 62 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mloss[0m [0;34m*[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m     64 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(64)[0;36mforward[0;34m()[0m
[0;32m     62 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mloss[0m [0;34m*[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m---> 64 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m


ipdb>  loss


tensor([[[0.0000, 0.0000, 0.1512, 0.0717, 0.0613, 0.0000, 0.0894, 0.1141,
          0.0000, 0.1403, 0.2108, 0.2186, 0.1381, 0.1448, 0.1164, 0.1302,
          0.1450, 0.1255],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0222, 0.0300, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1983,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1093, 0.1494,
          0.1060, 0.0000],
         [0.3000, 0.3517, 0.0000, 0.2445, 0.2485, 0.1161, 0.3500, 0.5874,
          0.1484, 0.2714, 0.2944, 0.3103, 0.2877, 0.2918, 0.4984, 0.5384,
          0.4950, 0.3407]],

        [[0.0130, 0.0000, 0.1060, -0.0000, 0.0000, 0.0000, 0.0000, 0.1053,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0585, 0.0830,
          0.0000, 0.0000],
         [0.4655, 0.4088, 0.5584, 0.0000, 0.3249, 0.2112, 0.3950, 0.5578,
          0.2175, 0.3474, 0.3432, 0.3748, 0.30

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m     64 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(67)[0;36mforward[0;34m()[0m
[0;32m     65 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m---> 67 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0

ipdb>  min(self.n_negatives, loss.shape[2])


10


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 69 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m,[0m [0mneg_incidence[0m[0;

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(71)[0;36mforward[0;34m()[0m
[0;32m     69 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m,[0m [0mneg_incidence[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m---> 71 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0

ipdb>  mask


tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True, False, False, False, False, False, False, False, False]],

        [[ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(74)[0;36mforward[0;34m()[0m
[0;32m     72 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 74 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m            [0mloss[0m [0;3

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m [0;34m*[0m [0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(76)[0;36mforward[0;34m()[0m
[0;32m     74 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 76 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m [0;34m*[0m [0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m     78 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(78)[0;36mforward[0;34m()[0m
[0;32m     76 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m [0;34m*[0m [0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m---> 78 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     79 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mn_inp2targ[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mboo

ipdb>  loss


tensor([[[0.0400, 0.0357, 0.0141, 0.0127, 0.0127, 0.0117, 0.0113, 0.0099,
          0.0091, 0.0077],
         [0.0155, 0.0106, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0819, 0.0378, 0.0185, 0.0174, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.2049, 0.1152, 0.0714, 0.0686, 0.0116, 0.0114, 0.0101, 0.0068,
          0.0059, 0.0055]],

        [[0.0278, 0.0275, 0.0173, 0.0096, 0.0014, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.1141, 0.1132, 0.0870, 0.0650, 0.0376, 0.0238, 0.0202, 0.0187,
          0.0157, 0.0122]],

        [[0.0303, 0.0263, 0.0101, 0.0067, 0.0064, 0.0058, 0.0053, 0.0024,
          0.0013, 0.0004],
         [0.0367, 0.0322, 0.0135, 0.0096, 0.0092, 0.0086, 0.0079, 0.0045,
          0.0030, 0.0020]],

        [[0.0317, 0.0257, 0.0239, 0.0192, 0.0189, 0.0186, 0.0179, 0.0157,
          0.0152, 0.0145],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(79)[0;36mforward[0;34m()[0m
[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m     78 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 79 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mn_inp2targ[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     8

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(80)[0;36mforward[0;34m()[0m
[0;32m     78 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     79 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mn_inp2targ[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 80 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m[0;34m[0m[0m
[0m[0;32m     8

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'mean'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m/[0m[0mlen[0m[0;34m([0m[0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m        [0;32melif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'sum'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m        [0;32melse[0m[0;34m:[0m [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34mf'[0m[0;34m`reduction` cannot be `[0m

ipdb>  loss


tensor(0.2206, grad_fn=<SumBackward0>)


ipdb>  n


--Return--
tensor(0.0221...DivBackward0>)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/1672367867.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'mean'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m/[0m[0mlen[0m[0;34m([0m[0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m        [0;32melif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'sum'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m        [0;32melse[0m[0;34m:[0m [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34

ipdb>  


--Return--
tensor(0.0221...DivBackward0>)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3851174484.py[0m(18)[0;36mforward[0;34m()[0m
[0;32m     16 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m4[0;32m--> 18 [0;31m        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
[0m[0;32m     19 [0;31m                               [0mapply_softmax[0m[0;34m=[0m[0mapply_softmax[0m[0;34m,[0m [0mn_negatives[0m[0;34m=[0m[0mn_negatives[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m



## `MultiTripletFromScores`

In [184]:
#| export
class MultiTripletFromScores(BaseMultiTriplet):
    
    def forward(
        self, 
        scores:torch.FloatTensor,  
        inp2targ_idx:torch.LongTensor,
        
        n_inp2targ:torch.LongTensor,
        
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
        
        assert scores.dim() == 2, "`scores` should be two dimensional matrix."
        assert inp2targ_idx.dim() == 2, "`inp2targ_idx` should be two dimensional matrix."
        
        pos_incidence = self.get_incidence(inp2targ_idx.shape[1], inp2targ_idx.flatten(), n_pinp2targ, pinp2targ_idx)
        pos_incidence = pos_incidence.view(inp2targ_idx.shape)

        pos_scores = self.get_pos_scores(scores, n_inp2targ)
        pos_scores, pos_mask = self.align_indices(pos_scores, n_inp2targ)
        neg_incidence = ~pos_incidence

        loss = scores.unsqueeze(1) - pos_scores.unsqueeze(2) + self.margin
        loss = F.relu(loss * neg_incidence.unsqueeze(1))
        
        scores = scores.unsqueeze(1).expand_as(loss)
        neg_incidence = neg_incidence.unsqueeze(1).expand_as(loss)

        if self.n_negatives is not None:
            loss, idx = torch.topk(loss, min(self.n_negatives, loss.shape[2]), dim=2, largest=True)
            scores, neg_incidence = scores.gather(2, idx), neg_incidence.gather(2, idx)

        if self.apply_softmax:
            mask = loss != 0
            penalty = scores / self.tau * mask
            penalty[neg_incidence == 0] = torch.finfo(penalty.dtype).min
            penalty = torch.softmax(penalty, dim=2)
            loss = loss * penalty

        loss /= (neg_incidence.sum(dim=2, keepdim=True) + 1e-6)
        loss /= (n_inp2targ.unsqueeze(1).unsqueeze(1) + 1e-6)
        loss = loss[pos_mask.bool()].sum()
        
        if self.reduction == 'mean': return loss/len(n_inp2targ)
        elif self.reduction == 'sum': return loss
        else: raise ValueError(f'`reduction` cannot be `{self.reduction}`')
        

## `BaseWithNegatives`

In [117]:
#| export
class BaseWithNegatives:

    def forward(
        self, 
        inp: torch.FloatTensor,
        
        pos_targ: torch.FloatTensor,
        n_pos: torch.LongTensor,
        pos_idx: torch.LongTensor,

        n_ppos: torch.LongTensor,
        ppos_idx: torch.LongTensor,
        
        neg_targ: Optional[torch.FloatTensor] = None,
        n_neg: Optional[torch.LongTensor] = None,
        neg_idx: Optional[torch.LongTensor] = None,
        
        **kwargs
    ):  
        if n_neg is not None:
            assert torch.all(n_neg == n_neg.max()), "All datapoints should same number of negatives"

        max_n_neg = None if n_neg is None else n_neg.max()
        scores = self.get_scores(inp, pos_targ, neg_targ, max_n_neg)
        indices = self.get_indices(pos_idx, len(inp), neg_idx, max_n_neg)
        return super().forward(scores, indices, n_pos, n_pinp2targ=n_ppos, pinp2targ_idx=ppos_idx)


## `MultiTripletWithNegatives`

In [227]:
#| export
MultiTripletWithNegatives = mix_classes(BaseWithNegatives, MultiTripletFromScores)

### Example

In [246]:
margin, tau = 0.3, 0.1
apply_softmax = True
n_negatives = 10

In [260]:
loss_fn = MultiTripletWithNegatives(margin, tau, apply_softmax, n_negatives, reduce='mean')

In [233]:
loss = loss_fn(inp, pos_targ, n_pos, pos_idx, n_ppos, ppos_idx, neg_targ=neg_targ, 
               n_neg=n_neg, neg_idx=neg_idx)

In [261]:
loss = loss_fn(inp, pos_targ, n_pos, pos_idx, n_ppos, ppos_idx)

In [249]:
loss

tensor(0.0223, grad_fn=<DivBackward0>)

In [262]:
def func():
    import pdb; pdb.set_trace()
    loss = loss_fn(inp, targ, n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx)
    

In [263]:
func()

> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/3723332974.py[0m(2)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    [0mloss[0m [0;34m=[0m [0mloss_fn[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mtarg[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m,[0m [0minp2targ_idx[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  b loss_fn.forward


Breakpoint 5 at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/832060994.py:21


ipdb>  c


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/832060994.py[0m(21)[0;36mforward[0;34m()[0m
[0;32m     19 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m5[0;32m--> 21 [0;31m        [0;32mif[0m [0mn_neg[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mn_neg[0m [0;34m==[0m [0mn_neg[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0;34m"All datapoints should same number of negatives"[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m[0;34m[0m[0m
[0m


ipdb>  n_neg
ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/832060994.py[0m(24)[0;36mforward[0;34m()[0m
[0;32m     22 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mn_neg[0m [0;34m==[0m [0mn_neg[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0;34m"All datapoints should same number of negatives"[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m---> 24 [0;31m        [0mmax_n_neg[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mn_neg[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mn_neg[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_scores[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mpos_targ[0m[0;34m,[0m [0mneg_targ[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m        [0mind

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/832060994.py[0m(25)[0;36mforward[0;34m()[0m
[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m     24 [0;31m        [0mmax_n_neg[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mn_neg[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mn_neg[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_scores[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mpos_targ[0m[0;34m,[0m [0mneg_targ[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m        [0mindices[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_indices[0m[0;34m([0m[0mpos_idx[0m[0;34m,[0m [0mlen[0m[0;34m([0m[0minp[0m[0;34m)[0m[0;34m,[0m [0mneg_idx[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m        [0;32mreturn[0m [0

ipdb>  max_n_neg
ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/832060994.py[0m(26)[0;36mforward[0;34m()[0m
[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m     24 [0;31m        [0mmax_n_neg[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mn_neg[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mn_neg[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_scores[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mpos_targ[0m[0;34m,[0m [0mneg_targ[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 26 [0;31m        [0mindices[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_indices[0m[0;34m([0m[0mpos_idx[0m[0;34m,[0m [0mlen[0m[0;34m([0m[0minp[0m[0;34m)[0m[0;34m,[0m [0mneg_idx[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m        [0;32mreturn[0m [0

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/832060994.py[0m(27)[0;36mforward[0;34m()[0m
[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m     24 [0;31m        [0mmax_n_neg[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mn_neg[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mn_neg[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_scores[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mpos_targ[0m[0;34m,[0m [0mneg_targ[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m        [0mindices[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_indices[0m[0;34m([0m[0mpos_idx[0m[0;34m,[0m [0mlen[0m[0;34m([0m[0minp[0m[0;34m)[0m[0;34m,[0m [0mneg_idx[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 27 [0;31m        [0;32mreturn[0m [0

ipdb>  indices


tensor([[    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67,   105,   102],
        [    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67,   105,   102],
        [    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67,   105,   102],
        [    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67,   105,   102],
        [    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67,   105,   102],
        [    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67,   105,   102],
        [    0,     2,     3, 26766,     9,    12,    14,    17, 56258,    24,
            42,    45,    51,    52,    66,    67, 

ipdb>  s


--Call--
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(4)[0;36mforward[0;34m()[0m
[0;32m      2 [0;31m[0;32mclass[0m [0mMultiTripletFromScores[0m[0;34m([0m[0mBaseMultiTriplet[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m[0;34m[0m[0m
[0m[0;32m----> 4 [0;31m    def forward(
[0m[0;32m      5 [0;31m        [0mself[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m        [0mscores[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mFloatTensor[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(20)[0;36mforward[0;34m()[0m
[0;32m     18 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 20 [0;31m        [0mstore_attr[0m[0;34m([0m[0;34m'margin,tau,apply_softmax,n_negatives'[0m[0;34m,[0m [0mis_none[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m     22 [0;31m        [0;32massert[0m [0mscores[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m,[0m [0;34m"`scores` should be two dimensional matrix."[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(22)[0;36mforward[0;34m()[0m
[0;32m     20 [0;31m        [0mstore_attr[0m[0;34m([0m[0;34m'margin,tau,apply_softmax,n_negatives'[0m[0;34m,[0m [0mis_none[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m---> 22 [0;31m        [0;32massert[0m [0mscores[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m,[0m [0;34m"`scores` should be two dimensional matrix."[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m        [0;32massert[0m [0minp2targ_idx[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m,[0m [0;34m"`inp2targ_idx` should be two dimensional matrix."[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(23)[0;36mforward[0;34m()[0m
[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m     22 [0;31m        [0;32massert[0m [0mscores[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m,[0m [0;34m"`scores` should be two dimensional matrix."[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 23 [0;31m        [0;32massert[0m [0minp2targ_idx[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m,[0m [0;34m"`inp2targ_idx` should be two dimensional matrix."[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m[0;34m[0m[0m
[0m[0;32m     25 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_incidence[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m,[0m [0minp2targ_idx[0m[0;34m.[0m[0mflatten[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mn_pinp2ta

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(25)[0;36mforward[0;34m()[0m
[0;32m     23 [0;31m        [0;32massert[0m [0minp2targ_idx[0m[0;34m.[0m[0mdim[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m,[0m [0;34m"`inp2targ_idx` should be two dimensional matrix."[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_incidence[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m,[0m [0minp2targ_idx[0m[0;34m.[0m[0mflatten[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mpos_incidence[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(26)[0;36mforward[0;34m()[0m
[0;32m     24 [0;31m[0;34m[0m[0m
[0m[0;32m     25 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_incidence[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m,[0m [0minp2targ_idx[0m[0;34m.[0m[0mflatten[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mn_pinp2targ[0m[0;34m,[0m [0mpinp2targ_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 26 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mpos_incidence[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m[0;34m[0m[0m
[0m[0;32m     28 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_pos_scores[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m

ipdb>  pos_incidence


tensor([ True,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, 

ipdb>  pos_incidence.long()


tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(28)[0;36mforward[0;34m()[0m
[0;32m     26 [0;31m        [0mpos_incidence[0m [0;34m=[0m [0mpos_incidence[0m[0;34m.[0m[0mview[0m[0;34m([0m[0minp2targ_idx[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m[0;34m[0m[0m
[0m[0;32m---> 28 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_pos_scores[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m        [0mpos_scores[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpos_scores[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;34m~[0m[0mpos_incidence[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  pos_incidence


tensor([[ True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False,  True,  True, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False,  True,  True, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False,  True,
          True, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False,  True, False, False, False, False, 

ipdb>  pos_incidence.long()


tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(29)[0;36mforward[0;34m()[0m
[0;32m     27 [0;31m[0;34m[0m[0m
[0m[0;32m     28 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_pos_scores[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 29 [0;31m        [0mpos_scores[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpos_scores[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;34m~[0m[0mpos_incidence[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(30)[0;36mforward[0;34m()[0m
[0;32m     28 [0;31m        [0mpos_scores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_pos_scores[0m[0;34m([0m[0mscores[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m        [0mpos_scores[0m[0;34m,[0m [0mpos_mask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0malign_indices[0m[0;34m([0m[0mpos_scores[0m[0;34m,[0m [0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;34m~[0m[0mpos_incidence[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m[0;34m[0m[0m
[0m[0;32m     32 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(32)[0;36mforward[0;34m()[0m
[0;32m     30 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0;34m~[0m[0mpos_incidence[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m[0;34m[0m[0m
[0m[0;32m---> 32 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     33 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mloss[0m [0;34m*[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     34 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(33)[0;36mforward[0;34m()[0m
[0;32m     31 [0;31m[0;34m[0m[0m
[0m[0;32m     32 [0;31m        [0mloss[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m-[0m [0mpos_scores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmargin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 33 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mloss[0m [0;34m*[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     34 [0;31m[0;34m[0m[0m
[0m[0;32m     35 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(35)[0;36mforward[0;34m()[0m
[0;32m     33 [0;31m        [0mloss[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mrelu[0m[0;34m([0m[0mloss[0m [0;34m*[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     34 [0;31m[0;34m[0m[0m
[0m[0;32m---> 35 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     36 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m[0;34m[0m[0m
[0m


ipdb>  loss


tensor([[[0.0000, 0.0000, 0.1512, 0.1450, 0.0717, 0.0613, 0.0000, 0.0894,
          0.1255, 0.1141, 0.0000, 0.1403, 0.2108, 0.2186, 0.1381, 0.1448,
          0.1302, 0.1164],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0222, 0.0300, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.1060, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1983, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.1494, 0.1093],
         [0.9109, 0.9626, 0.0000, 1.1060, 0.8554, 0.8595, 0.7270, 0.9610,
          0.9516, 1.1983, 0.7593, 0.8824, 0.9053, 0.9212, 0.8986, 0.9027,
          1.1494, 1.1093]],

        [[0.0130, 0.0000, 0.1060, 0.0000, -0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1053, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0830, 0.0585],
         [0.4655, 0.4088, 0.5584, 0.0000, 0.0000, 0.3249, 0.2112, 0.3950,
          0.4283, 0.5578, 0.2175, 0.3474, 0.34

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(36)[0;36mforward[0;34m()[0m
[0;32m     34 [0;31m[0;34m[0m[0m
[0m[0;32m     35 [0;31m        [0mscores[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 36 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m[0;34m[0m[0m
[0m[0;32m     38 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(38)[0;36mforward[0;34m()[0m
[0;32m     36 [0;31m        [0mneg_incidence[0m [0;34m=[0m [0mneg_incidence[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mexpand_as[0m[0;34m([0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m[0;34m[0m[0m
[0m[0;32m---> 38 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;3

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(39)[0;36mforward[0;34m()[0m
[0;32m     37 [0;31m[0;34m[0m[0m
[0m[0;32m     38 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 39 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(40)[0;36mforward[0;34m()[0m
[0;32m     38 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_negatives[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m            [0mloss[0m[0;34m,[0m [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtopk[0m[0;34m([0m[0mloss[0m[0;34m,[0m [0mmin[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_negatives[0m[0;34m,[0m [0mloss[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mlargest[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m,[0m [0mneg_incidence[0m[0;34

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(42)[0;36mforward[0;34m()[0m
[0;32m     40 [0;31m            [0mscores[0m[0;34m,[0m [0mneg_incidence[0m [0;34m=[0m [0mscores[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m,[0m [0mneg_incidence[0m[0;34m.[0m[0mgather[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0midx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m[0;34m[0m[0m
[0m[0;32m---> 42 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(43)[0;36mforward[0;34m()[0m
[0;32m     41 [0;31m[0;34m[0m[0m
[0m[0;32m     42 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 43 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(44)[0;36mforward[0;34m()[0m
[0;32m     42 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mapply_softmax[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 44 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m 

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(45)[0;36mforward[0;34m()[0m
[0;32m     43 [0;31m            [0mmask[0m [0;34m=[0m [0mloss[0m [0;34m!=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 45 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     47 [0;31m            [0mloss[0m [0;34m

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(46)[0;36mforward[0;34m()[0m
[0;32m     44 [0;31m            [0mpenalty[0m [0;34m=[0m [0mscores[0m [0;34m/[0m [0mself[0m[0;34m.[0m[0mtau[0m [0;34m*[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 46 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     47 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m [0;34m*[0m [0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(47)[0;36mforward[0;34m()[0m
[0;32m     45 [0;31m            [0mpenalty[0m[0;34m[[0m[0mneg_incidence[0m [0;34m==[0m [0;36m0[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfinfo[0m[0;34m([0m[0mpenalty[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m.[0m[0mmin[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m            [0mpenalty[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mpenalty[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 47 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m [0;34m*[0m [0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m[0;34m[0m[0m
[0m[0;32m     49 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;3

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(49)[0;36mforward[0;34m()[0m
[0;32m     47 [0;31m            [0mloss[0m [0;34m=[0m [0mloss[0m [0;34m*[0m [0mpenalty[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m[0;34m[0m[0m
[0m[0;32m---> 49 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mn_inp2targ[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool

ipdb>  loss


tensor([[[0.0400, 0.0357, 0.0141, 0.0127, 0.0127, 0.0117, 0.0113, 0.0099,
          0.0091, 0.0077],
         [0.0155, 0.0106, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0819, 0.0378, 0.0185, 0.0174, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.4181, 0.2458, 0.1590, 0.1532, 0.0318, 0.0312, 0.0282, 0.0201,
          0.0180, 0.0169]],

        [[0.0278, 0.0275, 0.0173, 0.0096, 0.0014, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.1141, 0.1132, 0.0870, 0.0650, 0.0376, 0.0238, 0.0202, 0.0187,
          0.0157, 0.0122]],

        [[0.0303, 0.0263, 0.0101, 0.0067, 0.0064, 0.0058, 0.0053, 0.0024,
          0.0013, 0.0004],
         [0.0367, 0.0322, 0.0135, 0.0096, 0.0092, 0.0086, 0.0079, 0.0045,
          0.0030, 0.0020]],

        [[0.0317, 0.0257, 0.0239, 0.0192, 0.0189, 0.0186, 0.0179, 0.0157,
          0.0152, 0.0145],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(50)[0;36mforward[0;34m()[0m
[0;32m     48 [0;31m[0;34m[0m[0m
[0m[0;32m     49 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 50 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mn_inp2targ[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 

ipdb>  loss


tensor([[[3.9988e-03, 3.5693e-03, 1.4111e-03, 1.2715e-03, 1.2663e-03,
          1.1744e-03, 1.1307e-03, 9.8511e-04, 9.0616e-04, 7.6707e-04],
         [1.5515e-03, 1.0642e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

        [[9.0991e-03, 4.2019e-03, 2.0602e-03, 1.9306e-03, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [4.1809e-02, 2.4585e-02, 1.5896e-02, 1.5324e-02, 3.1814e-03,
          3.1232e-03, 2.8168e-03, 2.0129e-03, 1.7953e-03, 1.6873e-03]],

        [[3.0925e-03, 3.0543e-03, 1.9253e-03, 1.0624e-03, 1.5037e-04,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.1409e-02, 1.1323e-02, 8.6951e-03, 6.4952e-03, 3.7559e-03,
          2.3838e-03, 2.0158e-03, 1.8723e-03, 1.5745e-03, 1.2207e-03]],

        [[3.0315e-03, 2.6349e-03, 1.0056e-03, 6.7048e-04, 6.3825e-04,
          5.8277e-04, 5.3167e-04, 2.4303e-04, 1.3016e-04, 4.4896e-05],
       

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(51)[0;36mforward[0;34m()[0m
[0;32m     49 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mneg_incidence[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m        [0mloss[0m [0;34m/=[0m [0;34m([0m[0mn_inp2targ[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0;36m1e-6[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 51 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m[0;34m[0m[0m
[0m[0;32m     53 

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(53)[0;36mforward[0;34m()[0m
[0;32m     51 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m[0;34m[0m[0m
[0m[0;32m---> 53 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'mean'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m/[0m[0mlen[0m[0;34m([0m[0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m        [0;32melif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'sum'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m        [0;32melse[0m[0;34m:[0m [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34mf'[0m[0;34m`reduction` cannot be `[0m[0

ipdb>  loss


tensor(0.2228, grad_fn=<SumBackward0>)


ipdb>  n


--Return--
tensor(0.0223...DivBackward0>)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/86789608.py[0m(53)[0;36mforward[0;34m()[0m
[0;32m     51 [0;31m        [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m[[0m[0mpos_mask[0m[0;34m.[0m[0mbool[0m[0;34m([0m[0;34m)[0m[0;34m][0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m[0;34m[0m[0m
[0m[0;32m---> 53 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'mean'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m/[0m[0mlen[0m[0;34m([0m[0mn_inp2targ[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m        [0;32melif[0m [0mself[0m[0;34m.[0m[0mreduction[0m [0;34m==[0m [0;34m'sum'[0m[0;34m:[0m [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m        [0;32melse[0m[0;34m:[0m [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34mf

ipdb>  


--Return--
tensor(0.0223...DivBackward0>)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_32368/832060994.py[0m(27)[0;36mforward[0;34m()[0m
[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m     24 [0;31m        [0mmax_n_neg[0m [0;34m=[0m [0;32mNone[0m [0;32mif[0m [0mn_neg[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0mn_neg[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m        [0mscores[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_scores[0m[0;34m([0m[0minp[0m[0;34m,[0m [0mpos_targ[0m[0;34m,[0m [0mneg_targ[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m        [0mindices[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_indices[0m[0;34m([0m[0mpos_idx[0m[0;34m,[0m [0mlen[0m[0;34m([0m[0minp[0m[0;34m)[0m[0;34m,[0m [0mneg_idx[0m[0;34m,[0m [0mmax_n_neg[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--

ipdb>  c


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m



## `MultiRankingWithNegatives`

In [115]:
#| export
class MultiRankingFromScores(BaseLoss, LossOperations):

    def __init__(
        self,
        tau:Optional[float]=1.0,
        n_negatives:Optional[int]=10,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('n_negatives')
        self.tau = nn.Parameter(torch.tensor(tau, dtype=torch.float32))

    def forward(
        self, 
        scores:torch.FloatTensor,  
        inp2targ_idx:torch.LongTensor,

        n_inp2targ:torch.LongTensor,
        
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        
        **kwargs
    ):        
        assert scores.dim() == 2, "`scores` should be two dimensional matrix."
        assert inp2targ_idx.dim() == 2, "`inp2targ_idx` should be two dimensional matrix."
        
        pos_incidence = self.get_incidence(inp2targ_idx.shape[1], inp2targ_idx.flatten(), n_pinp2targ, pinp2targ_idx)
        pos_incidence = pos_incidence.view(inp2targ_idx.shape)
        
        if self.n_negatives is not None:
            scores, idx = self.masked_inclusive_topk(scores, pos_incidence, k=self.n_negatives)
            pos_incidence = pos_incidence.gather(1, idx)

        loss = -F.log_softmax(scores/self.tau, dim=1)
        
        loss /= (pos_incidence.sum(dim=1, keepdim=True) + 1e-6)
        loss = loss[pos_incidence.bool()].sum()
        
        if self.reduction == 'mean': return loss/len(n_inp2targ)
        elif self.reduction == 'sum': return loss
        else: raise ValueError(f'`reduction` cannot be `{self.reduction}`')
        

In [116]:
#| export
MultiRankingWithNegatives = mix_classes(BaseWithNegatives, MultiRankingFromScores)

### Examples

In [92]:
loss_fn = MultiRankingWithNegatives(tau=1.0, n_negatives=5, reduce='mean')

In [99]:
loss = loss_fn(inp, pos_targ, n_pos, pos_idx,  neg_targ, n_neg, neg_idx, n_ppos, ppos_idx)

In [100]:
loss

tensor(1.5485, grad_fn=<DivBackward0>)

## `MultiSoupConWithNegatives`

In [78]:
#| export
class MultiSoupConFromScores(BaseLoss, LossOperations):

    def __init__(
        self,
        tau:Optional[float]=1.0,
        n_negatives:Optional[int]=10,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('n_negatives')
        self.tau = nn.Parameter(torch.tensor(tau, dtype=torch.float32))

    def forward(
        self, 
        scores:torch.FloatTensor,  
        inp2targ_idx:torch.LongTensor,

        n_inp2targ:torch.LongTensor,
        
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        
        **kwargs
    ):        
        assert scores.dim() == 2, "`scores` should be two dimensional matrix."
        assert inp2targ_idx.dim() == 2, "`inp2targ_idx` should be two dimensional matrix."
        
        pos_incidence = self.get_incidence(inp2targ_idx.shape[1], inp2targ_idx.flatten(), n_pinp2targ, pinp2targ_idx)
        pos_incidence = pos_incidence.view(inp2targ_idx.shape)
        
        if self.n_negatives is not None:
            scores, idx = self.masked_inclusive_topk(scores, pos_incidence, k=self.n_negatives)
            pos_incidence = pos_incidence.gather(1, idx)
            
        _, col_idx = torch.where(pos_incidence)
        row_idx = torch.arange(len(col_idx), device=col_idx.device)

        pos_scores = scores[pos_incidence]
        scores[pos_incidence] = torch.finfo(scores.dtype).min
        n_inp2targ = pos_incidence.sum(dim=1)
        scores = scores.repeat_interleave(n_inp2targ, dim=0)
        scores[row_idx, col_idx] = pos_scores
        
        loss = -F.log_softmax(scores/self.tau, dim=1)
        loss = loss[row_idx, col_idx]
        loss /= n_inp2targ.repeat_interleave(n_inp2targ)
        loss = loss.sum()

        if self.reduction == 'mean': return loss/len(n_inp2targ)
        elif self.reduction == 'sum': return loss
        else: raise ValueError(f'`reduction` cannot be `{self.reduction}`')
        

In [79]:
#| export
MultiSoupConWithNegatives = mix_classes(BaseWithNegatives, MultiSoupConFromScores)

### Examples

In [83]:
loss_fn = MultiSoupConWithNegatives(tau=1.0, n_negatives=5, reduce='mean')

In [88]:
loss = loss_fn(inp, pos_targ, n_pos, pos_idx,  neg_targ, n_neg, neg_idx, n_ppos, ppos_idx)

In [89]:
loss

tensor(1.3742, grad_fn=<DivBackward0>)

## `MarginMSEWithNegatives`

In [80]:
#| export
class MarginMSEWithNegatives(BaseLoss):

    def forward(
        self, 
        inp:torch.FloatTensor,
        
        pos_targ:torch.FloatTensor,
        pos_scores:torch.FloatTensor,

        neg_targ:torch.FloatTensor,
        neg_scores:torch.FloatTensor,
        **kwargs
    ):  
        bsz = len(inp)
        
        assert len(pos_scores) % bsz == 0, "Number of elements in `pos_scores` should be divisible by batch size."
        assert len(neg_scores) % bsz == 0, "Number of elements in `neg_scores` should be divisible by batch size."
        
        assert len(pos_targ) == len(pos_scores), "`pos_targ` and `pos_scores` should have same number of elements."
        assert len(neg_targ) == len(neg_scores), "`neg_targ` and `neg_scores` should have same number of elements."
        
        n = len(pos_targ) // bsz
        pos_targ, pos_scores = pos_targ.view(bsz, n, -1), pos_scores.view(bsz, n, 1)
        n = len(neg_targ) // bsz
        neg_targ, neg_scores = neg_targ.view(bsz, n, -1), neg_scores.view(bsz, 1, n)

        labels = pos_scores - neg_scores
        
        inp = inp.unsqueeze(1)
        pos_scores = inp @ pos_targ.transpose(1, 2)
        neg_scores = inp @ neg_targ.transpose(1, 2)
        margins = pos_scores.transpose(1, 2) - neg_scores

        return F.mse_loss(margins.flatten(), labels.flatten())
        

### Example

In [44]:
import types
from xcai.models.PPP0XX import DBT009

In [45]:
margin, tau = 0.3, 0.1
apply_softmax = True
n_negatives = 10

In [46]:
model = DBT009.from_pretrained('distilbert-base-uncased', use_encoder_parallel=False)

Some weights of DBT009 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [47]:
def forward(self, *args, **kwargs): return 1.0
model.loss_fn.forward = types.MethodType(forward, model.loss_fn)

In [69]:
output = model(**batch.to(model.device))
inp, pos_targ = output.data_repr, output.lbl2data_repr

In [70]:
output = model(data_input_ids=batch['data_input_ids'], data_attention_mask=batch['data_attention_mask'], 
               lbl2data_data2ptr=batch['lnk2data_data2ptr'], lbl2data_idx=batch['lnk2data_idx'], 
               lbl2data_input_ids=batch['lnk2data_input_ids'], lbl2data_attention_mask=batch['lnk2data_attention_mask'], 
               plbl2data_data2ptr=batch['plnk2data_data2ptr'], plbl2data_idx=batch['plbl2data_idx'])
_, neg_targ = output.data_repr, output.lbl2data_repr

In [77]:
loss_fn = MarginMSEWithNegatives(reduce='mean')

In [72]:
inp.shape, pos_targ.shape, neg_targ.shape

(torch.Size([4, 768]), torch.Size([8, 768]), torch.Size([12, 768]))

In [73]:
pos_scores, neg_scores = torch.rand(pos_targ.shape[0]), torch.rand(neg_targ.shape[0])

In [81]:
loss = loss_fn(inp, pos_targ, pos_scores, neg_targ, neg_scores)

> [0;32m/tmp/ipykernel_23124/135438691.py[0m(17)[0;36mforward[0;34m()[0m
[0;32m     15 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m[0;34m[0m[0m
[0m[0;32m---> 17 [0;31m        [0mbsz[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0minp[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m[0;34m[0m[0m
[0m[0;32m     19 [0;31m        [0;32massert[0m [0mlen[0m[0;34m([0m[0mpos_scores[0m[0;34m)[0m [0;34m%[0m [0mbsz[0m [0;34m==[0m [0;36m0[0m[0;34m,[0m [0;34m"Number of elements in `pos_scores` should be divisible by batch size."[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  q


In [79]:
loss

tensor(0.0968, grad_fn=<MseLossBackward0>)

## Cosine

In [None]:
#| export
class Cosine(BaseLoss):

    def __init__(self, 
                 **kwargs):
        super().__init__(**kwargs)
        

In [None]:
#| export
@patch
def forward(cls:Cosine, 
            inp:torch.FloatTensor,
            inp_mask:torch.FloatTensor,
            targ:torch.LongTensor,
            targ_mask:torch.LongTensor,
            **kwargs):
    seq_len = min(inp.shape[1], targ.shape[1])
    
    inp_mask = inp_mask.unsqueeze(2).expand(inp.size()).float()
    targ_mask = targ_mask.unsqueeze(2).expand(targ.size()).float()

    inp, targ = F.normalize(inp, dim=-1),F.normalize(targ, dim=-1)
    
    inp,targ = inp*inp_mask,targ*targ_mask
    inp,targ = inp[:,:seq_len],targ[:,:seq_len]

    loss = 1.0 - torch.sum(inp*targ, dim=-1)
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

### Example

In [None]:
loss_fn = Cosine(reduce='mean')

In [None]:
loss = loss_fn(data_embed, data_attention_mask, lbl2data_embed, lbl2data_attention_mask)

In [None]:
loss

tensor(0.6899, device='cuda:0', grad_fn=<MeanBackward0>)

## Entropy

In [None]:
#| export
class Entropy(BaseLoss):

    def __init__(self, 
                 margin:Optional[float]=0.8,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=True,
                 n_negatives:Optional[int]=5,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('margin,tau,apply_softmax,n_negatives')
        

In [None]:
#| export
@patch
def forward(cls:Entropy, 
            inp:torch.FloatTensor,
            targ:torch.LongTensor,
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
    sc = targ.exp()@inp.T
    
    sc_p =  sc.diagonal().unsqueeze(1)
    _, ne_idx = torch.topk(torch.where(ne == 0, torch.finfo(sc.dtype).min, sc), min(cls.n_negatives, sc.shape[0]-1), dim=1, largest=True)
    sc_n = sc.gather(1, ne_idx)
    
    loss = torch.relu(sc_n - sc_p + cls.margin)
    
    if cls.apply_softmax:
        m = loss != 0
        p = torch.softmax(sc_n/cls.tau * m, dim=1)
        loss = loss*p
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

### Example

In [None]:
el_fn = Entropy(margin=1e-2, tau=0.1, apply_softmax=True, n_negatives=5, reduce='mean')

In [None]:
loss = el_fn( F.log_softmax(data_repr, dim=-1), F.log_softmax(lbl2data_repr, dim=-1), lbl2data_idx, 
             kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx']); loss

tensor(0.0019, grad_fn=<MeanBackward0>)

In [None]:
loss

tensor(0.0019, grad_fn=<MeanBackward0>)

## Triplet

In [None]:
#| export
class Triplet(BaseLoss):

    def __init__(self, 
                 margin:Optional[float]=0.8,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=True,
                 n_negatives:Optional[int]=5,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('margin,tau,apply_softmax,n_negatives')


In [None]:
#| export
@patch
def forward(cls:Triplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
    sc = inp@targ.T
    sc_p =  sc.diagonal().unsqueeze(1)
    _, ne_idx = torch.topk(torch.where(ne == 0, -10, sc), min(cls.n_negatives, sc.shape[0]-1), dim=1, largest=True)
    sc_n = sc.gather(1, ne_idx)
    
    loss = torch.relu(sc_n - sc_p + cls.margin)
    
    if cls.apply_softmax:
        m = loss != 0
        p = torch.softmax(sc_n/cls.tau * m, dim=1)
        loss = loss*p
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
@patch
def forward(cls:Triplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
    sc = inp@targ.T
    loss = torch.relu((sc - sc.diagonal().unsqueeze(1) + cls.margin) * ne)
    
    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[0]-1), dim=1, largest=True)
        sc, ne = sc.gather(1, idx), ne.gather(1, idx)
        
    if cls.apply_softmax:
        m = loss != 0
        p = sc/cls.tau * m
        p[ne == 0] = torch.finfo(p.dtype).min
        p = torch.softmax(p, dim=1)
        loss = loss*p
        
    loss /= (ne.sum(dim=1, keepdim=True) + 1e-9)
    
    if cls.reduction == 'mean': return loss.sum(dim=1).mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
tl_fn = Triplet(0.8, tau=0.1, apply_softmax=True, n_negatives=5, reduce='mean')

In [None]:
loss = tl_fn(data_repr, lbl2data_repr, lbl2data_idx, kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx']); loss

tensor(0.9848, grad_fn=<MeanBackward0>)