In this notebook, we will investigate the ability of Wasserstein-GAN (W-GAN) to approximate the 1-Wasserstein distance. We will begin by formulating the problem and selecting ground truth distributions for which the true optimal transport (OT) distances are known.

# Problem Formulation

**Objective:** Given two probability distributions P and Q, we want to compute the 1-Wasserstein distance between them.

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt
import seaborn as sns

# Set the random seed
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# 1. Ground Truth Distributions : 
We will select a set of ground truth distributions for which the true optimal transport (OT) distances are known. In this case, we will consider 1D and 2D Gaussians.

First, let's define a function to generate ground truth Gaussian distributions:

In [2]:
def generate_gaussians(num_distributions, dim):
    means = [torch.randn(dim) for _ in range(num_distributions)]
    covariances = [torch.randn(dim, dim) for _ in range(num_distributions)]
    covariances = [A @ A.T for A in covariances]  # Ensure positive semi-definite
    
    return means, covariances

We can generate a set of ground truth Gaussian distributions as follows:

In [3]:
num_distributions = 10
dim = 2

means, covariances = generate_gaussians(num_distributions, dim)

Next, we need to compute the true OT distances between these Gaussian distributions. For simplicity, we will compute the distances only between consecutive pairs of distributions:

In [7]:
def true_ot_distance(mean1, cov1, mean2, cov2):
    mean_diff = mean2 - mean1
    cov_sum = cov1 + cov2
    ot_distance = (torch.sqrt(torch.dot(mean_diff.t(), mean_diff)) 
                   + torch.trace(cov_sum - 2 * torch.sqrt(torch.sqrt(cov1) @ (torch.sqrt(cov2) @ torch.sqrt(cov1)))))
    return ot_distance

true_ot_distances = [true_ot_distance(means[i], covariances[i], means[i+1], covariances[i+1]) for i in range(num_distributions - 1)]


# 2. W-GAN

In part 2 of the plan, we will implement the W-GAN model and train the critic on the selected ground truth distributions. Then, we will evaluate the performance of the W-GAN by comparing its approximated Wasserstein distances with the known true OT distances.

In [9]:
# Define the critic neural network
class Critic(nn.Module):
    def __init__(self, input_dim):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
    def forward(self, x):
        return self.net(x)

In [17]:
# Create the critic
input_dim = dim
critic = Critic(input_dim).to(device)

# Set up training parameters
epochs = 1000
batch_size = 256
lr = 1e-4
optimizer = optim.Adam(critic.parameters(), lr=lr)
clip_value = 0.01
n_critic_updates = 5

# Define a function to sample from Gaussian distributions
def sample_from_gaussian(mean, cov, n_samples):
    L = torch.linalg.cholesky(cov)
    samples = torch.randn(n_samples, cov.shape[0])
    return samples @ L + mean

# W-GAN training loop
for epoch in range(epochs):
    for i in range(num_distributions - 1):
        for _ in range(n_critic_updates):
            # Sample data from both distributions
            real_data = sample_from_gaussian(means[i], covariances[i], batch_size).to(device)
            fake_data = sample_from_gaussian(means[i+1], covariances[i+1], batch_size).to(device)
            
            # Calculate the Wasserstein distance approximation
            real_scores = critic(real_data)
            fake_scores = critic(fake_data)
            wasserstein_approx = real_scores.mean() - fake_scores.mean()
            
            # Update the critic
            critic_loss = -wasserstein_approx
            optimizer.zero_grad()
            critic_loss.backward()
            optimizer.step()
            
            # Clip the critic's weights
            for param in critic.parameters():
                param.data.clamp_(-clip_value, clip_value)
                
    if epoch % 100 == 0:
        print(f'Epoch: {epoch}, Critic Loss: {critic_loss.item()}')


Epoch: 0, Critic Loss: -4.875846207141876e-05


KeyboardInterrupt: 

Once the W-GAN is trained, we can evaluate its performance by comparing the approximated Wasserstein distances with the known true OT distances:

In [18]:
def wgan_approximated_distance(mean1, cov1, mean2, cov2, n_samples=1000):
    real_data = sample_from_gaussian(mean1, cov1, n_samples).to(device)
    fake_data = sample_from_gaussian(mean2, cov2, n_samples).to(device)
    
    real_scores = critic(real_data)
    fake_scores = critic(fake_data)
    
    return real_scores.mean() - fake_scores.mean()

wgan_distances = [wgan_approximated_distance(means[i], covariances[i], means[i+1], covariances[i+1]) for i in range(num_distributions - 1)]

# Compare W-GAN approximated distances with true OT distances
for i in range(num_distributions - 1):
    print(f'Pair {i + 1}: True OT Distance = {true_ot_distances[i].item()}, W-GAN Approximated Distance = {wgan_distances[i].item()}')

Pair 1: True OT Distance = 2.977085828781128, W-GAN Approximated Distance = 0.00043230969458818436
Pair 2: True OT Distance = nan, W-GAN Approximated Distance = 0.0001508370041847229
Pair 3: True OT Distance = nan, W-GAN Approximated Distance = -0.00026264041662216187
Pair 4: True OT Distance = nan, W-GAN Approximated Distance = 0.00019355490803718567
Pair 5: True OT Distance = 0.5751767158508301, W-GAN Approximated Distance = -2.5820918381214142e-05
Pair 6: True OT Distance = -0.6723374128341675, W-GAN Approximated Distance = 2.878718078136444e-06
Pair 7: True OT Distance = nan, W-GAN Approximated Distance = 1.5721656382083893e-05
Pair 8: True OT Distance = nan, W-GAN Approximated Distance = -0.0004567587748169899
Pair 9: True OT Distance = 4.366995811462402, W-GAN Approximated Distance = 0.0004621315747499466


# 3. Sinkhorn divergence

First, let's define a function to compute the Sinkhorn divergence between two distributions:

In [20]:
def sinkhorn_divergence(mu, nu, C, epsilon, n_iter=100):
    K = torch.exp(-C / epsilon)
    u = torch.ones_like(mu)
    
    for _ in range(n_iter):
        v = nu / (K.T @ u)
        u = mu / (K @ v)
        
    P = torch.diag(u) @ K @ torch.diag(v)
    return torch.sum(P * C)

In [21]:
# cost matrix between two sets of samples
def compute_cost_matrix(X, Y):
    X_sq = torch.sum(X ** 2, dim=1, keepdim=True)
    Y_sq = torch.sum(Y ** 2, dim=1, keepdim=True).T
    XY = X @ Y.T
    return X_sq - 2 * XY + Y_sq

In [23]:
epsilon = 0.1
n_samples = 1000

sinkhorn_distances = []

for i in range(num_distributions - 1):
    X = sample_from_gaussian(means[i], covariances[i], n_samples).to(device)
    Y = sample_from_gaussian(means[i+1], covariances[i+1], n_samples).to(device)
    
    mu = torch.ones(X.shape[0], device=device) / X.shape[0]
    nu = torch.ones(Y.shape[0], device=device) / Y.shape[0]
    
    C = compute_cost_matrix(X, Y)
    
    sinkhorn_dist = sinkhorn_divergence(mu, nu, C, epsilon)
    sinkhorn_distances.append(sinkhorn_dist.item())

In [24]:
for i in range(num_distributions - 1):
    print(f'Pair {i + 1}: True OT Distance = {true_ot_distances[i].item()}, W-GAN Approximated Distance = {wgan_distances[i].item()}, Sinkhorn Distance = {sinkhorn_distances[i]}')

Pair 1: True OT Distance = 2.977085828781128, W-GAN Approximated Distance = 0.00043230969458818436, Sinkhorn Distance = nan
Pair 2: True OT Distance = nan, W-GAN Approximated Distance = 0.0001508370041847229, Sinkhorn Distance = nan
Pair 3: True OT Distance = nan, W-GAN Approximated Distance = -0.00026264041662216187, Sinkhorn Distance = nan
Pair 4: True OT Distance = nan, W-GAN Approximated Distance = 0.00019355490803718567, Sinkhorn Distance = nan
Pair 5: True OT Distance = 0.5751767158508301, W-GAN Approximated Distance = -2.5820918381214142e-05, Sinkhorn Distance = nan
Pair 6: True OT Distance = -0.6723374128341675, W-GAN Approximated Distance = 2.878718078136444e-06, Sinkhorn Distance = nan
Pair 7: True OT Distance = nan, W-GAN Approximated Distance = 1.5721656382083893e-05, Sinkhorn Distance = nan
Pair 8: True OT Distance = nan, W-GAN Approximated Distance = -0.0004567587748169899, Sinkhorn Distance = nan
Pair 9: True OT Distance = 4.366995811462402, W-GAN Approximated Distance =