In [29]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage

from models.unet.unet import UNetModelWrapper
from models.unet import UNetModel
from flow_matching.models import OTFM
from torchdyn.core import NeuralODE
import sys

In [30]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.device(device)

device(type='cpu')

# 1 Experiments

As in Lipman et al. 2023, we use a MLP with 5-layers of 512 neurons for the 2D examples and the UNet architecture from Dhariwal & Nichol (2021) for the images. 

We generate results and compare the resutls for the models FM-OT and FM-Diffusion.

## 1.1 make_moons

## 1.2  MNIST

## 1.3 Checkboards

## 1.4 CIFAR10

We start by implemeting the flow matching and applying it to the CIFAR10 (Krizhevsky et al., 2009) dataset. As in Lipman et al., 2023, we evaluate likelihood and samples from the model using dopri5 (Dormand & Prince, 1980).

In [31]:
batch_size = 128
num_channel = 128
num_workers = 4
lr = 2e-4

In [32]:
dataset = datasets.CIFAR10(
    root="./data_cifar10",
    train=True, 
    download=True, 
    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
)

dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=batch_size,
    shuffle=True,
    drop_last=True, 
)

Files already downloaded and verified


In [33]:
net_model = UNetModelWrapper(
    dim=(3, 32, 32),
    num_res_blocks=2,
    num_channels=num_channel,
    channel_mult=[1, 2, 2, 2],
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="16",
    dropout=0.1,
).to(
    device
)

In [None]:
# training
optmizer = torch.optim.Adam(net_model.parameters(), lr=lr)
loss_fn = nn.MSELoss()
otfm = OTFM()
sigma_min = 0.1
n_epochs = 2

for epoch in range(n_epochs):
    net_model.train()
    running_loss = 0.0
    
    for batch, _ in dataloader:
        
        batch = batch.to(device)
        
        x_0 = torch.randn_like(batch).to(device)

        x_1 = batch

        t = torch.rand(len(batch), 3, 32, 32).to(device)
        
        x_t = otfm.compute_x_t(x_0, x_1, sigma_min)
        dx_t = otfm.compute_dx_t(x_0, x_1, sigma_min, t)
        
        optmizer.zero_grad()
        
        loss = loss_fn(net_model(t, x_t), dx_t)
        
        loss.backward()        
        optmizer.step()
        
        running_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {running_loss/len(dataloader)}")

torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
torch.Size([128, 3, 32, 32])
torch.Size([128, 32, 32])
torch.Size([128, 32])
