In [2]:
import torch

from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from pytorch_metric_learning.distances.base_distance import BaseDistance


class LpDistance(BaseDistance):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        assert not self.is_inverted

    def compute_mat(self, query_emb, ref_emb):
        dtype, device = query_emb.dtype, query_emb.device
        if ref_emb is None:
            ref_emb = query_emb
        if dtype == torch.float16:  # cdist doesn't work for float16
            rows, cols = lmu.meshgrid_from_sizes(query_emb, ref_emb, dim=0)
            output = torch.zeros(rows.size(), dtype=dtype, device=device)
            rows, cols = rows.flatten(), cols.flatten()
            distances = self.pairwise_distance(query_emb[rows], ref_emb[cols])
            output[rows, cols] = distances
            return output
        else:
            return torch.cdist(query_emb, ref_emb, p=self.p)

    def pairwise_distance(self, query_emb, ref_emb):
        return torch.nn.functional.pairwise_distance(query_emb, ref_emb, p=self.p)

In [3]:
z1 = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
z2 = torch.tensor([[1., 2., 3.], [6., 6., 6.]])

In [4]:
dist = LpDistance(p=2)
print(dist.compute_mat(z1, z2)**2)
print(dist.pairwise_distance(z1, z2)**2)

tensor([[ 0.0000, 50.0000],
        [27.0000,  5.0000]])
tensor([3.0000e-12, 5.0000e+00])


In [9]:
dist = LpDistance(p=2)
print(dist.compute_mat(z1, z2))
print(dist.pairwise_distance(z1, z2))

dist = LpDistance(power=5)
print(dist.compute_mat(z1, z2))
print(dist.pairwise_distance(z1, z2))

tensor([[0.0000, 7.0711],
        [5.1962, 2.2361]])
tensor([1.7321e-06, 2.2361e+00])
tensor([[0.0000, 7.0711],
        [5.1962, 2.2361]])
tensor([1.7321e-06, 2.2361e+00])


In [27]:
def compute_expected_l2_dist(self, z1_mean, z1_var, z2_mean, z2_var):
    mean_diff = z1_mean.unsqueeze(1) - z2_mean
    square_difference_of_means = torch.einsum("nmd,nmd->nm", mean_diff, mean_diff)
    # E[||z1 - z2||^2] = E[μ1 - μ2]^2 + Tr[Σ1] + Tr[Σ2]
    return square_difference_of_means + z1_var.sum() + z2_var.sum()


tensor([1.7321e-06, 2.2361e+00])

In [41]:
import torch
from pytorch_metric_learning.distances import BaseDistance, LpDistance
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu


class ExpectedSquareL2Distance(BaseDistance):
    def __init__(self, sample_dim=0, feature_dim=-1, **kwargs):
        super().__init__(**kwargs)

        self.l2_square_dist = LpDistance(p=2, power=2)
        # self.sample_dim = sample_dim
        # self.feature_dim = feature_dim

    def compute_mat(self, query_emb, ref_emb):
        assert len(query_emb.shape) == 3
        assert len(ref_emb.shape) == 3

        # query_mean = query_emb.mean(dim=self.sample_dim)
        # query_var = query_emb.var(dim=self.sample_dim)

        # ref_mean = ref_emb.mean(dim=self.sample_dim)
        # ref_var = ref_emb.var(dim=self.sample_dim)

        query_mean = query_emb["mean"]
        query_var = query_emb["var"]

        ref_mean = ref_emb["mean"]
        ref_var = ref_emb["var"]

        return (
            self.l2_square_dist.compute_mat(query_mean, ref_mean)
            + query_var.sum(dim=self.feature_dim).unsqueeze(1)
            + ref_var.sum(dim=self.feature_dim).unsqueeze(0)
        )

    def pairwise_distance(self, query_emb, ref_emb):
        assert len(query_emb.shape) == 3
        assert len(ref_emb.shape) == 3

        # query_mean = query_emb.mean(dim=self.sample_dim)
        # query_var = query_emb.var(dim=self.sample_dim)

        # ref_mean = ref_emb.mean(dim=self.sample_dim)
        # ref_var = ref_emb.var(dim=self.sample_dim)

        query_mean = query_emb["mean"]
        query_var = query_emb["var"]

        ref_mean = ref_emb["mean"]
        ref_var = ref_emb["var"]

        return (
            self.l2_square_dist.pairwise_distance(query_mean, ref_mean)
            + query_var.sum(dim=self.feature_dim)
            + ref_var.sum(dim=self.feature_dim)
        )