In [5]:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ClipLoss2D(torch.nn.Module):
    """CLIP (See Open AI CLIP) constrastive loss.
    """
    def __init__(self, linear=None, twin=True, pool=False, 
                 center=False, temp_tau= 1.0):
        super().__init__()
        self.linear = None
        self.pool = pool
        self.center = center
        if linear is not None:
            self.linear_est = torch.nn.LazyLinear(linear)
            if twin:
                self.linear_gt = self.linear_est
            else:
                self.linear_gt = torch.nn.LazyLinear(linear)
        self.temp_tau = nn.Parameter(torch.tensor(temp_tau))

    def get_scores(self, estimates: torch.Tensor, candidates: torch.Tensor):
        """Given estimates that is [B, C, T] and candidates
        which is [B', C, T], return a [B, B'] matrix of scores of matching.
        """
        if self.linear:
            estimates = self.linear_est(estimates)
            candidates = self.linear_gt(candidates)
        if self.pool:
            estimates = estimates.mean(dim=2, keepdim=True)
            candidates = candidates.mean(dim=2, keepdim=True)
        if self.center:
            estimates = estimates - estimates.mean(dim=(1, 2), keepdim=True)
            candidates = candidates - candidates.mean(dim=(1, 2), keepdim=True)

        inv_norms = 1 / (1e-8 + candidates.norm(dim=(1, 2), p=2))
        inv_norms_2 = 1 / (1e-8 + estimates.norm(dim=(1, 2), p=2))
        scores = torch.einsum("bct,oct,b,o -> bo", estimates, candidates, inv_norms_2, inv_norms)
        
        # We normalize inside the einsum, to avoid creating a copy
        # of candidates, which can be pretty big.
        # scores = torch.einsum("bct,oct,o->bo", estimates, candidates, inv_norms)
        # scores = torch.einsum("bct,bct->bc", estimates, candidates)
        scores = torch.einsum("bct,oct,b,o -> bo", estimates, candidates, inv_norms_2, inv_norms)
        return scores
    
    def get_probabilities(self, estimates, candidates):
        """Given estimates that is [B, C, T] and candidates
        which is [B', C, T], return a [B, B'] matrix of probabilities of matching.
        """
        scores = self.get_scores(estimates, candidates)
        scores = scores / self.temp_tau
        return F.softmax(scores, dim=1)

    def forward(self, estimate, candidate, mask=None):
        """Warning: estimate and candidate are not necessarily symmetrical.
        If estimate of shape [B, C, T] and candidate of size [B', C, T]
        with B'>=B, the first B samples of candidate are targets, while
        the remaining B'-B samples of candidate are only used as negatives.
        """
        # assert mask.all(), "mask is not supported for now"
        assert estimate.size(0) <= candidate.size(0), "need at least as many targets as estimates"
        scores = self.get_probabilities(estimate, candidate)
        target = torch.arange(len(scores), device=estimate.device)
        return F.cross_entropy(scores, target)
    
clip_loss = ClipLoss2D(linear=None, twin=True, pool=False, center=False, temp_tau=1.0)

estimates = torch.randn(4, 8, 82)
candidates = torch.randn(4, 8, 82)
scores = clip_loss.get_scores(estimates, candidates)
print(scores)

tensor([[-0.0048,  0.0127,  0.0413,  0.0038],
        [ 0.0319,  0.0068,  0.0119, -0.0046],
        [-0.0079, -0.0146,  0.0433,  0.0049],
        [ 0.0060,  0.0371,  0.0362, -0.0088]])


In [4]:

class ClipLoss1D(torch.nn.Module):
    """CLIP (See Open AI CLIP) contrastive loss."""

    def __init__(self, linear=None, twin=True, center=False, temp_tau=1.0):
        super().__init__()
        self.linear = None
        self.center = center
        if linear is not None:
            self.linear_est = torch.nn.LazyLinear(linear)
            if twin:
                self.linear_gt = self.linear_est
            else:
                self.linear_gt = torch.nn.LazyLinear(linear)
        self.temp_tau = nn.Parameter(torch.tensor(temp_tau))

    def get_scores(self, estimates: torch.Tensor, candidates: torch.Tensor):
        """Given estimates that is [B, N] and candidates which is [B', N],
        return a [B, B'] matrix of scores of matching."""
        if self.linear:
            estimates = self.linear_est(estimates)
            candidates = self.linear_gt(candidates)
        if self.center:
            estimates = estimates - estimates.mean(dim=1, keepdim=True)
            candidates = candidates - candidates.mean(dim=1, keepdim=True)

        inv_norms = 1 / (1e-8 + candidates.norm(dim=1, p=2))
        inv_norms_2 = 1 / (1e-8 + estimates.norm(dim=1, p=2))
        # scores = torch.einsum("bn,on,o->bo", estimates, candidates, inv_norms)
        scores = torch.einsum("bn,on,b,o -> bo", estimates, candidates, inv_norms_2, inv_norms)
        return scores

    def get_probabilities(self, estimates, candidates):
        """Given estimates that is [B, N] and candidates which is [B', N],
        return a [B, B'] matrix of probabilities of matching."""
        scores = self.get_scores(estimates, candidates)
        scores = scores / self.temp_tau
        return F.softmax(scores, dim=1)

    def forward(self, estimate, candidate):
        """Forward method for ClipLoss."""
        assert estimate.size(0) <= candidate.size(0), "need at least as many targets as estimates"
        scores = self.get_probabilities(estimate, candidate)
        target = torch.arange(len(scores), device=estimate.device)
        return F.cross_entropy(scores, target)
    

clip_loss = ClipLoss1D(linear=None, twin=True, center=False, temp_tau=1.0)

estimates = torch.randn(4, 8)
candidates = torch.randn(4, 8)
scores = clip_loss.get_scores(estimates, candidates)
print(scores)

tensor([[-0.2552,  0.6079, -0.1632,  0.2943],
        [-0.2991, -0.3183,  0.4228, -0.5013],
        [-0.1480, -0.0494,  0.0910, -0.1085],
        [-0.2797, -0.1010, -0.3935,  0.1049]])
