## **Flow Matching** 

This notebook trains a simple (vanilla) flow matching model on the CIFAR-10 dataset. 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [None]:
class FlowUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64):
        super().__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels + 1, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1), nn.ReLU()
        )
        self.up1 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_channels * 4, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        self.up2 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        self.out = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        t = t[:, None, None, None].float().expand(-1, 1, x.size(2), x.size(3))
        x = torch.cat([x, t], dim=1)
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        m = self.middle(self.pool2(e2))
        d1 = self.up1(m)
        d1 = self.dec1(torch.cat([d1, e2], dim=1))
        d2 = self.up2(d1)
        d2 = self.dec2(torch.cat([d2, e1], dim=1))
        return self.out(d2)


In [None]:
def sample_flow_data(x0):
    B = x0.size(0)
    x1 = torch.randn_like(x0)
    t = torch.rand(B, device=x0.device).view(B, 1, 1, 1)
    xt = (1 - t) * x0 + t * x1
    v_target = (x1 - x0) / (t * (1 - t))
    return xt, t.squeeze(), v_target

In [None]:
def train_flow_matching(model, loader, optimizer, epochs=10, device='cuda'):
    model.train()
    all_losses = []
    for epoch in range(epochs):
        total_loss = 0
        for x, _ in tqdm(loader, desc=f"Epoch {epoch+1}"):
            x = x.to(device)
            xt, t, v_target = sample_flow_data(x)
            pred_v = model(xt, t)
            loss = F.mse_loss(pred_v, v_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(loader.dataset)
        all_losses.append(avg_loss)
        print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")
    return all_losses

In [None]:
@torch.no_grad()
def sample_flow(model, steps=50, return_trajectory=False):
    model.eval()
    x = torch.randn(1, 3, 32, 32).to(next(model.parameters()).device)
    dt = 1.0 / steps
    trajectory = [x.clone().cpu()]
    for i in range(steps):
        t = torch.ones(x.size(0), device=x.device) * (1 - i / steps)
        v = model(x, t)
        x = x - v * dt
        if return_trajectory:
            trajectory.append(x.clone().cpu())
    return x.clamp(-1, 1).cpu(), trajectory if return_trajectory else None

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2. - 1.)
])
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=128, shuffle=True)

model = FlowUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

losses = train_flow_matching(model, loader, optimizer, epochs=10, device=device)

# Plot training loss
plt.plot(losses, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Flow MSE Loss")
plt.title("Training Loss (Flow Matching)")
plt.grid(True)
plt.show()

In [None]:
samples = []
model.eval()
for _ in range(16):
    sample, _ = sample_flow(model, steps=50)
    samples.append(sample[0])
samples = torch.stack(samples)
grid = make_grid(samples, nrow=4, normalize=True)
plt.figure(figsize=(6, 6))
plt.imshow(grid.permute(1, 2, 0))
plt.axis('off')
plt.title("Flow Matching Samples")
plt.show()

In [None]:
_, trajectory = sample_flow(model, steps=20, return_trajectory=True)
plt.figure(figsize=(15, 3))
for i, img in enumerate(trajectory[::4]):
    plt.subplot(1, len(trajectory[::4]), i + 1)
    plt.imshow((img[0].permute(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.title(f"t={1 - i*0.2:.1f}")
plt.suptitle("Trajectory of 1 Sample Through Flow Field")
plt.show()