In [1]:
import torch
from rectified_flow.rectified_flow import AffineInterp, match_time_dim_with_data

device = torch.device("cuda")

# Test AffineInterp Class

In [2]:
batch_size = 500

def test_affine_interp():
    interp = AffineInterp('ddim') # Change this into ["ddim", "straight", "sin"]
    t = torch.rand((batch_size,), device=device)
    X_0 = torch.rand((batch_size, 3, 4, 2, 3), device=device)
    X_1 = torch.rand((batch_size, 3, 4, 2, 3), device=device)
    X_t, dot_X_t = interp(X_0, X_1, t)

    interp.solve(t, xt=X_t, dot_xt=dot_X_t)
    print(torch.max(torch.abs(interp.x0 - X_0)), torch.mean(torch.abs(interp.x1 - X_1)))
    
    interp.solve(t, x0=X_0, x1=X_1)
    print(torch.max(torch.abs(interp.xt - X_t)), torch.mean(torch.abs(interp.dot_xt - dot_X_t)))
    
    interp.solve(t, xt=X_t, x1=X_1)
    print(torch.max(torch.abs(interp.x0 - X_0)), torch.mean(torch.abs(interp.dot_xt - dot_X_t)))
    
    interp.solve(t, x0=X_0, xt=X_t)
    print(torch.max(torch.abs(interp.x1 - X_1)), torch.mean(torch.abs(interp.dot_xt - dot_X_t)))
    
    interp.solve(t, x0=X_0, dot_xt=dot_X_t)
    print(torch.max(torch.abs(interp.x1 - X_1)), torch.mean(torch.abs(interp.xt - X_t)))
    
    interp.solve(t, x1=X_1, dot_xt=dot_X_t)
    print(torch.max(torch.abs(interp.x0 - X_0)), torch.mean(torch.abs(interp.xt - X_t)))
    
    print(type(interp.x0), type(interp.x1), type(interp.xt), type(interp.dot_xt))

test_affine_interp()

tensor(2.3842e-07, device='cuda:0') tensor(2.3848e-08, device='cuda:0')
tensor(0., device='cuda:0') tensor(0., device='cuda:0')
tensor(3.9153e-06, device='cuda:0') tensor(2.0821e-07, device='cuda:0')
tensor(7.2122e-06, device='cuda:0') tensor(1.4193e-07, device='cuda:0')
tensor(4.1127e-06, device='cuda:0') tensor(5.3170e-08, device='cuda:0')
tensor(7.6294e-06, device='cuda:0') tensor(4.0390e-07, device='cuda:0')
<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>


# Test Rectified Flow Functions

In [3]:
from models.dit import DiT, DiTConfig
from rectified_flow.rectified_flow import RectifiedFlow

DiT_reshaper_config = DiTConfig(
        input_size = 32,
        patch_size = 2,
        in_channels = 3,
        out_channels = 3,
        hidden_size = 128,
        depth = 5,
        num_heads = 4,
        mlp_ratio = 2,
        num_classes = 0,
        use_long_skip = False,
        final_conv = False,
    )

In [4]:
dit_toy = DiT(DiT_reshaper_config).to(device)

rf_func = RectifiedFlow(
    flow_model=dit_toy,
    interp_func="straight",
    time_dist="uniform",
    time_weight="uniform",
    criterion="mse",
    device=device,
    dtype=torch.float32,
)

In [5]:
batch_size = 64

X_0 = torch.zeros((batch_size, 3, 32, 32), device=device)
X_1 = torch.ones((batch_size, 3, 32, 32), device=device)
t = torch.rand((batch_size,), device=device)

with torch.no_grad():
	X_t, dot_X_t = rf_func.get_interpolation(X_0, X_1, t)
	# print(f"X_t: {X_t}, dot_X_t: {dot_X_t}")
	print(f"X_t device: {X_t.device}, dot_X_t device: {dot_X_t.device}, t device: {t.device}")
	velocity = rf_func.get_velocity(X_t, t)
	print(f"velocity: {velocity.shape}")

	loss = rf_func.get_loss(X_0, X_1)
	print(f"loss: {loss}")



X_t device: cuda:0, dot_X_t device: cuda:0, t device: cuda:0
X device: cuda:0, t device: cuda:0
velocity: torch.Size([64, 3, 32, 32])
X device: cuda:0, t device: cuda:0
X_t shape: torch.Size([64, 3, 32, 32]), dot_Xt shape: torch.Size([64, 3, 32, 32]), v_t shape: torch.Size([64, 3, 32, 32]), wts shape: torch.Size([64])
loss: 1.0
