# Losses

In [1]:
#| default_exp losses

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import functools, torch, torch.nn as nn, torch.nn.functional as F, pickle
from typing import MutableSequence, Union

from fastcore.utils import *
from fastcore.meta import *

from xcai.torch_core import *
from xcai.core import *

In [4]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [5]:
import pickle, torch.autograd.profiler as profiler
from xcai.block import *
from xcai.models.MMM0XX import *
from xcai.main import *

In [6]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
config_file = 'wikiseealsotitles'
config_key = 'data_lnk'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'
pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles_data-lnk_distilbert-base-uncased_sxc.joblib'

In [7]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_slbl_samples=2, 
                    n_sdata_meta_samples=3, do_build=True)



In [82]:
block.train.dset.data.main_oversample = False

In [83]:
batch = block.train.one_batch(4)

In [84]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_data2ptr', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'plnk2lbl_idx', 'plnk2lbl_lbl2ptr', 'lnk2lbl_idx', 'lnk2lbl_lbl2ptr', 'lnk2lbl_identifier', 'lnk2lbl_input_text', 'lnk2lbl_input_ids', 'lnk2lbl_attention_mask', 'lnk2lbl_data2ptr', 'plnk2lbl_data2ptr'])

In [85]:
batch['lbl2data_data2ptr'], batch['plbl2data_data2ptr']

(tensor([2, 1, 2, 2]), tensor([3, 1, 2, 4]))

In [86]:
batch['lnk2data_data2ptr'], batch['plnk2data_data2ptr']

(tensor([3, 3, 3, 3]), tensor([20, 20, 20, 20]))

## Helper

In [87]:
#|export
def get_sparse_matrix(data_idx:torch.Tensor, n_data:torch.Tensor, scores:Optional[torch.Tensor]=None):
    data_ptr = torch.cat([torch.zeros(1, device=n_data.device, dtype=n_data.dtype), n_data.cumsum(0)])
    if scores is None: scores = torch.ones_like(data_idx)
    if data_idx.shape != scores.shape: raise ValueError(f'`data_idx` and `scores` should have same shape.')
    return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)
    

## BaseLoss

In [12]:
#| export
class BaseLoss(nn.Module):

    def __init__(self, 
                 reduce:Optional[str]=None, 
                 **kwargs):
        super().__init__()
        self.reduce = reduce

    @property
    def reduction(self) -> str: return self.reduce
    
    @reduction.setter
    def reduction(self, v:str):
        "Sets the reduction style (typically 'mean', 'sum', or 'none')" 
        self.reduce = v
        

## MultiCrossEntropy

In [None]:
#| export
class MultiCrossEntropy(BaseLoss):

    def __init__(self,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 vocab_weights:Optional[torch.Tensor]=None,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('tn_targ,ig_tok,vocab_weights')
        self.o = torch.ones(tn_targ, dtype=torch.int64) if tn_targ is not None else None
        

In [None]:
vocab_weights = torch.rand(m.config.vocab_size)
mce_fn = MultiCrossEntropy(1000, vocab_weights=vocab_weights, reduce='mean')

In [None]:
#| export
@patch
def forward(cls:MultiCrossEntropy,
            inp:torch.FloatTensor,
            targ:torch.LongTensor,
            n_inp2targ:Optional[torch.LongTensor]=None,
            tn_targ:Optional[int]=None, 
            ig_tok:Optional[int]=None,
            vocab_weights:Optional[torch.Tensor]=None,
            **kwargs):
    store_attr('tn_targ,ig_tok,vocab_weights', is_none=False)
    
    cls.o = cls.o.to(inp.device) if cls.o is not None else None
    cls.vocab_weights = cls.vocab_weights.to(inp.device) if cls.vocab_weights is not None else None
    
    tn_targ, targ_len = targ.shape
    bsz, inp_len, vocab_sz = inp.shape
    
    if cls.vocab_weights is not None and cls.vocab_weights.shape[0] != vocab_sz: 
        raise ValueError(f"`vocab_weights` should have {vocab_sz} elements.")
    
    seq_len = min(targ_len, inp_len)
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len].transpose(1,2), targ[:, :seq_len]
    if cls.vocab_weights is not None: inp *= cls.vocab_weights.unsqueeze(1)
    
    if n_inp2targ is not None:
        mn_targ = n_inp2targ.max()
    
        inp2targ_ptr = n_inp2targ.cumsum(dim=0)-1
        xn_inp2targ = mn_targ-n_inp2targ+1
        r_targ = (
            torch.ones(tn_targ, dtype=torch.int64, device=inp.device).scatter(0, inp2targ_ptr, xn_inp2targ)
            if cls.tn_targ is None or tn_targ > cls.tn_targ else
            cls.o[:tn_targ].scatter(0, inp2targ_ptr, xn_inp2targ)
        )
        xtarg = targ.repeat_interleave(r_targ, dim=0)
        s = inp.gather(1, xtarg.view(bsz, -1, seq_len)).view(-1, seq_len)
        s /= r_targ.repeat_interleave(r_targ, dim=0).view(-1, 1)
    else:
        if bsz != tn_targ: raise ValueError("`inp` and `targ` should have same number of elements as `n_inp2targ` is empty.")
        s = inp.gather(1, targ.view(bsz, -1, seq_len)).view(-1, seq_len); xtarg = targ
    
    idx = torch.where(xtarg != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    
    if cls.reduction == 'mean': return (loss/len(torch.where(targ != cls.ig_tok)[0])).sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')


In [None]:
loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr); loss

tensor(7.8086, device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
loss = mce_fn(lbl2data_logits, data_input_ids); loss

tensor(9.0191, device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
    print(loss)

tensor(16.0431, device='cuda:0', grad_fn=<SumBackward0>)


STAGE:2024-04-26 08:09:57 4294:4294 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-26 08:09:57 4294:4294 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-26 08:09:57 4294:4294 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [None]:
print(prof)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::max        15.76%     398.000us        17.19%     434.000us     434.000us           0 b           0 b         512 b           0 b             1  
                                            aten::empty         0.67%      17.000us         0.67%      17.000us      17.000us           0 b           0 b         512 b         512 b             1  
         

In [None]:
@patch
def forward(cls:MultiCrossEntropy, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            n_inp2targ:torch.LongTensor, 
            **kwargs):
    inp_len, targ_len = inp.shape[1], targ.shape[1]
    seq_len = min(inp_len, targ_len)
    
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len], targ[:, :seq_len].unsqueeze(2)
    if cls.vocab_weights is not None: inp *= cls.vocab_weights
    inp = inp.repeat_interleave(n_inp2targ, dim=0)
    
    s = inp.gather(2, targ)
    idx = torch.where(targ != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr); loss

tensor(7.8086, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
loss = mce_fn(lbl2data_logits, data_input_ids, torch.ones(len(data_input_ids), dtype=data_input_ids.dtype, device=data_input_ids.device)); loss

tensor(9.0191, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
    print(loss)

STAGE:2024-05-05 16:20:26 189949:189949 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


tensor(14.2102, grad_fn=<MeanBackward0>)


STAGE:2024-05-05 16:20:28 189949:189949 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-05-05 16:20:28 189949:189949 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [None]:
print(prof)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::log_softmax         0.18%       5.000us        15.61%     441.000us     441.000us           0 b           0 b     112.00 Mb           0 b             1  
                                     aten::_log_softmax        14.86%     420.000us        15.43%     436.000us     436.000us           0 b           0 b     112.00 Mb     112.00 Mb             1  
         

In [None]:
@patch
def forward(cls:MultiCrossEntropy, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            n_inp2targ:torch.LongTensor, 
            **kwargs):
    inp_len, targ_len = inp.shape[1], targ.shape[1]
    seq_len = min(inp_len, targ_len)
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len], targ[:, :seq_len]
    num, s = 0, []
    for i,n in zip(inp, n_inp2targ):
        for _ in range(n):
            s.append(i.gather(1, targ[num].view(-1, 1)).view(1, -1))
            num += 1
    s = torch.vstack(s)
    idx = torch.where(targ != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
    

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
    print(loss)

In [None]:
print(prof)

<unfinished torch.autograd.profile>


## Calibration

In [19]:
#| export
class Calibration(BaseLoss):

    def __init__(self,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 n_negatives:Optional[int]=10,
                 apply_softmax:Optional[bool]=True,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('margin,tau,n_negatives,apply_softmax')
        

In [20]:
#| export
@patch
def forward(cls:Calibration,
            einp:torch.FloatTensor,
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            n_inp2targ:torch.LongTensor,
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            n_negatives:Optional[int]=None,
            apply_softmax:Optional[bool]=None,
            **kwargs):
    store_attr('margin', is_none=False)

    einp, inp, targ = einp.float(), inp.float(), targ.float()
    esc,sc = einp@targ.T,inp@targ.T
    
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    pos = get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]

    mul = 2*pos - 1
    loss = F.relu((sc-esc)*mul + cls.margin)

    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[1]), dim=1, largest=True)
        esc,sc,mul = esc.gather(1, idx), sc.gather(1, idx), mul.gather(1, idx)
    
    if cls.apply_softmax:
        m = loss != 0
        s = torch.where(mul == 1, sc, esc)
        p = s/cls.tau * m
        p = torch.softmax(p, dim=1)
        loss = loss*p
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

### Example

In [None]:
loss_fn = Calibration(0.3, reduce='mean')

In [None]:
loss = loss_fn(data_repr+torch.randn(data_repr.shape), data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
               kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx']); loss

tensor(0.8689, grad_fn=<MeanBackward0>)

## `Old` MultiTriplet

In [None]:
class MultiTriplet(BaseLoss):

    def __init__(self,
                 bsz:Optional[int]=None, 
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.8,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('bsz,tn_targ,margin,tau,apply_softmax,n_negatives')
        self.u = torch.arange(bsz, dtype=torch.int64) if bsz is not None else None
        self.v = torch.ones(tn_targ, dtype=torch.int64) if tn_targ is not None else None
        

In [None]:
mtl_fn = MultiTriplet(bsz, 10_000, 0.8, tau=0.1, n_negatives=5, apply_softmax=True, reduce='mean')

In [None]:
@patch
def forward(cls:MultiTriplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            n_inp2targ:torch.LongTensor,
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    
    cls.u = cls.u.to(inp.device) if cls.u is not None else None
    cls.v = cls.v.to(inp.device) if cls.v is not None else None
    
    bsz, tn_targ, mn_targ = inp.shape[0], targ.shape[0], n_inp2targ.max()
    u = torch.arange(bsz, dtype=torch.int64, device=inp.device) if cls.u is None or cls.bsz < bsz else cls.u[:bsz]
    v = (
        torch.ones(tn_targ, dtype=torch.int64, device=targ.device)
        if cls.tn_targ is None or tn_targ > cls.tn_targ else cls.v[:tn_targ]
    )
    
    targ2inp_ptr = u.repeat_interleave(n_inp2targ)
    sc = targ@inp.T
    ps = sc.gather(1, targ2inp_ptr.view(-1,1))
    
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    ne = ne.unsqueeze(1)
    
    inp2targ_ptr = n_inp2targ.cumsum(dim=0)-1
    xn_inp2targ = mn_targ-n_inp2targ+1
    r_targ = v.scatter(0, inp2targ_ptr, xn_inp2targ)

    psx = ps.repeat_interleave(r_targ).view(bsz, -1, 1)
    sc = sc.T.view(bsz, 1, -1)
    loss = F.relu((sc - psx + cls.margin)*ne)
    
    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[2]-n_inp2targ.max()), dim=2, largest=True)
        sc, ne = sc.expand(-1, mn_targ, -1).gather(2, idx), ne.expand(-1, mn_targ, -1).gather(2, idx)
    else: ne = ne.expand(-1, mn_targ, -1)
    
    if cls.apply_softmax:
        m = loss != 0
        p = sc/cls.tau * m
        p[ne == 0] = torch.finfo(p.dtype).min
        p = torch.softmax(p, dim=2)
        loss = loss*p
        
    loss /= (ne.sum(dim=2, keepdim=True) + 1e-9)
    
    xr_targ = r_targ.repeat_interleave(r_targ).view(bsz, -1, 1)
    loss /= xr_targ
    
    if cls.reduction == 'mean': return loss.sum()/n_inp2targ.sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, kwargs['plbl2data_data2ptr'],
              kwargs['plbl2data_idx'], n_negatives=5); loss

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


tensor(1.2115, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx)
    print(loss)

In [None]:
print(prof)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::max         0.95%     519.000us         1.11%     610.000us     610.000us           0 b           0 b         512 b           0 b             1  
                                            aten::empty         0.08%      43.000us         0.08%      43.000us      43.000us           0 b           0 b         512 b         512 b             1  
         

This the below function is not upto date, it has errors in it.

In [None]:
@patch
def forward(cls:MultiTriplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            n_inp2targ:torch.LongTensor,
            inp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    bsz, tn_targ, mn_targ = inp.shape[0], targ.shape[0], n_inp2targ.max()
    u = torch.arange(bsz, dtype=torch.int64, device=inp.device) if cls.u is None or cls.bsz < bsz else cls.u[:bsz]
    v = (
        torch.ones(tn_targ, dtype=torch.int64, device=targ.device)
        if cls.tn_targ is None or tn_targ > cls.tn_targ else cls.v[:tn_targ]
    )
    
    targ2inp_ptr = u.repeat_interleave(n_inp2targ)
    sc = targ@inp.T
    ps = sc.gather(1, targ2inp_ptr.view(-1,1))
    
    _, idx = torch.unique(inp2targ_idx, return_inverse=True)
    ne = 1 - get_sparse_matrix(idx, n_inp2targ).to_dense()[:, idx]
    ne = ne.unsqueeze(1)
    
    inp2targ_ptr = n_inp2targ.cumsum(dim=0)-1
    xn_inp2targ = mn_targ-n_inp2targ+1
    r_targ = v.scatter(0, inp2targ_ptr, xn_inp2targ)

    psx = ps.repeat_interleave(r_targ).view(bsz, -1, 1)
    sc = sc.T.view(bsz, 1, -1)
    loss = torch.clamp((sc - psx + cls.margin)*ne, 0)
    
    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[2]-n_inp2targ.max()), dim=2, largest=True)
        sc, ne = sc.expand(-1, mn_targ, -1).gather(2, idx), ne.expand(-1, mn_targ, -1).gather(2, idx)
    else: ne = ne.expand(-1, mn_targ, -1)
    
    if cls.apply_softmax:
        m = loss != 0
        p = sc/cls.tau * m
        p[ne == 0] = -1e9
        p = torch.softmax(p, dim=2)
        loss = loss*p
        
    loss /= (ne.sum(dim=2, keepdim=True) + 1e-9)
    
    xr_targ = r_targ.repeat_interleave(r_targ).view(bsz, -1, 1)
    loss /= xr_targ
    
    if cls.reduction == 'mean': return loss.sum()/n_inp2targ.sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
@patch
def forward(cls:MultiTriplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            n_inp2targ:torch.LongTensor, 
            inp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=0.1,
            apply_softmax=False,
            **kwargs):
    store_attr('margin,tau,apply_softmax', is_none=False)
    score = inp@targ.T
    ptr, loss, num_ne = 0, [], 0
    
    _, idx = torch.unique(inp2targ_idx, return_inverse=True)
    ne = 1 - get_sparse_matrix(idx, n_inp2targ).to_dense()[:, idx]
    
    for i, n in enumerate(n_inp2targ):
        ps = score[i, ptr:ptr+n].view(-1, 1)
        fs = torch.clamp((score[i] - ps + cls.margin)*ne[i], 0)
        if cls.apply_softmax: 
            m = fs != 0
            p = torch.softmax(score[i]/self.tau * m, dim=1)
            fs = fs*p
        loss.append(fs)
        ptr += n.item()
        num_ne += n*ne[i].sum()
    loss = torch.vstack(loss)
    if cls.reduction == 'mean': return loss.sum()/num_ne
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
             

In [None]:
loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx); loss

tensor(0.7512, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = mtl_fn(data_repr, lbl2data_repr, lbl2data_data2ptr)
    print(loss)

STAGE:2024-04-19 01:54:10 13866:13866 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-19 01:54:10 13866:13866 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-19 01:54:10 13866:13866 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


tensor(1.4957, device='cuda:0', grad_fn=<MeanBackward0>)


In [None]:
print(prof)

-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                    aten::numpy_T         0.00%       5.000us         0.03%      53.000us      53.000us           0 b           0 b           0 b           0 b             1  
                                    aten::permute         0.03%      38.000us         0.03%      48.000us      48.000us           0 b           0 b           0 b           0 b             1  
                                 aten::

## `New` MultiTriplet

In [16]:
#| export
class BaseMultiTriplet(BaseLoss):

    def __init__(
        self,
        margin:Optional[float]=0.8,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('margin,tau,apply_softmax,n_negatives')

    def align_indices(self, indices:torch.Tensor, group_lengths:torch.Tensor):
        n, num_groups, max_len = len(indices), len(group_lengths), group_lengths.max()
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=indices.device), group_lengths)
    
        row_indices = torch.arange(n, device=indices.device)
    
        group_start = torch.cat([torch.zeros(1, dtype=group_lengths.dtype, device=group_lengths.device), group_lengths.cumsum(0)[:-1]], dim=0)
    
        within_idx = row_indices - group_start[group_ids]
    
        output = torch.zeros((num_groups, max_len), dtype=indices.dtype, device=indices.device)
        mask = torch.zeros((num_groups, max_len), device=indices.device)
        output[group_ids, within_idx] = indices
        mask[group_ids, within_idx] = 1.0
    
        return output, mask

    def remove_redundant_indices(self, inp2targ_idx:torch.Tensor, n_inp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor):
        mask = torch.isin(pinp2targ_idx, inp2targ_idx)
        new_pinp2targ_idx = pinp2targ_idx[mask]
    
        num_groups = len(n_pinp2targ)
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=n_pinp2targ.device), n_pinp2targ)
        new_n_pinp2targ = torch.bincount(group_ids[mask], minlength=num_groups)
    
        return new_pinp2targ_idx, new_n_pinp2targ

    def reset_indices(self, inp2targ_idx:torch.Tensor, n_inp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor):
        _, reset_indices, counts = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True, return_counts=True)
    
        _, idx_sorted = torch.sort(reset_indices, stable=True)
        cum_sum = torch.cat((torch.zeros((1,), dtype=counts.dtype, device=counts.device), counts.cumsum(0)[:-1]))
        indices = idx_sorted[cum_sum]
    
        inp2targ_idx = reset_indices[:len(inp2targ_idx)]
        pinp2targ_idx = reset_indices[len(inp2targ_idx):]
    
        return inp2targ_idx, pinp2targ_idx, indices

    def compute_scores(self, inp, targ, indices=None):
        if indices is not None: targ = targ[indices]
        return inp@targ.T

    def forward(
        self, 
        
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,

        inp:Optional[torch.FloatTensor]=None, 
        targ:Optional[torch.FloatTensor]=None,
        scores:Optional[torch.FloatTensor]=None,
        
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)

        inp, targ = inp.float(), targ.float()
        scores = scores if scores is None else scores.float()
        
        pinp2targ_idx, n_pinp2targ = self.remove_redundant_indices(inp2targ_idx, n_inp2targ, pinp2targ_idx, n_pinp2targ)
        inp2targ_idx, pinp2targ_idx, indices = self.reset_indices(inp2targ_idx, n_inp2targ, pinp2targ_idx, n_pinp2targ)

        scores = self.compute_scores(inp, targ, indices=indices) if scores is None else scores[:, indices]

        pos_indices, pos_mask = self.align_indices(inp2targ_idx, n_inp2targ)
        pos_scores = scores.gather(1, pos_indices)

        pos_incidence = torch.zeros_like(scores)
        
        ppos_indices, ppos_mask = self.align_indices(pinp2targ_idx, n_pinp2targ)
        pos_incidence = pos_incidence.scatter(1, ppos_indices, 1)

        ppos_indices[~ppos_mask.bool()] = -1
        row_idx = torch.where(torch.all(ppos_indices != 0, dim=1))[0]
        pos_incidence[row_idx, 0] = 0
        
        neg_incidence = 1 - pos_incidence

        loss = scores.unsqueeze(1) - pos_scores.unsqueeze(2) + self.margin
        loss = F.relu(loss * neg_incidence.unsqueeze(1))

        scores = scores.unsqueeze(1).expand_as(loss)
        neg_incidence = neg_incidence.unsqueeze(1).expand_as(loss)

        if self.n_negatives is not None:
            loss, idx = torch.topk(loss, min(self.n_negatives, loss.shape[2]), dim=2, largest=True)
            scores, neg_incidence = scores.gather(2, idx), neg_incidence.gather(2, idx)

        if self.apply_softmax:
            mask = loss != 0
            penalty = scores / self.tau * mask
            penalty[neg_incidence == 0] = torch.finfo(penalty.dtype).min
            penalty = torch.softmax(penalty, dim=2)
            loss = loss * penalty
            
        loss /= (neg_incidence.sum(dim=2, keepdim=True) + 1e-6)
        loss /= (n_inp2targ.unsqueeze(1).unsqueeze(1) + 1e-6)
        loss = loss[pos_mask.bool()].sum()
        
        if self.reduction == 'mean': return loss/len(n_inp2targ)
        elif self.reduction == 'sum': return loss
        else: raise ValueError(f'`reduction` cannot be `{self.reduction}`')
    

In [65]:
#| export
class MultiTriplet(BaseMultiTriplet):

    def forward(
        self, 
        inp:torch.FloatTensor, # bs x dim
        targ:torch.FloatTensor, # total labels in batch (t) x dim
        n_inp2targ:torch.LongTensor, # bs x dim (like indptr in sp.csr_matrix)
        inp2targ_idx:torch.LongTensor, # t x dim (index of label)
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
                               apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        

In [66]:
#| export
class MultiTripletFromInBatchScores(BaseMultiTriplet):

    def forward(
        self, 
        scores:torch.FloatTensor,  
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, scores=scores, margin=margin, tau=tau, 
                               apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        

### Example

In [77]:
import types
from xcai.models.PPP0XX import DBT009

In [20]:
margin, tau = 0.3, 0.1
apply_softmax = True
n_negatives = 10

In [39]:
model = DBT009.from_pretrained('distilbert-base-uncased', use_encoder_parallel=False)

Some weights of DBT009 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [42]:
def forward(self, *args, **kwargs): return 1.0
model.loss_fn.forward = types.MethodType(forward, model.loss_fn)

In [91]:
output = model(**batch.to(model.device))

In [92]:
inp, targ = output.data_repr, output.lbl2data_repr

inp2targ_idx, n_inp2targ = batch['lbl2data_idx'], batch['lbl2data_data2ptr']
pinp2targ_idx, n_pinp2targ = batch['plbl2data_idx'], batch['plbl2data_data2ptr']

In [109]:
loss_fn = MultiTriplet(margin, tau, apply_softmax, n_negatives, reduce='mean')

In [110]:
loss = loss_fn(inp, targ, n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx)

In [111]:
loss

tensor(0.0246, grad_fn=<DivBackward0>)

## `MultiTripletFromScores`

In [97]:
#| export
class MultiTripletFromScores(BaseMultiTriplet):

    @staticmethod
    def get_incidence(n_inp2targ:int, inp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor):
        row_idx = torch.arange(len(n_pinp2targ), device=n_pinp2targ.device)
        inp2targ_row_idx = row_idx.repeat_interleave(n_inp2targ)
        pinp2targ_row_idx = row_idx.repeat_interleave(n_pinp2targ)
        
        max_col_idx = max(int(inp2targ_idx.max()), int(pinp2targ_idx.max()))
        offset = max_col_idx + 1
    
        inp2targ_keys = inp2targ_row_idx * offset + inp2targ_idx
        pinp2targ_keys = pinp2targ_row_idx * offset + pinp2targ_idx
    
        return torch.isin(inp2targ_keys, pinp2targ_keys)

    @staticmethod
    def get_pos_scores(scores:torch.FloatTensor, n_inp2targ:torch.FloatTensor):
        row_idx = torch.arange(len(n_inp2targ), device=n_inp2targ.device)
        inp2targ_row_idx = row_idx.repeat_interleave(n_inp2targ)
        inp2targ_col_idx = torch.arange(n_inp2targ.sum(), device=n_inp2targ.device)
        return scores[inp2targ_row_idx, inp2targ_col_idx]
    
    def forward(
        self, 
        scores:torch.FloatTensor,  
        inp2targ_idx:torch.LongTensor,
        
        n_inp2targ:torch.LongTensor,
        
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
        
        assert scores.dim() == 2, "`scores` should be two dimensional matrix."
        assert inp2targ_idx.dim() == 2, "`inp2targ_idx` should be two dimensional matrix."
        
        pos_incidence = self.get_incidence(inp2targ_idx.shape[1], inp2targ_idx.flatten(), n_pinp2targ, pinp2targ_idx)
        pos_incidence = pos_incidence.view(inp2targ_idx.shape)

        pos_scores = self.get_pos_scores(scores, n_inp2targ)
        pos_scores, pos_mask = self.align_indices(pos_scores, n_inp2targ)
        neg_incidence = ~pos_incidence

        loss = scores.unsqueeze(1) - pos_scores.unsqueeze(2) + self.margin
        loss = F.relu(loss * neg_incidence.unsqueeze(1))
        
        scores = scores.unsqueeze(1).expand_as(loss)
        neg_incidence = neg_incidence.unsqueeze(1).expand_as(loss)

        if self.n_negatives is not None:
            loss, idx = torch.topk(loss, min(self.n_negatives, loss.shape[2]), dim=2, largest=True)
            scores, neg_incidence = scores.gather(2, idx), neg_incidence.gather(2, idx)

        if self.apply_softmax:
            mask = loss != 0
            penalty = scores / self.tau * mask
            penalty[neg_incidence == 0] = torch.finfo(penalty.dtype).min
            penalty = torch.softmax(penalty, dim=2)
            loss = loss * penalty

        loss /= (neg_incidence.sum(dim=2, keepdim=True) + 1e-6)
        loss /= (n_inp2targ.unsqueeze(1).unsqueeze(1) + 1e-6)
        loss = loss[pos_mask.bool()].sum()
        
        if self.reduction == 'mean': return loss/len(n_inp2targ)
        elif self.reduction == 'sum': return loss
        else: raise ValueError(f'`reduction` cannot be `{self.reduction}`')
        

## `MultiTripletWithNegatives`

In [120]:
#| export
class MultiTripletWithNegatives(MultiTripletFromScores):

    def get_scores(self, inp:torch.Tensor, pos_targ:torch.Tensor, neg_targ:Optional[torch.Tensor]=None, n_neg:Optional[int]=None):
        scores = inp @ pos_targ.T
        if neg_targ is not None:
            neg_scores = inp.unsqueeze(1) @ neg_targ.view(len(inp), n_neg, -1).transpose(1, 2)
            neg_scores = neg_scores.squeeze(1)
            scores = torch.hstack([scores, neg_scores])
        return scores

    def get_indices(self, pos_idx:torch.Tensor, bsz:int, neg_idx:Optional[torch.Tensor]=None, n_neg:Optional[int]=None):
        indices = torch.repeat_interleave(pos_idx.unsqueeze(0), bsz, 0)
        if neg_idx is not None:
            neg_idx = neg_idx.view(bsz, n_neg)
            indices = torch.hstack([indices, neg_idx])
        return indices

    def forward(
        self, 
        inp:torch.FloatTensor,
        
        pos_targ:torch.FloatTensor,
        n_pos:torch.LongTensor,
        pos_idx:torch.LongTensor,

        neg_targ:torch.FloatTensor,
        n_neg:torch.LongTensor,
        neg_idx:torch.LongTensor,
        
        n_ppos:torch.LongTensor,
        ppos_idx:torch.LongTensor,
        
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):  
        assert torch.all(n_neg == n_neg.max()), "All datapoints should same number of negatives"
        scores = self.get_scores(inp, pos_targ, neg_targ, n_neg.max())
        indices = self.get_indices(pos_idx, len(inp), neg_idx, n_neg.max())
        return super().forward(scores, indices, n_pos, n_pinp2targ=n_ppos, pinp2targ_idx=ppos_idx)
    

### Example

In [121]:
import types
from xcai.models.PPP0XX import DBT009

In [122]:
margin, tau = 0.3, 0.1
apply_softmax = True
n_negatives = 10

In [99]:
model = DBT009.from_pretrained('distilbert-base-uncased', use_encoder_parallel=False)

Some weights of DBT009 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [100]:
def forward(self, *args, **kwargs): return 1.0
model.loss_fn.forward = types.MethodType(forward, model.loss_fn)

In [101]:
output = model(**batch.to(model.device))
inp, pos_targ = output.data_repr, output.lbl2data_repr

In [110]:
output = model(data_input_ids=batch['data_input_ids'], data_attention_mask=batch['data_attention_mask'], 
               lbl2data_data2ptr=batch['lnk2data_data2ptr'], lbl2data_idx=batch['lnk2data_idx'], 
               lbl2data_input_ids=batch['lnk2data_input_ids'], lbl2data_attention_mask=batch['lnk2data_attention_mask'], 
               plbl2data_data2ptr=batch['plnk2data_data2ptr'], plbl2data_idx=batch['plbl2data_idx'])
_, neg_targ = output.data_repr, output.lbl2data_repr

In [123]:
pos_idx, n_pos = batch['lbl2data_idx'], batch['lbl2data_data2ptr']
ppos_idx, n_ppos = batch['plbl2data_idx'], batch['plbl2data_data2ptr']

neg_idx, n_neg = batch['lnk2data_idx'], batch['lnk2data_data2ptr']

In [124]:
loss_fn = MultiTripletWithNegatives(margin, tau, apply_softmax, n_negatives, reduce='mean')

In [125]:
loss = loss_fn(inp, pos_targ, n_pos, pos_idx,  neg_targ, n_neg, neg_idx, n_ppos, ppos_idx)

In [126]:
loss

tensor(0.0263, grad_fn=<DivBackward0>)

## `MarginMSEWithNegatives`

In [80]:
#| export
class MarginMSEWithNegatives(BaseLoss):

    def forward(
        self, 
        inp:torch.FloatTensor,
        
        pos_targ:torch.FloatTensor,
        pos_scores:torch.FloatTensor,

        neg_targ:torch.FloatTensor,
        neg_scores:torch.FloatTensor,
        **kwargs
    ):  
        bsz = len(inp)
        
        assert len(pos_scores) % bsz == 0, "Number of elements in `pos_scores` should be divisible by batch size."
        assert len(neg_scores) % bsz == 0, "Number of elements in `neg_scores` should be divisible by batch size."
        
        assert len(pos_targ) == len(pos_scores), "`pos_targ` and `pos_scores` should have same number of elements."
        assert len(neg_targ) == len(neg_scores), "`neg_targ` and `neg_scores` should have same number of elements."
        
        n = len(pos_targ) // bsz
        pos_targ, pos_scores = pos_targ.view(bsz, n, -1), pos_scores.view(bsz, n, 1)
        n = len(neg_targ) // bsz
        neg_targ, neg_scores = neg_targ.view(bsz, n, -1), neg_scores.view(bsz, 1, n)

        labels = pos_scores - neg_scores
        
        inp = inp.unsqueeze(1)
        pos_scores = inp @ pos_targ.transpose(1, 2)
        neg_scores = inp @ neg_targ.transpose(1, 2)
        margins = pos_scores.transpose(1, 2) - neg_scores

        return F.mse_loss(margins.flatten(), labels.flatten())
        

### Example

In [44]:
import types
from xcai.models.PPP0XX import DBT009

In [45]:
margin, tau = 0.3, 0.1
apply_softmax = True
n_negatives = 10

In [46]:
model = DBT009.from_pretrained('distilbert-base-uncased', use_encoder_parallel=False)

Some weights of DBT009 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [47]:
def forward(self, *args, **kwargs): return 1.0
model.loss_fn.forward = types.MethodType(forward, model.loss_fn)

In [69]:
output = model(**batch.to(model.device))
inp, pos_targ = output.data_repr, output.lbl2data_repr

In [70]:
output = model(data_input_ids=batch['data_input_ids'], data_attention_mask=batch['data_attention_mask'], 
               lbl2data_data2ptr=batch['lnk2data_data2ptr'], lbl2data_idx=batch['lnk2data_idx'], 
               lbl2data_input_ids=batch['lnk2data_input_ids'], lbl2data_attention_mask=batch['lnk2data_attention_mask'], 
               plbl2data_data2ptr=batch['plnk2data_data2ptr'], plbl2data_idx=batch['plbl2data_idx'])
_, neg_targ = output.data_repr, output.lbl2data_repr

In [77]:
loss_fn = MarginMSEWithNegatives(reduce='mean')

In [72]:
inp.shape, pos_targ.shape, neg_targ.shape

(torch.Size([4, 768]), torch.Size([8, 768]), torch.Size([12, 768]))

In [73]:
pos_scores, neg_scores = torch.rand(pos_targ.shape[0]), torch.rand(neg_targ.shape[0])

In [81]:
loss = loss_fn(inp, pos_targ, pos_scores, neg_targ, neg_scores)

> [0;32m/tmp/ipykernel_23124/135438691.py[0m(17)[0;36mforward[0;34m()[0m
[0;32m     15 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m[0;34m[0m[0m
[0m[0;32m---> 17 [0;31m        [0mbsz[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0minp[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m[0;34m[0m[0m
[0m[0;32m     19 [0;31m        [0;32massert[0m [0mlen[0m[0;34m([0m[0mpos_scores[0m[0;34m)[0m [0;34m%[0m [0mbsz[0m [0;34m==[0m [0;36m0[0m[0;34m,[0m [0;34m"Number of elements in `pos_scores` should be divisible by batch size."[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  q


In [79]:
loss

tensor(0.0968, grad_fn=<MseLossBackward0>)

## Cosine

In [None]:
#| export
class Cosine(BaseLoss):

    def __init__(self, 
                 **kwargs):
        super().__init__(**kwargs)
        

In [None]:
#| export
@patch
def forward(cls:Cosine, 
            inp:torch.FloatTensor,
            inp_mask:torch.FloatTensor,
            targ:torch.LongTensor,
            targ_mask:torch.LongTensor,
            **kwargs):
    seq_len = min(inp.shape[1], targ.shape[1])
    
    inp_mask = inp_mask.unsqueeze(2).expand(inp.size()).float()
    targ_mask = targ_mask.unsqueeze(2).expand(targ.size()).float()

    inp, targ = F.normalize(inp, dim=-1),F.normalize(targ, dim=-1)
    
    inp,targ = inp*inp_mask,targ*targ_mask
    inp,targ = inp[:,:seq_len],targ[:,:seq_len]

    loss = 1.0 - torch.sum(inp*targ, dim=-1)
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

### Example

In [None]:
loss_fn = Cosine(reduce='mean')

In [None]:
loss = loss_fn(data_embed, data_attention_mask, lbl2data_embed, lbl2data_attention_mask)

In [None]:
loss

tensor(0.6899, device='cuda:0', grad_fn=<MeanBackward0>)

## Entropy

In [None]:
#| export
class Entropy(BaseLoss):

    def __init__(self, 
                 margin:Optional[float]=0.8,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=True,
                 n_negatives:Optional[int]=5,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('margin,tau,apply_softmax,n_negatives')
        

In [None]:
#| export
@patch
def forward(cls:Entropy, 
            inp:torch.FloatTensor,
            targ:torch.LongTensor,
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
    sc = targ.exp()@inp.T
    
    sc_p =  sc.diagonal().unsqueeze(1)
    _, ne_idx = torch.topk(torch.where(ne == 0, torch.finfo(sc.dtype).min, sc), min(cls.n_negatives, sc.shape[0]-1), dim=1, largest=True)
    sc_n = sc.gather(1, ne_idx)
    
    loss = torch.relu(sc_n - sc_p + cls.margin)
    
    if cls.apply_softmax:
        m = loss != 0
        p = torch.softmax(sc_n/cls.tau * m, dim=1)
        loss = loss*p
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

### Example

In [None]:
el_fn = Entropy(margin=1e-2, tau=0.1, apply_softmax=True, n_negatives=5, reduce='mean')

In [None]:
loss = el_fn( F.log_softmax(data_repr, dim=-1), F.log_softmax(lbl2data_repr, dim=-1), lbl2data_idx, 
             kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx']); loss

tensor(0.0019, grad_fn=<MeanBackward0>)

In [None]:
loss

tensor(0.0019, grad_fn=<MeanBackward0>)

## Triplet

In [None]:
#| export
class Triplet(BaseLoss):

    def __init__(self, 
                 margin:Optional[float]=0.8,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=True,
                 n_negatives:Optional[int]=5,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('margin,tau,apply_softmax,n_negatives')


In [None]:
#| export
@patch
def forward(cls:Triplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
    sc = inp@targ.T
    sc_p =  sc.diagonal().unsqueeze(1)
    _, ne_idx = torch.topk(torch.where(ne == 0, -10, sc), min(cls.n_negatives, sc.shape[0]-1), dim=1, largest=True)
    sc_n = sc.gather(1, ne_idx)
    
    loss = torch.relu(sc_n - sc_p + cls.margin)
    
    if cls.apply_softmax:
        m = loss != 0
        p = torch.softmax(sc_n/cls.tau * m, dim=1)
        loss = loss*p
    
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
@patch
def forward(cls:Triplet, 
            inp:torch.FloatTensor, 
            targ:torch.LongTensor, 
            inp2targ_idx:torch.LongTensor,
            n_pinp2targ:torch.LongTensor,
            pinp2targ_idx:torch.LongTensor,
            margin:Optional[float]=None,
            tau:Optional[float]=None,
            apply_softmax:Optional[bool]=None,
            n_negatives:Optional[int]=None,
            **kwargs):
    store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)
    _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
    ne = 1 - get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
    sc = inp@targ.T
    loss = torch.relu((sc - sc.diagonal().unsqueeze(1) + cls.margin) * ne)
    
    if cls.n_negatives is not None:
        loss, idx = torch.topk(loss, min(cls.n_negatives, loss.shape[0]-1), dim=1, largest=True)
        sc, ne = sc.gather(1, idx), ne.gather(1, idx)
        
    if cls.apply_softmax:
        m = loss != 0
        p = sc/cls.tau * m
        p[ne == 0] = torch.finfo(p.dtype).min
        p = torch.softmax(p, dim=1)
        loss = loss*p
        
    loss /= (ne.sum(dim=1, keepdim=True) + 1e-9)
    
    if cls.reduction == 'mean': return loss.sum(dim=1).mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
tl_fn = Triplet(0.8, tau=0.1, apply_softmax=True, n_negatives=5, reduce='mean')

In [None]:
loss = tl_fn(data_repr, lbl2data_repr, lbl2data_idx, kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx']); loss

tensor(0.9848, grad_fn=<MeanBackward0>)

## SoupCon

In [None]:
#| export
class SoupCon(BaseLoss):

    @delegates(BaseLoss.__init__)
    def __init__(self,
                 tau:Optional[float]=1.0, 
                 **kwargs):
        super().__init__(**kwargs)
        self.tau = nn.Parameter(torch.tensor(tau, dtype=torch.float32))
        

In [None]:
scn_fn = SoupCon(reduce='mean')

In [None]:
#| export
@patch
def forward(cls:SoupCon,
            inp:torch.FloatTensor,
            targ:torch.LongTensor,
            n_inp2targ:torch.LongTensor,
            inp2targ_idx:torch.LongTensor,
            **kwargs):
    _, idx, cnt = torch.unique(inp2targ_idx, return_inverse=True, return_counts=True)
    _, idx_sorted = torch.sort(idx)
    targ_idx = idx_sorted[torch.cat([torch.zeros(1, dtype=inp2targ_idx.dtype, device=inp2targ_idx.device), cnt.cumsum(0)[:-1]])]

    pe = get_sparse_matrix(idx, n_inp2targ).to_dense()
    sc = -F.log_softmax(inp@targ[targ_idx].T/cls.tau, dim=1)
    
    loss = sc*pe
    loss /= pe.sum(dim=1, keepdim=True)
    
    if cls.reduce == 'mean': return loss.sum(dim=1).mean()
    elif cls.reduce == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = scn_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx); loss

tensor(6.8475, grad_fn=<MeanBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = scn_fn(data_repr, lbl2data_repr, lbl2data_data2ptr)
    print(loss)

tensor(13.0039, device='cuda:0', grad_fn=<SumBackward0>)


STAGE:2024-04-14 01:53:44 12684:12684 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-14 01:53:44 12684:12684 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-14 01:53:44 12684:12684 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [None]:
print(prof)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::slice         0.90%      22.000us         1.18%      29.000us      29.000us           0 b           0 b           0 b           0 b             1  
                                       aten::as_strided         0.29%       7.000us         0.29%       7.000us       7.000us           0 b           0 b           0 b           0 b             1  
         

This is also not update to date with the above implementation

In [None]:
@patch
def __call__(cls:SoupCon, 
             inp:torch.FloatTensor, 
             targ:torch.LongTensor, 
             n_inp2targ:torch.LongTensor,
             inp2targ_idx:torch.LongTensor,
             **kwargs):
    _, idx, cnt = torch.unique(inp2targ_idx, return_inverse=True, return_counts=True)
    _, idx_sorted = torch.sort(idx)
    targ_idx = idx_sorted[torch.cat([torch.zeros(1, dtype=inp2targ_idx.dtype, device=inp2targ_idx.device), cnt.cumsum(0)[:-1]])]
    
    pe = get_sparse_matrix(idx, n_inp2targ).to_dense()
    
    s = -F.log_softmax(inp@targ[targ_idx].T, dim=1)
    loss = []
    for i,p in zip(s, pe):
        ps = i*p
        loss.append(ps)
    loss = torch.hstack(loss)
    if cls.reduce == 'mean': return loss.sum()/pe.sum()
    elif cls.reduce == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = scn_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx); loss

tensor(10.6172, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
with profiler.profile(with_stack=True, profile_memory=True) as prof:
    loss = scn_fn(data_repr, lbl2data_repr, lbl2data_data2ptr)
    print(loss)

STAGE:2024-04-14 01:54:12 12684:12684 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-14 01:54:12 12684:12684 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-14 01:54:12 12684:12684 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


tensor(13.0039, device='cuda:0', grad_fn=<SumBackward0>)


In [None]:
print(prof)

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::numpy_T         0.02%      10.000us         0.11%      59.000us      59.000us           0 b           0 b           0 b           0 b             1  
                aten::permute         0.07%      40.000us         0.09%      49.000us      49.000us           0 b           0 b           0 b           0 b             1  
             aten::as_strided         0.02%       9.000us         0.02%       9.000us       9.000us           0 b           0 b           0