In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
from torch.utils.data import DataLoader 
from PIL import Image
from matplotlib import pyplot as plt
import seaborn as sns
import wandb
import math
import argparse
import datetime
import math
from time import time

In [2]:

def generate_twomode_GMM(num_samples, dim, p, r, sigma=1.0, device='cpu', dtype=torch.float32, seed=None):
    """
    Generate samples from a bimodal Gaussian mixture model in d dimensions using PyTorch.
    
    The mixture is defined as:
    ρ*(x) = p*N(x; r, I_d) + (1-p)*N(x; -r, I_d)
    
    Parameters:
    -----------
    num_samples : int
        Number of samples to generate
    dim : int
        Dimensionality of the data
    p : float
        Mixing probability for the first component (0 ≤ p ≤ 1)
    r : array-like, tensor, or float
        Fixed vector for the means. If float, creates vector (r, r, ..., r)
        If array/tensor, should have length dim and satisfy |r|_2 = sqrt(d)
    sigma : float, default=1.0
        Standard deviation for each component (assumes spherical covariance σ²I_d)
    device : str or torch.device, default='cpu'
        Device to place tensors on ('cpu', 'cuda', etc.)
    dtype : torch.dtype, default=torch.float32
        Data type for the tensors
    seed : int, optional
        Random seed for reproducibility
        
    Returns:
    --------
    samples : torch.Tensor of shape (num_samples, dim)
        Generated samples from the mixture
    labels : torch.Tensor of shape (num_samples,)
        Component labels (0 for first component, 1 for second component)
    """
    
    if seed is not None:
        torch.manual_seed(seed)
    
    # Handle r parameter
    if torch.is_tensor(r):
        r_vec = r.to(device=device, dtype=dtype)
    elif hasattr(r, '__len__'):  # array-like
        r_vec = torch.tensor(r, device=device, dtype=dtype)
    else:  # scalar
        r_vec = torch.full((dim,), r, device=device, dtype=dtype)
    
    if r_vec.numel() != dim:
        raise ValueError(f"r must have length {dim}, got {r_vec.numel()}")
    
    # Verify the constraint |r|_2 = sqrt(d) as mentioned in the example
    r_norm = torch.norm(r_vec)
    expected_norm = torch.sqrt(torch.tensor(dim, dtype=dtype))
    if not torch.allclose(r_norm, expected_norm, rtol=1e-6):
        print(f"Warning: |r|_2 = {r_norm.item():.4f}, but sqrt(d) = {expected_norm.item():.4f}")
        print(f"Consider using r with norm sqrt({dim}) = {expected_norm.item():.4f}")
    
    # Generate component assignments using Bernoulli distribution
    # torch.bernoulli with p gives 1 with probability p, so we use (1-p) to match the formula
    component_probs = torch.full((num_samples,), 1-p, device=device, dtype=dtype)
    component_labels = torch.bernoulli(component_probs).long()  # 0 for first, 1 for second
    
    # Initialize samples tensor
    samples = torch.zeros(num_samples, dim, device=device, dtype=dtype)
    
    # Generate samples for first component: N(x; -r, σ²I_d)
    first_mask = (component_labels == 0)
    n_first = first_mask.sum().item()
    if n_first > 0:
        samples[first_mask] = torch.normal(
            mean=r_vec.unsqueeze(0).expand(n_first, -1),
            std=sigma
        )
    
    # Generate samples for second component: N(x; r, σ²I_d)
    second_mask = (component_labels == 1)
    n_second = second_mask.sum().item()
    if n_second > 0:
        samples[second_mask] = torch.normal(
            mean=-r_vec.unsqueeze(0).expand(n_second, -1),
            std=sigma
        )
    
    return samples, component_labels

from sklearn.decomposition import PCA

def visualize_gmm_pca(samples, labels=None, n_components=2):
    """
    Visualize high-dimensional GMM samples using PCA.
    
    Parameters:
    -----------
    samples : torch.Tensor of shape (num_samples, dim)
        High-dimensional samples from GMM
    labels : torch.Tensor of shape (num_samples,), optional
        Component labels for coloring
    n_components : int, default=2
        Number of PCA components (2 or 3)
    """
    
    # Convert to numpy
    if torch.is_tensor(samples):
        samples = samples.detach().cpu().numpy()
    if labels is not None and torch.is_tensor(labels):
        labels = labels.detach().cpu().numpy()
    
    # Apply PCA
    pca = PCA(n_components=n_components)
    samples_pca = pca.fit_transform(samples)
    
    # Plot
    if n_components == 2:
        plt.figure(figsize=(8, 6))
        if labels is not None:
            colors = ['red', 'blue', 'green', 'orange', 'purple']
            for i, label in enumerate(np.unique(labels)):
                mask = labels == label
                plt.scatter(samples_pca[mask, 0], samples_pca[mask, 1], 
                          c=colors[i % len(colors)], label=f'Component {label}', alpha=0.7)
            plt.legend()
        else:
            plt.scatter(samples_pca[:, 0], samples_pca[:, 1], alpha=0.7)
        
        plt.grid(True, alpha=0.3)
        plt.show()
    
    elif n_components == 3:
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        
        if labels is not None:
            colors = ['red', 'blue', 'green', 'orange', 'purple']
            for i, label in enumerate(np.unique(labels)):
                mask = labels == label
                ax.scatter(samples_pca[mask, 0], samples_pca[mask, 1], samples_pca[mask, 2],
                          c=colors[i % len(colors)], label=f'Component {label}', alpha=0.7)
            ax.legend()
        else:
            ax.scatter(samples_pca[:, 0], samples_pca[:, 1], samples_pca[:, 2], alpha=0.7)
        
        plt.show()
    
    return samples_pca, pca




In [139]:
#### standard interpolants
    
def beta(t):
    return t

def beta_dot(t):
    return 1.0 

#### designed interpolants

# dim = 1000
# d = dim
# M = np.sqrt(d)
# def beta(t):
#     """
#     Compute β_t = (1/M) * √(-log(1 + (e^(-M²) - 1)t))
    
#     Parameters:
#     -----------
#     t : float or torch.Tensor
#         Time parameter
#     M : float or torch.Tensor
#         Parameter M
        
#     Returns:
#     --------
#     torch.Tensor
#         β_t(t) value
#     """
#     t = torch.as_tensor(t, dtype=torch.float32)
#     M_squared = torch.as_tensor(M*M, dtype=torch.float32)
    
#     # For numerical stability when M² is large
#     if M_squared > 20:
#         # e^(-M²) ≈ 0, so (e^(-M²) - 1) ≈ -1
#         inner = 1 - t
#     else:
#         exp_neg_M2 = torch.exp(-M_squared)
#         inner = 1 + (exp_neg_M2 - 1) * t
    
#     # Clamp to avoid log(0) or negative values
#     inner = torch.clamp(inner, min=1e-10)
#     log_term = -torch.log(inner)
    
#     beta = torch.sqrt(torch.clamp(log_term, min=1e-10)) / M
    
#     return beta

# def beta_dot(t):
#     """
#     Compute dβ_t/dt = -(1/M) * (e^(-M²) - 1) / (2 * √(-log(1 + (e^(-M²) - 1)t)) * (1 + (e^(-M²) - 1)t))
    
#     Mathematical derivation:
#     β_t = (1/M) * √(-log(1 + (e^(-M²) - 1)t))
    
#     Let u = 1 + (e^(-M²) - 1)t
#     Then β_t = (1/M) * √(-log(u))
    
#     dβ_t/dt = (1/M) * (1/2) * (1/√(-log(u))) * d/dt[-log(u)]
#             = (1/M) * (1/2) * (1/√(-log(u))) * (-1/u) * du/dt
#             = (1/M) * (1/2) * (1/√(-log(u))) * (-1/u) * (e^(-M²) - 1)
#             = -(1/M) * (e^(-M²) - 1) / (2 * √(-log(u)) * u)
    
#     Parameters:
#     -----------
#     t : float or torch.Tensor
#         Time parameter
#     M : float or torch.Tensor
#         Parameter M
        
#     Returns:
#     --------
#     torch.Tensor
#         dβ_t/dt value
#     """
#     t = torch.as_tensor(t, dtype=torch.float32)
#     M_squared = torch.as_tensor(M*M, dtype=torch.float32)
    
#     # For numerical stability when M² is large
#     if M_squared > 20:
#         # e^(-M²) ≈ 0, so (e^(-M²) - 1) ≈ -1
#         exp_term = -1.0
#         inner = 1 - t
#     else:
#         exp_neg_M2 = torch.exp(-M_squared)
#         exp_term = exp_neg_M2 - 1
#         inner = 1 + exp_term * t
    
#     # Clamp to avoid numerical issues
#     inner = torch.clamp(inner, min=1e-10)
#     log_term = -torch.log(inner)
#     sqrt_log_term = torch.sqrt(torch.clamp(log_term, min=1e-10))
    
#     # Compute derivative - the negative sign is crucial!
#     derivative = -(1/M) * exp_term / (2 * sqrt_log_term * inner)
    
#     return derivative


In [140]:
import torch

def drift_b(x, t, r, h):
    """
    Compute b_t(x) = dotβ_t(t) * tanh(h + β_t(t) * ⟨r, x⟩)
    
    Parameters:
    -----------
    x : torch.Tensor of shape (num_samples, d)
        Input points 
    r : torch.Tensor of shape (d,)
        Fixed vector satisfying |r|_2 = sqrt(d)
    beta_t : callable
        Function that takes time t and returns β_t(t)
    h : float or torch.Tensor
        Scalar parameter h in the tanh function
    t : float or torch.Tensor
        Current time value
        
    Returns:
    --------
    torch.Tensor of shape (num_samples, d)
        b_t(x) evaluated at input points x
    """
    
    # Evaluate β_t at time t
    beta_val = beta(t)
    dot_beta = beta_dot(t)
    
    # Compute inner product ⟨r, x⟩ for each sample
    inner_product = torch.sum(x * r, dim=1, keepdim=True)  # (num_samples, 1)
    
    # Compute β_t * ⟨r, x⟩
    beta_inner = beta_val * inner_product  # (num_samples, 1)
    
    # Compute tanh(h + β_t * ⟨r, x⟩)
    tanh_term = torch.tanh(h + beta_inner)  # (num_samples, 1)
    
    # Compute b_t(x) = dotβ_t * tanh(h + β_t * ⟨r, x⟩) * r
    bt_x = dot_beta * tanh_term * r.unsqueeze(0)  # (num_samples, d)
    
    return bt_x

In [141]:
num_samples = 5000
dim = 1000
r = torch.ones(dim)
p = 0.2
h = math.log(p/(1-p))/2
z0 = torch.randn(num_samples, dim)
z1, _ = generate_twomode_GMM(num_samples, dim, p, r, sigma=1.0)

D = {'z0': z0, 'z1': z1}

t_min_sample = 1e-3
t_max_sample = 1-1e-3

# t_min_sample = 0
# t_max_sample =1

In [142]:
## RK integration

from torchdiffeq import odeint
class PFlowRHS(nn.Module):
    def __init__(self, drift_b, r, h):
        super(PFlowRHS, self).__init__()
        self.drift_b = drift_b
        self.r = r
        self.h = h
        
    def forward(self, t, states):
        (zt,) = states
        dzt = self.drift_b(zt, t, self.r, self.h)
        return (dzt,)
             
class PFlowIntegrator:
        
    def __init__(self):
        return        

    def __call__(self, drift_b, z0, r, h, T_min, T_max, steps, method='dopri5', return_last = True):

        rhs = PFlowRHS(drift_b, r, h)

        t = torch.linspace(
            T_min, T_max, steps
        ).type_as(z0)

        int_args = {
            'method': method, 
            # 'atol': c.integration_atol, 
            # 'rtol': c.integration_rtol,
        }

        (z,) = odeint(rhs, (z0,), t, **int_args)
        if return_last:
            return z[-1].clone()
        else:
            return z

pflow = PFlowIntegrator()

out = pflow(
    drift_b = drift_b,
    z0 = D['z0'],
    r = r, 
    h = h, 
    T_min = t_min_sample,
    T_max = t_max_sample,
    steps = 5, 
    method = 'rk4',
    return_last = True,)

In [143]:
# visualize_gmm_pca(z1, None, n_components=2)
# visualize_gmm_pca(out, None, n_components=2)

In [144]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA

def fit_gmm_density(samples, n_components=2, use_pca=True):
    """
    Fit GMM to samples and return parameters.
    
    Parameters:
    -----------
    samples : torch.Tensor of shape (num_samples, dim)
        Input samples
    n_components : int, default=2
        Number of GMM components
    use_pca : bool, default=True
        Apply PCA if dim > 2
        
    Returns:
    --------
    weights : np.ndarray
        Component weights
    means : np.ndarray
        Component means
    covariances : np.ndarray
        Component covariances
    """
    
    # Convert to numpy
    if torch.is_tensor(samples):
        samples = samples.detach().cpu().numpy()
    
    # Apply PCA if high-dimensional
    if use_pca and samples.shape[1] > 2:
        pca = PCA(n_components=2)
        samples = pca.fit_transform(samples)
        print(f"PCA: {pca.explained_variance_ratio_.sum():.1%} variance explained")
    
    # Fit GMM
    gmm = GaussianMixture(n_components=n_components, random_state=42)
    gmm.fit(samples)
    
    # Visualize density
    # x_min, x_max = samples[:, 0].min() - 1, samples[:, 0].max() + 1
    # y_min, y_max = samples[:, 1].min() - 1, samples[:, 1].max() + 1
    # xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
    #                      np.linspace(y_min, y_max, 100))
    
    # grid_points = np.c_[xx.ravel(), yy.ravel()]
    # log_prob = gmm.score_samples(grid_points)
    # prob = np.exp(log_prob).reshape(xx.shape)
    
    # plt.figure(figsize=(8, 6))
    # plt.contourf(xx, yy, prob, levels=15, alpha=0.6, cmap='viridis')
    # plt.scatter(samples[:, 0], samples[:, 1], alpha=0.7, s=20, c='white', edgecolors='black')
    # plt.colorbar(label='Density')
    # plt.title('GMM Density Estimation')
    # plt.show()
    
    # Print parameters
    print(f"Weights: {gmm.weights_}")
    print(f"Means:\n{gmm.means_}")
    print(f"Covariances:\n{gmm.covariances_}")
    
    return gmm.weights_, gmm.means_, gmm.covariances_


In [145]:
pca = PCA(n_components=1)
z1_pca = pca.fit_transform(z1)
out_pca = pca.fit_transform(out)

print('truth')
weights, means, covariances = fit_gmm_density(z1_pca, n_components=2, use_pca=False)
print('---------------------------------')
print('generation')
weights, means, covariances = fit_gmm_density(out_pca, n_components=2, use_pca=False)


truth
Weights: [0.8024 0.1976]
Means:
[[-12.49770977]
 [ 50.74980932]]
Covariances:
[[[1.02485029]]

 [[0.97934768]]]
---------------------------------
generation
Weights: [0.9442 0.0558]
Means:
[[-3.39617267]
 [57.46713684]]
Covariances:
[[[1.06909726]]

 [[1.30990296]]]
