In [39]:
# Create a birkhoff polytope of size 3 in numpy 
import numpy as np
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
import seaborn as sns

np.random.seed(11)
n = 3
# create a uniform matrix of size n
Gamma = np.ones((n, n)) / n

# sample a permutation matrix by adding Gumbel noise to Gamma
# and then normalizing
def sample_perm_matrix(Gamma, Gumbel, tau, n_iter):
    T = np.exp((Gumbel + Gamma) / tau)
    for _ in range(n_iter):
        T = T / np.sum(T, axis=1, keepdims=True)
        T = T / np.sum(T, axis=0, keepdims=True)
    return T

# sample a permutation matrix
n_sample = 10000
all_noises = []
for _ in range(n_sample):
    Gumbel = np.random.gumbel(size=(n, n))
    all_noises.append(Gumbel)


In [40]:
import os

def save_image(tau, Gamma, all_noises, root_dir, index = None, pca_model = None):
    all_samples = []
    for i in range(n_sample):
        all_samples.append(sample_perm_matrix(Gamma, all_noises[i], tau, 30).reshape(-1))

    all_samples = np.vstack(all_samples)
    # plot all_samples using dimensionality reduction on 2D
    # dimension reduction on all_samples using PCA
    from sklearn.decomposition import PCA
    if pca_model is None:
        pca = PCA(n_components=2)
        pca.fit(all_samples)
    else:
        pca = pca_model

    # scatterplot all_samples_2d with small point width

    if index is not None:
        all_samples_2d = pca.transform(all_samples)
        fig, ax = plt.subplots()
        ax.scatter(all_samples_2d[:, 0], all_samples_2d[:, 1], s=1)
        # save the plot in a file
        fig.savefig(os.path.join(root_dir, f"birkhoff_{index}.png"))
        plt.close(fig)
        
    return pca_model

In [42]:
all_samples = []
pca_model = save_image(0.05, Gamma, all_noises, "birkhoff")
t = 1
F = 15
for tau in np.linspace(0.2, 1, F):
    save_image(tau**2, Gamma, all_noises, "birkhoff", F - t, pca_model)
    t += 1

# save all images in the folder birkhoff into a gif
import imageio
images = []
for i in range(F):
    images.append(imageio.imread(os.path.join("birkhoff", f"birkhoff_{i}.png")))
    if i == F - 1:
        for _ in range(5):
            images.append(imageio.imread(os.path.join("birkhoff", f"birkhoff_{i}.png")))
            
imageio.mimsave(os.path.join("birkhoff", "birkhoff.gif"), images, duration=6)

# delete all images in the folder birkhoff
import os
for i in range(F):
    os.remove(os.path.join("birkhoff", f"birkhoff_{i}.png"))

  images.append(imageio.imread(os.path.join("birkhoff", f"birkhoff_{i}.png")))
  images.append(imageio.imread(os.path.join("birkhoff", f"birkhoff_{i}.png")))
