# Losses

In [None]:
#| default_exp losses

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import functools, torch, torch.nn as nn, torch.nn.functional as F
from typing import MutableSequence, Union
from fastcore.utils import *
from fastcore.meta import *

from xcai.torch_core import *

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [None]:
from xcai.test_utils import *
from xcai.models.BT000X import *

In [None]:
block = Test.from_cfg('train')

  self._set_arrayXarray(i, j, x)


In [None]:
batch = block.train.one_batch()

In [None]:
m = BT0001.from_pretrained('bert-base-cased')

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [None]:
data_logits, data_logits, lbl2data_data2ptr, data_embed, lbl2data_embed = m(**batch)

In [None]:
data_logits.shape, data_logits.shape, lbl2data_data2ptr.shape, data_embed.shape, lbl2data_embed.shape

(torch.Size([24, 12]),
 torch.Size([24, 12]),
 torch.Size([10]),
 torch.Size([10, 768]),
 torch.Size([24, 768]))

## BaseLoss

In [None]:
#| export
class BaseLoss(nn.Module):

    def __init__(self, 
                 reduce:Optional[str]=None, 
                 **kwargs):
        super().__init__()
        self.reduce = reduce

    @property
    def reduction(self) -> str: return self.reduce
    
    @reduction.setter
    def reduction(self, v:str):
        "Sets the reduction style (typically 'mean', 'sum', or 'none')" 
        self.reduce = v
        

## MultiCrossEntropy

In [None]:
#| export
class MultiCrossEntropy(BaseLoss):

    def __init__(self,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 **kwargs):
        super().__init__(**kwargs)
        self.tn_targ, self.ig_tok = tn_targ, ig_tok
        self.o = torch.ones(tn_targ, dtype=torch.int64) if tn_targ is not None else None
        self._parameters = {'o': self.o}
        

In [None]:
bsz = 10
batch = block.train.one_batch(bsz)
data_logits, lbl2data_input_ids, lbl2data_data2ptr, data_embed, lbl2data_embed = m(**batch)

In [None]:
mce_fn = MultiCrossEntropy(10_000, reduce='mean')

In [None]:
#| export
@patch
def __call__(cls:MultiCrossEntropy,
             inp:torch.FloatTensor,
             targ:torch.LongTensor,
             n_inp2targ:torch.LongTensor):
    tn_targ, targ_len = targ.shape
    bsz, inp_len, mn_targ = inp.shape[0], inp.shape[1], n_inp2targ.max()
    seq_len = min(targ_len, inp_len)
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len].transpose(1,2), targ[:, :seq_len]
    
    inp2targ_ptr = n_inp2targ.cumsum(dim=0)-1
    xn_inp2targ = mn_targ-n_inp2targ+1
    r_targ = (
        torch.ones(tn_targ, dtype=torch.int64, device=inp.device).scatter(0, inp2targ_ptr, xn_inp2targ)
        if cls.tn_targ is None or tn_targ > cls.tn_targ else
        cls.o[:tn_targ].scatter(0, inp2targ_ptr, xn_inp2targ)
    )
    xtarg = targ.repeat_interleave(r_targ, dim=0)

    s = inp.gather(1, xtarg.view(bsz, -1, seq_len)).view(-1, seq_len)
    s /= r_targ.repeat_interleave(r_targ, dim=0).view(-1, 1)
    idx = torch.where(xtarg != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    
    if cls.reduction == 'mean': return (loss/len(torch.where(targ != cls.ig_tok)[0])).sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')


In [None]:
loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr); loss

tensor(14.8057, grad_fn=<SumBackward0>)

In [None]:
@patch
def __call__(cls:MultiCrossEntropy, 
             inp:torch.FloatTensor, 
             targ:torch.LongTensor, 
             n_inp2targ:torch.LongTensor):
    inp_len, targ_len = inp.shape[1], targ.shape[1]
    seq_len = min(inp_len, targ_len)
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len], targ[:, :seq_len].unsqueeze(2)
    inp = inp.repeat_interleave(n_inp2targ, dim=0)
    s = inp.gather(2, targ)
    idx = torch.where(targ != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr); loss

tensor(14.8057, grad_fn=<MeanBackward0>)

In [None]:
@patch
def __call__(cls:MultiCrossEntropy, 
             inp:torch.FloatTensor, 
             targ:torch.LongTensor, 
             n_inp2targ:torch.LongTensor):
    inp_len, targ_len = inp.shape[1], targ.shape[1]
    seq_len = min(inp_len, targ_len)
    inp, targ = -F.log_softmax(inp, dim=2)[:, :seq_len], targ[:, :seq_len]
    num, s = 0, []
    for i,n in zip(inp, n_inp2targ):
        for _ in range(n):
            s.append(i.gather(1, targ[num].view(-1, 1)).view(1, -1))
            num += 1
    s = torch.vstack(s)
    idx = torch.where(targ != cls.ig_tok)
    loss = s[idx[0], idx[1]]
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
    

In [None]:
loss = mce_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr); loss

tensor(14.8057, grad_fn=<MeanBackward0>)

## MultiTriplet

In [None]:
#| export
class MultiTriplet(BaseLoss):

    def __init__(self,
                 bsz:Optional[int]=None, 
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.8,
                 ig_tok:Optional[int]=0,
                 **kwargs):
        super().__init__(**kwargs)
        self.bsz, self.tn_targ, self.margin, self.ig_tok = bsz, tn_targ, margin, ig_tok
        self.t = torch.ones((bsz, bsz), dtype=torch.int64).triu() if bsz is not None else None
        self.u = torch.arange(bsz, dtype=torch.int64) if bsz is not None else None
        self.v = torch.ones(tn_targ, dtype=torch.int64) if tn_targ is not None else None
        self._parameters = {'t':self.t, 'u':self.u, 'v':self.v}
        

In [None]:
bsz = 10
batch = block.train.one_batch(bsz)
data_logits, lbl2data_input_ids, lbl2data_data2ptr, data_embed, lbl2data_embed = m(**batch)

In [None]:
mtl_fn = MultiTriplet(bsz, 10_000, 0.8, reduce='mean')

In [None]:
#| export
@patch
def __call__(cls:MultiTriplet, 
             inp:torch.FloatTensor, 
             targ:torch.LongTensor, 
             n_inp2targ:torch.LongTensor,
             margin:Optional[float]=None):
    cls.margin = cls.margin if margin is None else margin
    bsz, tn_targ, mn_targ = inp.shape[0], targ.shape[0], n_inp2targ.max()
    t, u = cls.t[:bsz,:bsz], cls.u[:bsz]
    v = (
        torch.ones(tn_targ, dtype=torch.int64, device=targ.device)
        if tn_targ > cls.tn_targ else cls.v[:tn_targ]
    )
    targ2inp_ptr = u.repeat_interleave(n_inp2targ)
    s = targ@inp.T
    ps = s.gather(1, targ2inp_ptr.view(-1,1))
    
    inp2targ_ptr = CUDALongTensor.matmul(n_inp2targ[None], t).squeeze(0)-1
    xn_inp2targ = mn_targ-n_inp2targ+1
    
    r_targ = v.scatter(0, inp2targ_ptr, xn_inp2targ)
    
    targ2inp_ptrx = targ2inp_ptr.repeat_interleave(r_targ)
    mask, maskx = F.one_hot(targ2inp_ptr), F.one_hot(targ2inp_ptrx)
    fmask = CUDALongTensor.matmul(maskx,mask.T)
    psx = ps.repeat_interleave(r_targ).view(bsz, -1, 1)
    s = s.T.view(bsz, 1, -1)
    fs = (s - psx + cls.margin).view(-1, tn_targ)
    fs /= r_targ.repeat_interleave(r_targ).view(-1, 1)
    
    idx = torch.where(fmask == 0)
    loss = fs[idx[0], idx[1]]
    loss, n = torch.where(loss > 0, loss, 0), (n_inp2targ.sum())**2 - (n_inp2targ**2).sum()
    if cls.reduction == 'mean': return (loss/n).sum()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = mtl_fn(data_embed, lbl2data_embed, lbl2data_data2ptr); loss

tensor(1.1562, grad_fn=<SumBackward0>)

In [None]:
@patch
def __call__(cls:MultiTriplet, 
             inp:torch.FloatTensor, 
             targ:torch.LongTensor, 
             n_inp2targ:torch.LongTensor, 
             margin:Optional[float]=None):
    cls.margin = cls.margin if margin is None else margin
    score = inp@targ.T
    ptr, fs = 0, []
    for i, n in enumerate(n_inp2targ):
        ps = score[i, ptr:ptr+n].view(-1, 1)
        s = (score[i] - ps + cls.margin).roll(-ptr, 1)
        fs.append(s[:, n:].flatten())
        ptr += n.item()
    loss = torch.hstack(fs)
    loss = torch.where(loss > 0, loss, 0)
    if cls.reduction == 'mean': return loss.mean()
    elif cls.reduction == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
             

In [None]:
loss = mtl_fn(data_embed, lbl2data_embed, lbl2data_data2ptr); loss

tensor(1.1562, grad_fn=<MeanBackward0>)

## SoupCon

In [None]:
#| export
class SoupCon(BaseLoss):

    @delegates(BaseLoss.__init__)
    def __init__(self,
                 bsz:Optional[int]=None, 
                 **kwargs):
        super().__init__(**kwargs)
        self.t = torch.arange(bsz, dtype=torch.int64) if bsz is not None else None
        self._parameters = {'t':self.t}
        

In [None]:
bsz = 100
batch = block.train.one_batch(bsz)
data_logits, data_logits, lbl2data_data2ptr, data_embed, lbl2data_embed = m(**batch)

In [None]:
scn_fn = SoupCon(bsz, reduce='mean')

In [None]:
#| export
@patch
def __call__(cls:SoupCon,
             inp:torch.FloatTensor,
             targ:torch.LongTensor,
             n_inp2targ:torch.LongTensor):
    bsz = inp.shape[0]
    t = cls.t[:bsz]
    targ2inp_ptr = t.repeat_interleave(n_inp2targ)
    s = -F.log_softmax(targ@inp.T, dim=0)
    ps = s.gather(1, targ2inp_ptr.unsqueeze(1)).squeeze(1)
    if cls.reduce == 'mean':
        ps /= n_inp2targ.repeat_interleave(n_inp2targ)
        ps /= bsz
        return ps.sum()
    elif cls.reduce == 'sum': return ps.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = scn_fn(data_embed, lbl2data_embed, lbl2data_data2ptr); loss

tensor(11.1438, grad_fn=<SumBackward0>)

In [None]:
@patch
def __call__(cls:SoupCon, 
             inp:torch.FloatTensor, 
             targ:torch.LongTensor, 
             n_inp2targ:torch.LongTensor):
    bsz = inp.shape[0]
    s = -F.log_softmax(inp@targ.T, dim=1)
    ptr, loss = 0, []
    for i,n in zip(s, n_inp2targ):
        ps = i[ptr:ptr+n]
        ptr += n
        if cls.reduce == 'mean': ps = ps/n
        loss.append(ps)
    loss = torch.hstack(loss)
    if cls.reduce == 'mean': return (loss/bsz).sum()
    elif cls.reduce == 'sum': return loss.sum()
    else: raise ValueError(f'`reduction` cannot be `{cls.reduction}`')
        

In [None]:
loss = scn_fn(data_embed, lbl2data_embed, lbl2data_data2ptr); loss

tensor(11.1438, grad_fn=<SumBackward0>)