In [None]:
"""
Minimal PyTorch implementation of Flow Matching with Optimal Transport (OT) for generative modeling, based on "Flow Matching for Generative Modeling" by Lipman et al. (2023).

This code provides a simplified example using the half-moons dataset and Euler method for sampling.
"""


import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_moons
from tqdm import tqdm
from typing import Tuple

# Device configuration
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
data_size: int = 10_000
batch_size: int = 256
input_dim: int = 2
hidden_dim: int = 64
n_iterations: int = 10_000  # Total number of iterations.
sigma_min: float = (
    1e-5  # For OT path (Eq. 22). Controls variance at t=1 (should be small).
)
lr: float = 1e-3

# Data #######################################################################
# Generate half moon data as toy dataset
data, _ = make_moons(data_size, noise=0.05)
data = torch.tensor(data, dtype=torch.float32).to(device)
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)


# Model ######################################################################
# Define the vector field network v_t
class VectorFieldNetwork(nn.Module):
    """
    Parametric model for the time-dependent vector field v_t(x; θ) (Section 3 in Lipman et al.).

    This network represents the function v_t(x; θ), which is the Continuous Normalizing Flow (CNF) vector field.
    It is used in the Flow Matching objective (Eq. 5 in Lipman et al.).
    The network outputs the *value* of the vector field at a given time t and point x.
    """

    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        # MLP architecture (no specific architecture defined in the paper)
        self.net: nn.Sequential = nn.Sequential(
            nn.Linear(input_dim + 1, hidden_dim),  # +1 for time embedding
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def time_embedding(self, t: torch.Tensor) -> torch.Tensor:
        """
        Embeds time into a higher dimensional space.
        """
        return t.unsqueeze(-1)  # Simple embedding, no frequencies used

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        Computes the vector field v_t(x) at time t and point x.
        """
        t_emb: torch.Tensor = self.time_embedding(t)
        tx: torch.Tensor = torch.cat([t_emb, x], dim=-1)
        return self.net(tx)  # Output is v_t(x; θ) (Eq. 5 in Lipman et al.)


def ot_path(
    x_0: torch.Tensor,
    x_1: torch.Tensor,
    t: torch.Tensor,
    sigma_min: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the Optimal Transport (OT) path (ψ_t) and target vector field (u_t).

    This function implements the conditional flow ψ_t (Eq. 22 in Lipman et al.) and the corresponding target vector field u_t (Eq. 23 in Lipman et al.) for the Optimal Transport (OT) formulation in Flow Matching.  ψ_t defines a straight-line interpolation path between a sample x_0 from the prior distribution and a target data point x_1.  The target vector field u_t represents the ideal vector field that the learned vector field v_t should approximate.

    Args:
        x_0: Samples from the prior distribution (typically a standard Gaussian).
        x_1: Target data points.
        t: Time points along the OT path (between 0 and 1).
        sigma_min: Small constant controlling the variance at t=1 (see Eq. 22).

    Returns:
        A tuple containing the OT path (ψ_t) and the target vector field (u_t).
    """
    t = t.unsqueeze(-1)
    # ψ_t (Eq. 22 in Lipman et al.) Linear interpolation between prior and target.
    psi_t: torch.Tensor = (1 - (1 - sigma_min) * t) * x_0 + (t * x_1)
    # Target vector field for OT (u_t) (Eq. 23 in Lipman et al.)
    target_v: torch.Tensor = x_1 - (1 - sigma_min) * x_0
    return psi_t, target_v


def compute_cfm_loss(
    v_net: VectorFieldNetwork,
    x_0: torch.Tensor,
    x_1: torch.Tensor,
    t: torch.Tensor,
    sigma_min: float,
) -> torch.Tensor:
    """
    Computes the Conditional Flow Matching (CFM) loss for Optimal Transport.

    This function computes the CFM loss (Eq. 23 in Lipman et al.) using the Optimal Transport (OT) formulation.  It compares the predicted vector field v_t(ψ_t(x_0, x_1, t)) (output of the v_net) with the target vector field u_t derived from the OT path. The loss is calculated as the mean squared error (MSE) between the predicted and target vector fields.  Minimizing this loss encourages the learned vector field to accurately represent the OT flow.

    Args:
        v_net: The vector field network v_t(x; θ).
        x_0: Samples from the prior distribution.
        x_1: Target data points.
        t: Time points along the OT path.
        sigma_min: Small constant controlling the variance at t=1.

    Returns:
        The CFM loss as a scalar tensor.
    """
    psi_t, target_v = ot_path(x_0, x_1, t, sigma_min)
    v: torch.Tensor = v_net(t, psi_t)  # v_t(ψ_t(x_0, x_1, t))
    # CFM loss (Eq. 23 in Lipman et al.), simplified as MSE loss
    loss: torch.Tensor = ((v - target_v) ** 2).mean()
    return loss


# Training ###################################################################
# Initialize the vector field network v_t and optimizer
v_net = VectorFieldNetwork(input_dim, hidden_dim).to(device)
optimizer = optim.Adam(v_net.parameters(), lr=lr)

# Training loop
progress_bar = tqdm(range(n_iterations), desc="Training", unit="iteration")
losses = []  # List to store loss values
for iteration in progress_bar:
    x_1 = next(iter(data_loader))  # Sample data x_1 ~ q(x_1)
    t = torch.rand(x_1.shape[0], device=device)  # Sample time t ~ U[0, 1]
    x_0 = torch.randn_like(x_1)  # Sample prior noise x_0 ~ N(0, I)

    loss = compute_cfm_loss(v_net, x_0, x_1, t, sigma_min)  # CFM loss

    # Optimization step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss.item())  # Store the loss value
    progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

# Plot loss curve after training
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(losses)
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.set_title("Training Loss Curve")
ax.grid(True)  # Add grid
plt.show()


# Sampling ###################################################################
def sample(
    n_samples: int,
    vector_field: VectorFieldNetwork,
    dt: float = 0.01,  # Step size for Euler integration
) -> torch.Tensor:
    """
    Generates samples by integrating the learned vector field.

    This function generates samples by numerically integrating the learned vector field `v_t(x)` from t=0 to t=1.  It uses the Euler method (a simple first-order numerical integration method) to approximate the solution to the ODE defined by the vector field (Eq. 1 in Lipman et al. with a simplified discretization).  The integration starts from random samples drawn from the prior distribution (typically a standard Gaussian).

    Args:
        n_samples: The number of samples to generate.
        vector_field: The learned vector field network v_t(x; θ).
        dt: The integration step size.

    Returns:
        A tensor of generated samples.
    """
    x = torch.randn(n_samples, input_dim).to(device)  # Sample x_0 ~ N(0, I)
    for t in torch.arange(0, 1, dt):  # Euler integration from t=0 to t=1
        t_batch = t.expand(n_samples).to(device)
        x = (
            x + vector_field(t_batch, x) * dt
        )  # Euler update (Eq. 28 in Lipman et al. with simplified discretization)
    return x


# Sample and plot generated data
generated_samples = sample(data_size, v_net)
generated_samples = generated_samples.cpu().detach().numpy()


# Visualization ###############################################################
# Plot the generated samples
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(
    data[:, 0].cpu().numpy(), data[:, 1].cpu().numpy(), label="Real Data", alpha=0.1
)
ax.scatter(
    generated_samples[:, 0], generated_samples[:, 1], label="Generated Data", alpha=0.1
)
ax.set_title("Flow Matching with Optimal Transport - Generated Samples")
ax.legend()
plt.show()