In [None]:
%cd ..

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda")

# Test AffineInterp Class

In [None]:
from rectified_flow.flow_components.interpolation_solver import AffineInterp

batch_size = 500

def test_affine_interp():
    interp = AffineInterp('ddim') # Change this into ["ddim", "straight", "sin"]
    t = torch.rand((batch_size,), device=device)
    X_0 = torch.rand((batch_size, 3, 4, 2, 3), device=device)
    X_1 = torch.rand((batch_size, 3, 4, 2, 3), device=device)
    X_t, dot_X_t = interp(X_0, X_1, t)

    interp.solve(t, xt=X_t, dot_xt=dot_X_t)
    print(torch.max(torch.abs(interp.x0 - X_0)), torch.max(torch.abs(interp.x1 - X_1)))
    assert torch.allclose(interp.x0, X_0, atol=1e-4) and torch.allclose(interp.x1, X_1, atol=1e-4)
    
    interp.solve(t, x0=X_0, x1=X_1)
    print(torch.max(torch.abs(interp.xt - X_t)), torch.max(torch.abs(interp.dot_xt - dot_X_t)))
    assert torch.allclose(interp.xt, X_t, atol=1e-4) and torch.allclose(interp.dot_xt, dot_X_t, atol=1e-4)
    
    interp.solve(t, xt=X_t, x1=X_1)
    print(torch.max(torch.abs(interp.x0 - X_0)), torch.max(torch.abs(interp.dot_xt - dot_X_t)))
    assert torch.allclose(interp.x0, X_0, atol=1e-4) and torch.allclose(interp.dot_xt, dot_X_t, atol=1e-4)

    interp.solve(t, x0=X_0, xt=X_t)
    print(torch.max(torch.abs(interp.x1 - X_1)), torch.max(torch.abs(interp.dot_xt - dot_X_t)))
    assert torch.allclose(interp.x1, X_1, atol=1e-4) and torch.allclose(interp.dot_xt, dot_X_t, atol=1e-4)
    
    interp.solve(t, x0=X_0, dot_xt=dot_X_t)
    print(torch.max(torch.abs(interp.x1 - X_1)), torch.max(torch.abs(interp.xt - X_t)))
    assert torch.allclose(interp.x1, X_1, atol=1e-4) and torch.allclose(interp.xt, X_t, atol=1e-4)
    
    interp.solve(t, x1=X_1, dot_xt=dot_X_t)
    print(torch.max(torch.abs(interp.x0 - X_0)), torch.max(torch.abs(interp.xt - X_t)))
    assert torch.allclose(interp.x0, X_0, atol=1e-4) and torch.allclose(interp.xt, X_t, atol=1e-4)
    
    print(type(interp.x0), type(interp.x1), type(interp.xt), type(interp.dot_xt))

test_affine_interp()

# Test Rectified Flow Functions

In [None]:
from rectified_flow.models.dit import DiT, DiTConfig
from rectified_flow.rectified_flow import RectifiedFlow

DiT_reshaper_config = DiTConfig(
        input_size = 32,
        patch_size = 2,
        in_channels = 3,
        out_channels = 3,
        hidden_size = 128,
        depth = 5,
        num_heads = 4,
        mlp_ratio = 2,
        num_classes = 0,
        use_long_skip = False,
        final_conv = False,
    )

In [None]:
dit_toy = DiT(DiT_reshaper_config).to(device)

rf_func = RectifiedFlow(
    data_shape=(3, 32, 32),
    model=dit_toy,
    interp="straight",
    source_distribution="normal",
    is_independent_coupling=True,
    train_time_distribution="uniform",
    train_time_weight="uniform",
    criterion="mse",
    device=device,
    dtype=torch.float32,
)

In [None]:
batch_size = 64

X_0 = rf_func.sample_source_distribution(batch_size)
X_1 = torch.ones((batch_size, *rf_func.data_shape), device=device) # Sampled from data class
print(f"X_0: {X_0.shape}, X_1: {X_1.shape}")

t = rf_func.sample_train_time(batch_size)
print(f"t: {t.shape}")

with torch.no_grad():
	X_t, dot_X_t = rf_func.get_interpolation(X_0, X_1, t)
	# print(f"X_t: {X_t}, dot_X_t: {dot_X_t}")
	print(f"X_t device: {X_t.device}, dot_X_t device: {dot_X_t.device}, t device: {t.device}")

	velocity = rf_func.get_velocity(X_t, t)
	print(f"velocity: {velocity.shape}")

	loss = rf_func.get_loss(X_0, X_1)
	print(f"loss: {loss}")

	score1 = rf_func.get_score_function_from_velocity(X_t, velocity, t)
	score2 = rf_func.get_score_function(X_t, t)
	print(f"Max diff: {torch.max(torch.abs(score1 - score2))}")

# Test Solvers

# Test Coupling Dataset