Copyright 2024 Thibaut Issenhuth, Ludovic Dos Santos, Jean-Yves Franceschi, Alain Rakotomamonjy

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
import torch
import matplotlib.pyplot as plt
import math

import numpy as np
from matplotlib.ticker import ScalarFormatter
from cycler import cycler

from matplotlib import rc
rc('font', **{'family': 'serif', 'serif': 'Computer Modern Roman', 'size'   : 11})
rc('text', usetex=True)

In [None]:
def return_gaussians(size):
    means = [(1,-1),(1,1)]
    var = 0.01
    x = torch.FloatTensor([])
    for mean_tuple in means:
        gauss = torch.randn((size//len(means),2)) * var
        gauss[:,0] += mean_tuple[0]
        gauss[:,1] += mean_tuple[1]
        x = torch.cat((x, gauss),dim=0)
    return x

def return_input_distrib(size):
    means = [(-1,0)]
    var = (0.01, 0.5)
    x = torch.FloatTensor([])
    for mean_tuple in means:
        gauss = torch.randn((size//len(means),2))
        gauss[:,0] = mean_tuple[0] + gauss[:,0] * var[0]
        gauss[:,1] = mean_tuple[1] + gauss[:,1] * var[1]
        x = torch.cat((x, gauss),dim=0)
    return x

def construct_test_batch(sigmas, device):
    x_A = torch.FloatTensor([1.,1.]).to(device)
    x_B = torch.FloatTensor([1.,-1.]).to(device)

    z_A = torch.FloatTensor([-1.,1.]).to(device)
    z_B = torch.FloatTensor([-1.,-1.]).to(device)

    n_sigmas = len(sigmas)
    n_points = n_sigmas * 6
    x_test = torch.zeros((n_points, 2)).to(device)
    z_test = torch.zeros((n_points, 2)).to(device)
    x_i_test = torch.zeros((n_points, 2)).to(device)
    sigmas_test = torch.zeros((n_points)).to(device)

    for i in range(n_sigmas):
        sigma_i = sigmas[i].to(device)
        x_i_test[4*i,:] = (1 - sigma_i) * x_A + sigma_i * z_A
        x_i_test[4*i + 1,:] = (1 - sigma_i) * x_A + sigma_i * z_B
        x_i_test[4*i + 2,:] = (1 - sigma_i) * x_B + sigma_i * z_A
        x_i_test[4*i + 3,:] = (1 - sigma_i) * x_B + sigma_i * z_B

        x_test[4*i,:] = x_A
        x_test[4*i+1,:] = x_A
        x_test[4*i+2,:] = x_B
        x_test[4*i+3,:] = x_B

        z_test[4*i,:] = z_A
        z_test[4*i+1,:] = z_B
        z_test[4*i+2,:] = z_A
        z_test[4*i+3,:] = z_B
        sigmas_test[4*i:5*i] = sigma_i
    return x_test, x_i_test, z_test, sigmas_test


In [None]:
class PositionalEmbedding(torch.nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x

class MLP(torch.nn.Module):
    def __init__(self, data_dim=2, hidden_dim=256):
        super().__init__()
        self.map_noise = PositionalEmbedding(hidden_dim)
        self.net_0 = torch.nn.Sequential(torch.nn.Linear(data_dim, hidden_dim), torch.nn.GELU())
        self.net_1 = torch.nn.Sequential(torch.nn.Linear(2 * hidden_dim, hidden_dim), torch.nn.GELU(),
                                    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.GELU(),
                                    torch.nn.Linear(hidden_dim, data_dim))
    def forward(self, x, t):
        t = self.map_noise(t)
        x = self.net_0(x)
        x = torch.cat((x,t),dim=1)
        x = self.net_1(x)
        return x

class ConsistencyModel(torch.nn.Module):
    def __init__(self, sigma_min, sigma_max, sigma_data=0.5, hidden_dim=256):
        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.sigma_data = sigma_data
        self.model = MLP(hidden_dim=hidden_dim)

    def forward(self, x, sigma):
        sigma = sigma.unsqueeze(1)
        c_skip = self.sigma_data**2 / ((sigma - self.sigma_min) ** 2 + self.sigma_data**2)
        c_out = (self.sigma_data * (sigma - self.sigma_min)) / (self.sigma_data**2 + sigma**2) ** 0.5
        c_in = 1  / (self.sigma_data ** 2 + sigma ** 2).sqrt()

        c_noise = sigma.log() / 4

        F_x = self.model((c_in * x), c_noise.flatten())
        D_x = c_skip * x + c_out * F_x.to(torch.float32)
        return D_x


In [None]:
def get_sigmas_karras(num_timesteps, sigma_min, sigma_max, rho=7.0, device="cpu"):
    """Constructs the noise schedule of Karras et al. (2022)."""
    '''ramp = torch.linspace(0, 1, int(n))
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return torch.flip(sigmas, dims=(0,))'''

    rho_inv = 1.0 / rho
    # Clamp steps to 1 so that we don't get nans
    steps = torch.arange(num_timesteps, device=device) / max(num_timesteps - 1, 1)
    sigmas = sigma_min**rho_inv + steps * (
        sigma_max**rho_inv - sigma_min**rho_inv
    )
    sigmas = sigmas**rho
    return sigmas

def improved_timesteps_schedule(current_training_step, total_training_steps, initial_timesteps = 10, final_timesteps = 1280):
    """Implements the improved timestep discretization schedule.

    Parameters
    ----------
    current_training_step : int
        Current step in the training loop.
    total_training_steps : int
        Total number of steps the model will be trained for.
    initial_timesteps : int, default=2
        Timesteps at the start of training.
    final_timesteps : int, default=150
        Timesteps at the end of training.

    Returns
    -------
    int
        Number of timesteps at the current point in training.

    References
    ----------
    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)
    """
    total_training_steps_prime = math.floor(
        total_training_steps
        / (math.log2(math.floor(final_timesteps / initial_timesteps)) + 1)
    )
    num_timesteps = initial_timesteps * math.pow(
        2, math.floor(current_training_step / total_training_steps_prime)
    )
    num_timesteps = min(num_timesteps, final_timesteps) + 1

    return num_timesteps

def lognormal_timestep_distribution(num_samples, sigmas, mean = -1.1, std = 2.0):
    """Draws timesteps from a lognormal distribution.

    Parameters
    ----------
    num_samples : int
        Number of samples to draw.
    sigmas : Tensor
        Standard deviations of the noise.
    mean : float, default=-1.1
        Mean of the lognormal distribution.
    std : float, default=2.0
        Standard deviation of the lognormal distribution.

    Returns
    -------
    Tensor
        Timesteps drawn from the lognormal distribution.

    References
    ----------
    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)
    """
    #sigmas = torch.flip(sigmas, dims=(0,))
    pdf = torch.erf((torch.log(sigmas[1:]) - mean) / (std * math.sqrt(2))) - torch.erf(
        (torch.log(sigmas[:-1]) - mean) / (std * math.sqrt(2))
    )
    #pdf = torch.flip(pdf, dims=(0,))
    #print('pdf : ', pdf)
    timesteps = torch.multinomial(pdf, num_samples, replacement=True)

    return timesteps

def improved_loss_weighting(sigmas):
    """Computes the weighting for the consistency loss.

    Parameters
    ----------
    sigmas : Tensor
        Standard deviations of the noise.

    Returns
    -------
    Tensor
        Weighting for the consistency loss.

    References
    ----------
    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)
    """
    return 1 / (sigmas[1:] - sigmas[:-1])

def draw_transport_cost_per_timestep(batch_data_preds, batch_data, batch_x_t, sigmas_i, sigmas, steps, training_step):
    fig, ax = plt.subplots( nrows=1, ncols=1, figsize=(4.5,3.5))

    sum_per_steps = torch.zeros(len(sigmas))
    count_per_steps = torch.zeros(len(sigmas)) + 1e-5

    color_std = 'red'
    color_GI = 'dodgerblue'
    for i in range(len(batch_data)):
        dist = ((batch_data_preds[i,:] - batch_x_t[i,:])**2).sum().cpu().numpy()
        sigma_i = sigmas_i[i].cpu().numpy()
        sum_per_steps[steps[i]] += dist
        count_per_steps[steps[i]] += 1
        if i==0:
            ax.scatter(sigma_i, dist, alpha=1., marker='x', color=color_GI, label="Pointwise cost for GC")
        else:
            ax.scatter(sigma_i, dist, alpha=1., marker='x', color=color_GI)

    sum_per_steps_standard = torch.zeros(len(sigmas))
    count_per_steps_standard = torch.zeros(len(sigmas)) + 1e-5
    for i in range(len(batch_data)):
        dist = ((batch_data[i,:] - batch_x_t[i,:])**2).sum().cpu().numpy()
        sigma_i = sigmas_i[i].cpu().numpy()
        sum_per_steps_standard[steps[i]] += dist
        count_per_steps_standard[steps[i]] += 1
        if i==0:
            ax.scatter(sigma_i, dist, alpha=1., marker='+', color=color_std, label="Pointwise cost for IC")
        else:
            ax.scatter(sigma_i, dist, alpha=1., marker='+', color=color_std)

    sigma_array = []
    mean_array = []
    mean_standard_array = []
    for i in range(len(sigmas)):
        sigma_array.append(sigmas[i].item())
        mean_array.append((sum_per_steps[i] / count_per_steps[i]).item())
        mean_standard_array.append((sum_per_steps_standard[i] / count_per_steps_standard[i]).item())

    ax.plot(sigma_array[:-1], mean_array[:-1], color=color_GI, label="Mean cost for GC")
    ax.plot(sigma_array[:-1], mean_standard_array[:-1], color=color_std, label="Mean cost for IC")

    ax.grid(linestyle='--')
    ax.set_axisbelow(True)
    ax.tick_params(axis="x", direction="in")
    ax.tick_params(axis="y", direction="in")

    plt.xlabel(r'Timestep')
    plt.ylabel(r'Mean quadratic transport cost')
    ax.set_title(r'Gaussians 1m-2m')

    for axis in [ax.xaxis, ax.yaxis]:
        formatter = ScalarFormatter()
        formatter.set_scientific(False)
        axis.set_major_formatter(formatter)

    handles, legend_labels = ax.get_legend_handles_labels()
    fig.legend(handles, legend_labels, ncol=1, bbox_to_anchor=(0.6,0.9))
    fig.tight_layout()
    fig.savefig('viz/transport_cost_1to2_'+str(training_step)+'.pdf', bbox_inches='tight')
    plt.show()

def draw_arrows(batch_z, batch_preds, color):
    for i in range(batch_z.shape[0]):
        x_i, y_i = batch_z[i,0], batch_z[i,1]
        pred_x_i, pred_y_i = batch_preds[i,0], batch_preds[i,1]
        dx, dy = pred_x_i - x_i, pred_y_i - y_i
        plt.arrow(x_i.cpu().numpy(), y_i.cpu().numpy(), dx.cpu().numpy(), dy.cpu().numpy(), alpha=0.4, \
            length_includes_head=True, facecolor=color, edgecolor=color,width=0.005)


In [None]:
from ema_pytorch import EMA

len_data = 10000
batch_size = 256
training_steps = 10000
lr = 0.00005
s0 = 30
s1 = 30
rho = 3
sigma_min = 0.001
sigma_max = 1
hidden_dim = 256
generator_induced_trajectory = True
device = 'cuda:0'
print_freq = 1000
model = ConsistencyModel(sigma_min, sigma_max, hidden_dim=hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model_ema = EMA(
    model,
    beta = 0.9999,              # exponential moving average factor
    update_after_step = 0,    # only after this number of .update() calls will it start updating
    update_every = 2,          # how often to actually update, to save on compute (updates every 10th .update() call)
)
datapoints = return_gaussians(len_data)
dataset = torch.utils.data.TensorDataset(datapoints)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
sigmas = get_sigmas_karras(s0, sigma_min, sigma_max, rho=rho)
###### Training Loop #########
current_training_step = 0
while True:
    for (idx, batch) in enumerate(loader):
        data = batch[0].to(device)
        batch_z = return_input_distrib(data.shape[0]).to(device)
        current_n_step = improved_timesteps_schedule(current_training_step, training_steps,
                                        initial_timesteps = s0, final_timesteps = s1)
        sigmas = get_sigmas_karras(current_n_step, sigma_min, sigma_max, rho=rho)
        steps = lognormal_timestep_distribution(len(data), sigmas)
        loss_weights = improved_loss_weighting(sigmas)[steps].to(device)
        sigmas_i = sigmas[steps].to(device)
        sigmas_ip1 = sigmas[steps + 1].to(device)
        batch_z_i = (1 - sigmas_i.view(sigmas_i.shape[0],1)) * data + sigmas_i.view(sigmas_i.shape[0],1) * batch_z
        batch_z_ip1 = (1 - sigmas_ip1.view(sigmas_ip1.shape[0],1)) * data + sigmas_ip1.view(sigmas_ip1.shape[0],1) * batch_z
        if generator_induced_trajectory:
            with torch.no_grad():
                data_pred = model_ema(batch_z_i, sigmas_i)
            batch_z_i = (1 - sigmas_i.view(sigmas_i.shape[0],1)) * data_pred + sigmas_i.view(sigmas_i.shape[0],1) * batch_z
            batch_z_ip1 = (1 - sigmas_ip1.view(sigmas_ip1.shape[0],1)) * data_pred + sigmas_ip1.view(sigmas_ip1.shape[0],1) * batch_z
        optimizer.zero_grad()
        with torch.no_grad():
            pred_z_i = model(batch_z_i, sigmas_i)
        pred_z_ip1 = model(batch_z_ip1, sigmas_ip1)
        loss = ((pred_z_ip1 - pred_z_i) ** 2).sum(dim=1)
        loss = (loss_weights * loss).mean()
        loss.backward()
        optimizer.step()
        model_ema.update()
        if (current_training_step % print_freq) == 0:
            print('step : ', current_training_step)
            print('loss : ', loss)
            with torch.no_grad():
                x_test, x_i_test, z_test, sigmas_test = data, batch_z_i, batch_z, sigmas_i
                if generator_induced_trajectory:
                    draw_transport_cost_per_timestep(data_pred, data, batch_z_ip1, sigmas_i, sigmas, steps, current_training_step)
                fig, ax = plt.subplots( nrows=1, ncols=1, figsize=(4.5,3.5))
                if generator_induced_trajectory:
                    data_pred_test = model_ema(x_i_test, sigmas_test)
                    tilde_x_test = (1 - sigmas_test.unsqueeze(1)) * data_pred_test + sigmas_test.unsqueeze(1) * z_test

                    draw_arrows(data_pred_test, tilde_x_test,'dodgerblue')
                    draw_arrows(x_i_test, data_pred_test, 'purple')
                    draw_arrows(x_test, x_i_test, 'red')

                    ax.scatter(x_i_test[:,0].cpu().numpy(), x_i_test[:,1].cpu().numpy(), marker='*', alpha=0.4, label='intermediate (IC)', color='red')
                    ax.scatter(tilde_x_test[:,0].cpu().numpy(), tilde_x_test[:,1].cpu().numpy(), marker='<', alpha=0.4, label='intermediate (GC)',color='dodgerblue')
                    ax.scatter(data_pred_test[:,0].cpu().numpy(), data_pred_test[:,1].cpu().numpy(), marker='>', alpha=0.4, label='generated (IC)',color='purple')
                else:
                    data_pred_test = model_ema(x_i_test, sigmas_test)
                    ax.scatter(x_i_test[:,0].cpu().numpy(), x_i_test[:,1].cpu().numpy(), alpha=0.4, label='intermediate from independent coupling', color='tab:blue')
                    ax.scatter(data_pred_test[:,0].cpu().numpy(), data_pred_test[:,1].cpu().numpy(), alpha=0.4, label='generated from independent coupling', color='tab:green')
                    draw_arrows(x_test, x_i_test, 'tab:gray')
                    draw_arrows(x_i_test, data_pred_test, 'tab:pink')
                ax.scatter(data[:32,0].cpu().numpy(), data[:32,1].cpu().numpy(), marker='o', alpha=0.4, label='data', color='darkcyan')
                ax.scatter(batch_z[:16,0].cpu().numpy(), batch_z[:16,1].cpu().numpy(), marker='s', alpha=0.4, label='noise', color='blue')
                ax.scatter(batch_z[-16:,0].cpu().numpy(), batch_z[-16:,1].cpu().numpy(), marker='s', alpha=0.4, color='blue')

                ax.grid(linestyle='--')
                ax.set_axisbelow(True)
                ax.tick_params(axis="x", direction="in")
                ax.tick_params(axis="y", direction="in")
                ax.set_ylim([-1.6, 1.2])
                plt.xlabel(r'$x$')
                plt.ylabel(r'$y$')
                for axis in [ax.xaxis, ax.yaxis]:
                    formatter = ScalarFormatter()
                    formatter.set_scientific(False)
                    axis.set_major_formatter(formatter)

                handles, legend_labels = ax.get_legend_handles_labels()
                handles =[handles[-1]] +  handles[:-1]
                legend_labels =[legend_labels[-1]] + legend_labels[:-1]
                fig.legend(handles, legend_labels, ncol=2, bbox_to_anchor=(0.9,0.05), handletextpad=0.1)
                ax.set_title(r'Gaussians 1m-2m')
                fig.tight_layout()
                fig.savefig('viz/generations_'+str(current_training_step)+'.pdf', bbox_inches='tight')
                plt.show()

        current_training_step += 1
        if current_training_step == training_steps:
            break
    if current_training_step == training_steps:
        break