# Homework6: Flow models

## Task 1: Theory (5pt)

### Problem 1: KFP theorem (1pt)

We have faced with 2 different formulations of Kolmogorov-Fokker-Planck theorem.

1) continuity equation in continuous-in-time NF:
$$
\frac{d \log p(\mathbf{x}(t), t)}{d t} = - \text{tr} \left( \frac{\partial f(\mathbf{x}, t)}{\partial \mathbf{x}} \right);
$$

2) the general form of the KFP equation in SDEs:
$$
\frac{\partial p(\mathbf{x}, t)}{\partial t} = - \text{div}\left(\mathbf{f}(\mathbf{x}, t) p(\mathbf{x}, t)\right) + \frac{1}{2} g^2(t) \Delta p(\mathbf{x}, t).
$$

In this task your goal is to prove that the first formulation is a special case of the more general second formulation.

**Note:** The derivation in the first formulation is total derivative (not partial).

```
your solution
```

### Problem 2: DDPM as SDE discretization (2pt)

We have proved that DDPM is a discretization of the SDE
$$
	d \mathbf{x} = - \frac{1}{2} \beta(t) \mathbf{x}(t) dt + \sqrt{\beta(t)} \cdot d \mathbf{w}.
$$
Here $\mathbf{f}(\mathbf{x}, t) = - \frac{1}{2} \beta(t) \mathbf{x}(t)$, $g(t) = \sqrt{\beta(t)}$.

Recall reverse SDE
$$
    d\mathbf{x} = \left(\mathbf{f}(\mathbf{x}, t) - g^2(t) \frac{\partial \log p_t(\mathbf{x})}{\partial \mathbf{x}}\right) dt + g(t) d \mathbf{w}.
$$

The reverse SDE of the DDPM model will be
$$
    d\mathbf{x}(t) = -\beta(t)\left[\frac{x(t)}{2} + \nabla_{\mathbf{x}}\log p_t(\mathbf{x}(t))\right]dt + \sqrt{\beta(t)}d\mathbf{w}.
$$

The DDPM uses the following form of ancestral sampling
$$
\mathbf{x}_{t-1} = \frac{1}{\sqrt{1 - \beta_t}} \cdot \mathbf{x}_t + \frac{\beta_t}{\sqrt{1 - \beta_t}} \cdot \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t | \boldsymbol{\theta}) +  \sqrt{\beta_t} \cdot \boldsymbol{\epsilon}.
$$
(Here we assumed that $p(\mathbf{x}_{t - 1} | \mathbf{x}_t, \boldsymbol{\theta}) = \mathcal{N} \bigl(\boldsymbol{\mu}_{\boldsymbol{\theta}, t}(\mathbf{x}_t), \beta_t \cdot \mathbf{I}\bigr)$).

Here is your task to validate that DDPM iterative update scheme is actually discretization of SDE by letting $t \in \{0,\ldots,\frac{N-1}{N}\}$, $\Delta t = 1/N$, $\mathbf{x}(t-\Delta t) = \mathbf{x}_{s-s}$, $\mathbf{x}(t) = \mathbf{x}_s$, and $\beta(t)\Delta t = \beta_s$, s.e.:

In this task your goal is to show that the ancestral sampling is a discretization of the DDPM reverse SDE.

**Hints**:
1. use $dt < 0$;
2. $\beta_t = - \beta(t) dt$;
3. $d\mathbf{w} = \boldsymbol{\epsilon} \cdot \sqrt{-dt}$;
4. drop the terms with the order of $o(dt)$.

```
your solution
```

### Problem 3: Flow matching distribution (2pt)

Let consider flow matching model between two same distributions:
$$
    p_0(x) = \mathcal{N}(0, \sigma^2) \quad p_1(x) = \mathcal{N}(0, \sigma^2)
$$

Your goal is to find the analytical expression for distribution $p_t(x_t)$.

**Note:** you have to get nonlinear expression for variance, try to understand this effect.

```
your solution
```

In [None]:
COMMIT_HASH = "11668881e2da2ea7938417bdabda0397660508c8"
!if [ -d dgm_utils ]; then rm -Rf dgm_utils; fi
!git clone https://github.com/r-isachenko/dgm_utils.git
%cd dgm_utils
!git checkout {COMMIT_HASH}
!pip install ./
!pip install torchdiffeq
%cd ./..
!rm -Rf dgm_utils

In [None]:
from dgm_utils import load_dataset, BaseModel, train_model
from dgm_utils import visualize_images, visualize_2d_data

In [None]:
import numpy as np

from typing import Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.distributions.normal import Normal
from torchdiffeq import odeint, odeint_adjoint

if torch.cuda.is_available():
    DEVICE = "cuda"
    print('GPU found :)') 
else: 
    DEVICE = "cpu"
    print('GPU not found :(')

## Task 2: Continuous-time Normalizing Flows (4 pt)

In this part you have to implement Continuous-time Normalizing Flow and apply it to 2D dataset.

In [None]:
COUNT = 5_000

train_data, train_labels, test_data, test_labels = load_dataset('moons', size=COUNT, with_targets=True)
visualize_2d_data(train_data, test_data, train_labels, test_labels)

Let's revisit continuous normalizing flows (CNFs).

In CNFs, a central task is efficiently computing derivatives, particularly the trace of the Jacobian of the dynamics function $f(\mathbf{x}(t), t)$. As we saw in Lecture 11, the change in log-probability over time is given by:

$$
\frac{d \log p(\mathbf{x}(t))}{dt} = -\text{Tr}\left( \frac{\partial f(\mathbf{x}(t), t)}{\partial \mathbf{x}(t)} \right).
$$

In high-dimensional spaces computing the exact trace of the Jacobian $\frac{\partial f}{\partial \mathbf{x}}$ can be computationally expensive. To overcome this challenge, could be used **Hutchinson's Trace Estimator**, which provides an **efficient** and **unbiased** estimate:

$$
\text{Tr}\left( \frac{\partial f}{\partial \mathbf{x}} \right) = \mathbb{E}_{\mathbf{\epsilon} \sim p(\mathbf{\epsilon})} \left[ \mathbf{\epsilon}^\top \frac{\partial f}{\partial \mathbf{x}} \mathbf{\epsilon} \right],
$$

where $\mathbf{\epsilon}$ is a random vector sampled from a standard normal distribution $\mathcal{N}(0, \mathbf{I})$.

**Note:** In practice, we approximate this expectation using a single sample of $\mathbf{\epsilon}$ to efficiently estimate the trace.

However, since the data we consider in this task is only two-dimensional, we can easily compute the entire Jacobian. For this task we will use the same conditional model as in previous homeworks to parametrize $\frac{d \mathbf{x}(t)}{d t}$.

Firstly, lets define time embedding layer, that works with values in range $[0, 1]$.

**Note:** we can't use here `nn.Embedding`, because it takes only integers.

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.register_buffer('freqs', torch.arange(1, dim // 2 + 1) * torch.pi)

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        emb = self.freqs * t.unsqueeze(-1)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [None]:
class ConditionalMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 128):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.x_proj = nn.Linear(input_dim, self.hidden_dim)
        self.t_proj = TimeEmbedding(self.hidden_dim)
        self.backbone = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Tanh(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Tanh(),
            nn.Linear(self.hidden_dim, input_dim),
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]:
        x = self.x_proj(x)
        t = self.t_proj(t)
        x = x + t
        x = F.tanh(x)
        return self.backbone(x)


def test_conditional_mlp():
    SHAPE = [2, 20]
    x = torch.ones(SHAPE)
    t = torch.ones((2,)).long() * 5
    model = ConditionalMLP(input_dim=20)
    output = model(x, t)
    assert list(output.shape) == SHAPE


test_conditional_mlp()

In [None]:
class CNFModel(BaseModel):
    def __init__(self, input_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.input_dim = input_dim
        self.model = ConditionalMLP(input_dim, hidden_dim)
        self.prior = Normal(torch.tensor(0.0), torch.tensor(1.0))
    
    def odefunc(self, t: torch.Tensor, states: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
        z, _ = states
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)
            # ====
            # your code
            # 1) apply model to get first order derivatives
            # 2) get second order derivative using torch.autograd
            # Do not forget to use epsilon
            
            # ====
        return dz_dt, -trace
    
    def forward(self, x: torch.Tensor, reverse: bool = False) -> Tuple[torch.Tensor]:
        x = x.to(self.device)
        dz_dt = torch.zeros([x.shape[0], 1], device=self.device)

        # ====
        # your code
        # use odeint_adjoint to simulate self.odefunc
        # use reverse to simulate from 1 to 0 timesteps
        
        # ====
        
        return z, dz_dt
    
    def loss(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        z, dz_dt = self(x)
        # ====
        # your code
        # use CoV to get loglikelihood of p(x)
        
        # ====
        return {'total_loss': loss}
    
    @torch.no_grad()
    def sample(self, n: int) -> np.ndarray:
        # read carefully the sampling process
        z = self.prior.sample([n, self.input_dim]).to(self.device)
        x, _ = self(z, reverse=True)
        return x.cpu().numpy()

Now lets train the model, it takes some time :)

In [None]:
# ====
# your code
# choose these parameters
BATCH_SIZE = 
LR = 
EPOCHS = 
HIDDEN_DIM = 
# ====

model = CNFModel(input_dim=2, hidden_dim=HIDDEN_DIM)

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=BATCH_SIZE)

# try your own optimizer/scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

train_model(
    model,
    train_loader,
    test_loader,
    epochs=EPOCHS,
    optimizer=optimizer,
    device=DEVICE,
    n_samples=1024,
    visualize_samples=True
)

## Task 3: Flow matching on MNIST (5 pt) 

Finally, your task to train Flow matching model!

In [None]:
train_data, test_data = load_dataset("mnist", flatten=False, binarize=True)
visualize_images(train_data, "MNIST samples")

The model is written for you. We will use conditioned ResNet architecture. But you could change it if you want.

In [None]:
class ConditionedResnetBlock(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        # you could experiment with this architecture
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(dim, dim, kernel_size=1),
        )
        self.dim = dim
        self.embedding = TimeEmbedding(dim)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        time_embed = self.embedding(t).view(-1, self.dim, 1, 1)
        return x + self.block(x + time_embed)


class ConditionedSimpleResnet(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, n_filters: int, n_blocks: int
    ) -> None:
        super().__init__()
        # you could experiment with this architecture
        self.first_block = nn.Sequential(
            nn.Conv2d(in_channels, n_filters, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.layers = nn.Sequential(*[ConditionedResnetBlock(n_filters) for _ in range(n_blocks)])
        self.last_block = nn.Sequential(
            nn.ReLU(), nn.Conv2d(n_filters, out_channels, kernel_size=3, padding=1)
        )
        self.n_filters = n_filters

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        x = self.first_block(x)
        for layer in self.layers:
            x = layer(x, t)
        x = self.last_block(x)
        return x


def test_conditioned_resnet():
    model = ConditionedSimpleResnet(in_channels=1, out_channels=1, n_filters=16, n_blocks=1)
    x = torch.rand((1, 1, 28, 28))
    t = torch.zeros(size=(1,), dtype=torch.long)
    out1 = model(x, t)
    t = torch.ones(size=(1,), dtype=torch.long)
    out2 = model(x, t)
    assert not np.allclose(out1.detach().numpy(), out2.detach().numpy())


test_conditioned_resnet()

In conditional flow matching, our objective is to learn a vector field $ f_\theta(\mathbf{x}, t) $, parameterized by a neural network, that aligns with a known target vector field $f(\mathbf{x}, \mathbf{x}_1, t)$ at each point along a path connecting the data distribution and a base distribution. So, the training objective is defined as:

$$
\min_\theta\, \mathbb{E}_{t \sim U[0, 1]}\, \mathbb{E}_{\mathbf{x}_1 \sim p(\mathbf{x}_1)} \mathbb{E}_{\mathbf{x} \sim p_t(\mathbf{x} | \mathbf{x}_1)} \left[ \left\| f(\mathbf{x}, \mathbf{x}_1, t) - f_\theta(\mathbf{x}, t) \right\|^2 \right],
$$

In this task, we consider the **optimal transport conditional vector field**, defined by:
$$
f(\mathbf{x}, \mathbf{x}_1, t) = \frac{d\mathbf{x}}{dt} = \frac{\mathbf{x}_1 - (1 - \sigma_{\text{min}})\mathbf{x}}{1 - (1 - \sigma_{\text{min}})t},
$$
which means that $\mathbf{x}$ iterpolates linearly by making data more noisy:
$$
\mathbf{x}_t = t \mathbf{x}_1 + (1 - (1 - \sigma_{\text{min}})t) \mathbf{x}_0.
$$

Now, let's define the architecture of the Flow Matching model.

In [None]:
class FlowMatchingModel(BaseModel):
    def __init__(self, in_channels: int, out_channels: int, n_filters: int, n_blocks: int):
        super().__init__()
        self.model = ConditionedSimpleResnet(
            in_channels, out_channels, n_filters, n_blocks
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        return self.model(x, t)

    def loss(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        sigma_min = 1e-4
        # ====
        # your code
        # 1) samle time uniformly from 0 to 1
        # 2) calculate noised data and optimal flow
        # 3) predict flow using model
        # 4) calculate loss
        
        # ====
        return {'total_loss': loss}

    def odefunc(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return self(x, torch.full(x.shape[:1], t, device=self.device))

    @torch.no_grad()
    def sample(self, n: int) -> np.ndarray:
        # read carefully the sampling process
        z = torch.randn(n, 1, 28, 28, device=self.device)  # Start with noise

        # ====
        # your code
        # use odeint to sample from model
        # here we don't need to use adjoint because we use odeint only for sampling!
        
        # ====
        samples = states[1]
        return samples.cpu().numpy()

In [None]:
# ====
# your code
# choose these parameters
BATCH_SIZE = 
LR = 
EPOCHS = 
N_FILTERS = 
N_BLOCKS = 
# ====

model = FlowMatchingModel(
    in_channels=1, 
    out_channels=1, 
    n_filters=N_FILTERS, 
    n_blocks=N_BLOCKS
)

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=BATCH_SIZE)

# choose any optimizer/scheduler as you want
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

train_model(
    model,
    train_loader,
    test_loader,
    epochs=EPOCHS,
    optimizer=optimizer,
    device=DEVICE,
    n_samples=16,
    visualize_samples=True
)