In [27]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
import torchvision
from torchdiffeq import odeint

import numpy as np

from models.unet.unet import UNetModelWrapper
from models.unet import UNetModel
from flow_matching.models import OTFM
from torchdyn.core import NeuralODE
import sys

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.device(device)

device(type='cpu')

# 1 Experiments

As in Lipman et al. 2023, we use a MLP with 5-layers of 512 neurons for the 2D examples and the UNet architecture from Dhariwal & Nichol (2021) for the images. 

We generate results and compare the resutls for the models FM-OT and FM-Diffusion.

## 1.1 make_moons

## 1.2  MNIST

## 1.3 Checkboards

## 1.4 CIFAR10

We start by implemeting the flow matching and applying it to the CIFAR10 (Krizhevsky et al., 2009) dataset. As in Lipman et al., 2023, we evaluate likelihood and samples from the model using dopri5 (Dormand & Prince, 1980).

In [4]:
batch_size = 128
num_channel = 128
num_workers = 4
lr = 2e-4

In [16]:
dataset = datasets.CIFAR10(
    root="./data_cifar10",
    train=True, 
    download=True, 
    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
)

subset_indices = list(range(256))  # Use the first 256 samples
dataset_subset = torch.utils.data.Subset(dataset, subset_indices)

dataloader = torch.utils.data.DataLoader(
    dataset_subset, 
    batch_size=batch_size,
    shuffle=True,
    drop_last=True, 
)

Files already downloaded and verified


In [17]:
net_model = UNetModelWrapper(
    dim=(3, 32, 32),
    num_res_blocks=2,
    num_channels=num_channel,
    channel_mult=[1, 2, 2, 2],
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="16",
    dropout=0.1,
).to(
    device
)

In [36]:
# training
optmizer = torch.optim.Adam(net_model.parameters(), lr=lr)
loss_fn = nn.MSELoss()
otfm = OTFM()
sigma_min = 0.001
n_epochs = 10

for epoch in range(n_epochs):
    net_model.train()
    running_loss = 0.0
    
    print("epoch: ", epoch)
    
    for batch, _ in dataloader:
        
        batch = batch.to(device)
        
        x_0 = torch.randn_like(batch).to(device)

        x_1 = batch

        t = torch.rand(len(batch), 3, 32, 32).to(device)
        
        x_t = otfm.compute_x_t(x_0, x_1, sigma_min)
        dx_t = otfm.compute_dx_t(x_0, x_1, sigma_min, t)
        
        optmizer.zero_grad()
        
        loss = loss_fn(net_model(t, x_t), dx_t)
        
        loss.backward()        
        optmizer.step()
        
        running_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {running_loss/len(dataloader)}")

epoch:  0
Epoch 1/10, Loss: 0.17949417233467102
epoch:  1
Epoch 2/10, Loss: 0.17611891776323318
epoch:  2
Epoch 3/10, Loss: 0.1601923257112503
epoch:  3
Epoch 4/10, Loss: 0.16158664971590042
epoch:  4
Epoch 5/10, Loss: 0.15941228717565536
epoch:  5
Epoch 6/10, Loss: 0.15786418318748474
epoch:  6
Epoch 7/10, Loss: 0.1569773033261299
epoch:  7
Epoch 8/10, Loss: 0.15630418807268143
epoch:  8
Epoch 9/10, Loss: 0.15595542639493942
epoch:  9
Epoch 10/10, Loss: 0.1557096466422081


In [37]:

class FlowMatchingSampler:
    def __init__(self, net_model, device='cuda'):
        self.net_model = net_model
        self.device = device
        
    def vector_field(self, t, x):
        """Vector field for ODE solver"""
        # Reshape t to match the expected input shape of the model
        t_shaped = t.reshape(1, 1, 1, 1).expand(x.shape[0], 3, 32, 32).to(self.device)
        return self.net_model(t_shaped, x)
    
    def generate_samples(self, num_samples, rtol=1e-5, atol=1e-5):
        """
        Generate samples using the trained flow matching model
        
        Args:
            num_samples: Number of samples to generate
            rtol: Relative tolerance for ODE solver
            atol: Absolute tolerance for ODE solver
        """
        # Initial noise
        x_0 = torch.randn(num_samples, 3, 32, 32).to(self.device)
        
        # Time points for integration
        t = torch.linspace(0, 1, 2).to(self.device)
        
        # Solve ODE
        self.net_model.eval()
        with torch.no_grad():
            samples = odeint(
                self.vector_field,
                x_0,
                t,
                method='dopri5',
                rtol=rtol,
                atol=atol
            )
        
        # Return final samples (at t=1)
        return samples[-1]

def save_samples(samples, filename):
    """Helper function to save generated samples"""
    # Ensure samples are in correct range [0, 1]
    samples = torch.clamp(samples, 0, 1)
    # Save as grid of images
    torchvision.utils.save_image(
        samples,
        filename,
        nrow=int(np.sqrt(len(samples))),
        normalize=True
    )



In [40]:
# Initialize sampler
sampler = FlowMatchingSampler(net_model, device=device)

# Generate samples
num_samples = 4
samples = sampler.generate_samples(num_samples)

# Save samples
save_samples(samples, "flow_matching_samples.png")


RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[4, 3, 32, 32] to have 1 channels, but got 3 channels instead

In [44]:
batch_size = 128
num_channel = 64  # Reduced from 128 since MNIST is simpler
num_workers = 4
lr = 2e-4

dataset = datasets.MNIST(
    root="./data_mnist",
    train=True, 
    download=True, 
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
)

# Filter the dataset to include only classes 1 and 7
subset_indices = [i for i, (img, label) in enumerate(dataset) if label in [1, 7]]
dataset_subset = torch.utils.data.Subset(dataset, subset_indices)

dataloader = torch.utils.data.DataLoader(
    dataset_subset, 
    batch_size=batch_size,
    shuffle=True,
    drop_last=True, 
)

# Modified architecture for MNIST
net_model = UNetModelWrapper(
    dim=(1, 28, 28),
    num_res_blocks=1,  # Reduced from 2
    num_channels=num_channel,
    channel_mult=[1, 2, 4],  # Simplified multiplier sequence
    num_heads=2,  # Reduced from 4
    num_head_channels=32,  # Reduced from 64
    attention_resolutions="14",
    dropout=0.1,
).to(device)

# Rest of the training loop
optmizer = torch.optim.Adam(net_model.parameters(), lr=lr)
loss_fn = nn.MSELoss()
otfm = OTFM()
sigma_min = 0.01
n_epochs = 10

for epoch in range(n_epochs):
    net_model.train()
    running_loss = 0.0
    
    print("epoch: ", epoch)
    
    for batch, _ in dataloader:
        batch = batch.to(device)
        
        x_0 = torch.randn_like(batch).to(device)
        x_1 = batch
        t = torch.rand(len(batch), 1, 28, 28).to(device)
        
        x_t = otfm.compute_x_t(x_0, x_1, sigma_min)
        dx_t = otfm.compute_dx_t(x_0, x_1, sigma_min, t)
        
        optmizer.zero_grad()
        loss = loss_fn(net_model(t, x_t), dx_t)
        loss.backward()        
        optmizer.step()
        
        running_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {running_loss/len(dataloader)}")

epoch:  0


KeyboardInterrupt: 

In [43]:
class FlowMatchingSampler:
    def __init__(self, net_model, device='cuda'):
        self.net_model = net_model
        self.device = device
        
    def vector_field(self, t, x):
        """Vector field for ODE solver"""
        # Reshape t to match MNIST dimensions (1 channel, 28x28)
        t_shaped = t.reshape(1, 1, 1, 1).expand(x.shape[0], 1, 28, 28).to(self.device)
        return self.net_model(t_shaped, x)
    
    def generate_samples(self, num_samples, rtol=1e-5, atol=1e-5):
        """
        Generate samples using the trained flow matching model
        
        Args:
            num_samples: Number of samples to generate
            rtol: Relative tolerance for ODE solver
            atol: Absolute tolerance for ODE solver
        """
        # Initial noise - adapted for MNIST dimensions
        x_0 = torch.randn(num_samples, 1, 28, 28).to(self.device)
        
        # Time points for integration
        t = torch.linspace(0, 1, 2).to(self.device)
        
        # Solve ODE
        self.net_model.eval()
        with torch.no_grad():
            samples = odeint(
                self.vector_field,
                x_0,
                t,
                method='dopri5',
                rtol=rtol,
                atol=atol
            )
        
        # Return final samples (at t=1)
        return samples[-1]

def save_samples(samples, filename):
    """Helper function to save generated samples"""
    # Ensure samples are in correct range [0, 1]
    samples = torch.clamp(samples, 0, 1)
    # Save as grid of images
    torchvision.utils.save_image(
        samples,
        filename,
        nrow=int(np.sqrt(len(samples))),
        normalize=True
    )

# Usage example:
sampler = FlowMatchingSampler(net_model, device=device)
num_samples = 16  # Generate a 4x4 grid of samples
samples = sampler.generate_samples(num_samples)
save_samples(samples, "mnist_samples.png")