TODO

Flow matching

Linear Gaussian Flow

In [53]:
import lightning as l
from torch import optim, nn
import torch
import torch.nn.functional as F
from lightning.pytorch.loggers import TensorBoardLogger
import ot
import numpy as np


class Flow(l.LightningModule):
    def __init__(self, net=None, dim: int =2, h: int = 64):
        super().__init__()
        if net is None:
            net = nn.Sequential(
                nn.Linear(dim+1, h), nn.SiLU(), # input dim, +1 for time
                nn.Linear(h, h), nn.SiLU(),
                nn.Linear(h, h), nn.SiLU(),
                nn.Linear(h, dim)
            )
        self.net = net
        self.criterion = F.mse_loss

        # net will be trained to predict the flow vector field.

    def compute_noisy_sample(self, x, t, noise):
        # t=0 : noise distribution t=1: data distribution
        xt = (1-t)*noise + t*x
        return xt

    def training_step(self, x):
        # Ground Truth Image batches will be fed. compute flow matching loss here.
        
        ts = torch.rand(x.size(0), 1).to(x.device)
        noise = torch.randn_like(x)
        xt = self.compute_noisy_sample(x, ts, noise)
        flow_pred = self.net(torch.cat([xt, ts], dim=-1))
        flow = x - noise
        loss = self.criterion(flow_pred, flow)
        
        
        self.log('train_loss', loss)
        return loss

    @torch.no_grad()
    def validation_step(self, x):
        
        # Generate samples
        generated_samples = self.generate(sample_steps=100, num_samples=x.size(0))

        # Compute cost matrix (pairwise distances)
        cost_matrix = ot.dist(x.cpu().numpy(), generated_samples.cpu().numpy(), metric='euclidean')

        # Uniform weights for both distributions
        weights = np.ones(x.size(0)) / x.size(0)

        # Compute Earth Mover's Distance
        emd = ot.emd2(weights, weights, cost_matrix)

        # Log EMD as a validation metric
        self.log('val_emd', emd)
        




    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-2)
        return optimizer

    @torch.no_grad()
    def generate(self, sample_steps: int = 100, num_samples: int = 1):
        xt = torch.randn(num_samples, 2).to(self.device)
        ts = torch.linspace(0, 1, sample_steps).to(self.device)
        dt = ts[1] - ts[0]
        for t in ts[:-1]:
            flow = self.net(torch.cat([xt, t*torch.ones_like(xt[:, 0:1])], dim=-1))
            x_mid = xt + flow*dt/2
            mid_t = t + dt/2
            xt = xt + dt * self.net(torch.cat([x_mid, mid_t*torch.ones_like(x_mid[:, 0:1])], dim=-1))
        return xt

In [None]:
# 2d data generation. with make moon
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

X, y = make_moons(n_samples=1000, noise=0.1)
X = torch.tensor(X).float()
y = torch.tensor(y).float()
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()


In [None]:
# Start TensorBoard
%load_ext tensorboard
%tensorboard --logdir tb_logs

In [None]:
import os
import matplotlib.pyplot as plt
from lightning.pytorch.callbacks import Callback
import numpy as np
from torch.utils.data import DataLoader, TensorDataset


class PlotSamplesCallback(Callback):
    def __init__(self, interval=5, save_dir="generated_samples"):
        """
        Callback to plot and save generated samples during training.
        
        Args:
            interval (int): Frequency of epochs at which to generate plots.
            save_dir (str): Directory where plots and sample data will be saved.
        """
        self.interval = interval
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)  # Ensure the save directory exists
        

    def on_train_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.interval == 0:
            # Generate samples
            samples = pl_module.generate(num_samples=500)  # Ensure `generate` is implemented in your model
            samples = samples.detach().cpu().numpy()
            # Plot samples
            plt.figure(figsize=(8, 6))
            plt.scatter(samples[:, 0], samples[:, 1], c='r', s=10)
            plt.title(f"Generated Samples - Epoch {trainer.current_epoch}")
            plt.xlabel("X-axis")
            plt.ylabel("Y-axis")
            
            # Save plot
            plot_path = os.path.join(self.save_dir, f"epoch_{trainer.current_epoch}.png")
            plt.savefig(plot_path)
            plt.close()  # Close the plot to free memory
            
            # Save sample data
            samples_path = os.path.join(self.save_dir, f"samples_epoch_{trainer.current_epoch}.npy")
            np.save(samples_path, samples)

            # Log the plot to TensorBoard
            trainer.logger.experiment.add_image(
                "generated samples", plt.imread(plot_path), dataformats='HWC', global_step=trainer.global_step
            )



class MoonDataLoader(l.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        X_train, _ = make_moons(n_samples=5000, noise=0.1)
        X_test, _ = make_moons(n_samples=1000, noise=0.1)
        self.X_train = torch.tensor(X_train).float()
        
        self.X_test = torch.tensor(X_test).float()
        



    def train_dataloader(self):
        
        return DataLoader(self.X_train, batch_size=self.batch_size, shuffle=True)
        

    def val_dataloader(self):
        
        return DataLoader(self.X_test, batch_size=self.batch_size, shuffle=False)


# Define TensorBoard logger
logger = TensorBoardLogger("tb_logs", name="simple_flow")
# Lightning Trainer with callback
model = Flow()
plot_callback = PlotSamplesCallback(interval=5)
trainer = l.Trainer(max_epochs=10000, callbacks=[plot_callback], logger=logger, check_val_every_n_epoch=5)
trainer.fit(model, MoonDataLoader(batch_size=516))

# Generate samples
samples = model.generate(num_samples=1000, sample_steps = 100)
plt.scatter(samples[:, 0], samples[:, 1], c='r')
plt.show()

In [30]:
X_train, y_train = make_moons(n_samples=5000, noise=0.1)
X_test, y_test = make_moons(n_samples=1000, noise=0.1)

X_train = torch.tensor(X_train).float()
y_train = torch.tensor(y_train).float()
X_test = torch.tensor(X_test).float()
y_test = torch.tensor(y_test).float()

In [None]:
dataloader = DataLoader(X_train, batch_size=64, shuffle=True)

# get one batch from dataloader
x = next(iter(dataloader))
print(x)