In [None]:
# default_exp utils.distances

# Distances
> Implementation of utilities for measuring vector distances.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
import torch

## wasserstein_distance

In [None]:
#export
def wasserstein_distance(mean1, cov1, mean2, cov2):
    ret = torch.sum((mean1 - mean2) * (mean1 - mean2), -1)
    cov1_sqrt = torch.sqrt(torch.clamp(cov1, min=1e-24)) 
    cov2_sqrt = torch.sqrt(torch.clamp(cov2, min=1e-24))
    ret = ret + torch.sum((cov1_sqrt - cov2_sqrt) * (cov1_sqrt - cov2_sqrt), -1)

    return ret

In [None]:
mean1 = torch.tensor([0.4, 0.8, 0.6])
cov1 = torch.tensor([0.4, 0.8, 0.6])

mean2 = torch.tensor([0.6, 0.3, 0.8])
cov2 = torch.tensor([0.3, 0.2, 0.6])

test_eq(torch.round(wasserstein_distance(mean1, cov1, mean2, cov2)*1e4)/1e4, torch.tensor(0.5372))

## wasserstein_distance_matmul

In [None]:
#export
def wasserstein_distance_matmul(mean1, cov1, mean2, cov2):
    mean1_2 = torch.sum(mean1**2, -1, keepdim=True)
    mean2_2 = torch.sum(mean2**2, -1, keepdim=True)
    ret = -2 * torch.matmul(mean1, mean2.transpose(-1, -2)) + mean1_2 + mean2_2.transpose(-1, -2)
    cov1_2 = torch.sum(cov1, -1, keepdim=True)
    cov2_2 = torch.sum(cov2, -1, keepdim=True)
    cov_ret = -2 * torch.matmul(torch.sqrt(torch.clamp(cov1, min=1e-24)), torch.sqrt(torch.clamp(cov2, min=1e-24)).transpose(-1, -2)) + cov1_2 + cov2_2.transpose(-1, -2)

    return ret + cov_ret

In [None]:
mean1 = mean2 = cov1 = cov2 = torch.tensor([[[0.1376, 0.2219], [0.2287, 0.3205]],
                                            [[0.4656, 0.5470], [0.2581, 0.0454]]])
expected = torch.tensor([[[-0., 0.], [0., 0.]],
                         [[0., 1.], [1., 0.]]])
test_eq(wasserstein_distance_matmul(mean1, cov1, mean2, cov2).round()*1e4/1e4, expected)

## kl_distance

In [None]:
#export
def kl_distance(mean1, cov1, mean2, cov2):
    trace_part = torch.sum(cov1 / cov2, -1)
    mean_cov_part = torch.sum((mean2 - mean1) / cov2 * (mean2 - mean1), -1)
    determinant_part = torch.log(torch.prod(cov2, -1) / torch.prod(cov1, -1))

    return (trace_part + mean_cov_part - mean1.shape[1] + determinant_part) / 2

In [None]:
mean1 = torch.tensor([[0.0240, 0.3383],
        [0.5015, 0.9207]])
mean2 = torch.tensor([[0.4716, 0.7865],
        [0.7942, 0.1391]])
cov1 = torch.tensor([[0.1346, 0.5232],
        [0.9208, 0.1602]])
cov2 = torch.tensor([[0.9033, 0.0117],
        [0.4091, 0.6434]])
test_eq(torch.round(kl_distance(mean1, cov1, mean2, cov2)*1e4)/1e4, torch.tensor([29.1808,  1.1189]))

## kl_distance_matmul

In [None]:
#export
def kl_distance_matmul(mean1, cov1, mean2, cov2):
    cov1_det = 1 / torch.prod(cov1, -1, keepdim=True)
    cov2_det = torch.prod(cov2, -1, keepdim=True)
    log_det = torch.log(torch.matmul(cov1_det, cov2_det.transpose(-1, -2)))

    trace_sum = torch.matmul(1 / cov2, cov1.transpose(-1, -2))

    #mean_cov_part1 = torch.matmul(mean1 / cov2, mean1.transpose(-1, -2))
    #mean_cov_part1 = torch.matmul(mean1 * mean1, (1 / cov2).transpose(-1, -2))
    #mean_cov_part2 = -torch.matmul(mean1 / cov2, mean2.transpose(-1, -2))
    #mean_cov_part2 = -torch.matmul(mean1 * mean2, (1 / cov2).transpose(-1, -2))
    #mean_cov_part3 = -torch.matmul(mean2 / cov2, mean1.transpose(-1, -2))
    #mean_cov_part4 = torch.matmul(mean2 / cov2, mean2.transpose(-1, -2))
    #mean_cov_part4 = torch.matmul(mean2 * mean2, (1 / cov2).transpose(-1, -2))

    #mean_cov_part = mean_cov_part1 + mean_cov_part2 + mean_cov_part3 + mean_cov_part4
    mean_cov_part = torch.matmul((mean1 - mean2) ** 2, (1/cov2).transpose(-1, -2))

    return (log_det + mean_cov_part + trace_sum - mean1.shape[-1]) / 2

In [None]:
mean1 = mean2 = cov1 = cov2 = torch.tensor([[[0.1376, 0.2219], [0.2287, 0.3205]],
                                            [[0.4656, 0.5470], [0.2581, 0.0454]]])
expected = torch.tensor([[[0., 1.], [-1., 0.]],
                         [[0., -2.], [7., 0.]]])
test_eq(kl_distance_matmul(mean1, cov1, mean2, cov2).round()*1e4/1e4, expected)

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d

Author: Sparsh A.

Last updated: 2022-01-22 13:28:44

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.144+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

torch  : 1.10.0+cu111
IPython: 5.5.0

