In [2]:
import numpy as np
import torch

In [159]:
def compute_mmd(x, y, x_weights=None, y_weights=None):
    """
    Compute an MMD
    """
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    
    if (x_weights is None) and (y_weights is None):
        return x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean()
    else:
        if x_weights is None:
            x_weights = torch.ones(x.shape[0])
        if y_weights is None:
            y_weights = torch.ones(y.shape[0])
            
        x_weights_tile = x_weights.unsqueeze(1) * x_weights.unsqueeze(0)
        y_weights_tile = y_weights.unsqueeze(1) * y_weights.unsqueeze(0)
        xy_weights_tile = x_weights.unsqueeze(1) * y_weights.unsqueeze(0)
        
        return (
                (x_kernel * x_weights_tile).sum() / x_weights_tile.sum()
                ) + \
            (
                (y_kernel * y_weights_tile).sum() / y_weights_tile.sum()
            ) - \
            (
                2*(xy_kernel * xy_weights_tile).sum() / xy_weights_tile.sum()
            )

def compute_kernel(x, y, gamma=None):
    """
    Gaussian RBF kernel for use in an MMD
    """
    dim = x.size(1)
    assert dim == y.size(1)
    if gamma is None:
        gamma = dim
        
    kernel_input = (x.unsqueeze(1) - y.unsqueeze(0)).pow(2).sum(2) # sum over features
    return torch.exp(-gamma * kernel_input)  # (x_size, y_size)

In [None]:
# x[:2] = x[]

In [198]:
batch_size = 256
pivot_idx = batch_size // 4
x = np.random.rand(batch_size)
x1 = x[:pivot_idx] * 10 + 1
x2 = x[pivot_idx:]

# x1_weights = np.random.rand(x1.shape[0])
# x2_weights = np.random.rand(x2.shape[0])
x1_weights = np.abs(x1)*10
x2_weights = np.abs(x2)*10

x1_weights = x1_weights / x1_weights.sum()
x2_weights = x2_weights / x2_weights.sum()

x1_weights_torch = torch.from_numpy(x1_weights)
x2_weights_torch = torch.from_numpy(x2_weights)

x1_torch = torch.from_numpy(x1).unsqueeze(1)
x2_torch = torch.from_numpy(x2).unsqueeze(1)

# Resampled data
n_resamples=10
result = []
for i in range(n_resamples):
    x1_resample = torch.from_numpy(np.random.choice(x1, size=len(x), p=x1_weights)).unsqueeze(1)
    x2_resample = torch.from_numpy(np.random.choice(x2, size=len(x), p=x2_weights)).unsqueeze(1)
    result.append(compute_mmd(x1_resample, x2_resample))

assert x1.shape[0] + x2.shape[0] == x.shape[0]
print('Standard MMD: {}'.format(compute_mmd(x1_torch, x2_torch)))
# print('Standard MMD2: {}'.format(compute_mmd(x1, x2, x_weights=torch.ones(x1.shape[0]), y_weights=torch.ones(x2.shape[0]))))
print('Resampled MMD: {}'.format(np.array(result).mean()))
print('Weighted MMD: {}'.format(compute_mmd(x1_torch, x2_torch, x_weights=x1_weights_torch, y_weights=x2_weights_torch)))


Standard MMD: 0.9683579053154063
Resampled MMD: 1.0669778022588927
Weighted MMD: 1.0696495423789159
