In [None]:
%cd ..

In [None]:
import torch
import torchvision
from torchvision import transforms

from rectified_flow.models.gauss_analytic import AnalyticGaussianVelocity
from rectified_flow.flow_components.interpolation_solver import AffineInterp
from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.utils import plot_cifar_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
dataset = torchvision.datasets.CIFAR10(root="/root/autodl-tmp/cifar10", transform=transform, download=False)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=50000, shuffle=False)  
data_iter = iter(trainloader)
images, labels = next(data_iter)

images = images.to(device)
images = images.reshape(images.shape[0], -1)
print(images.shape)

In [None]:
interp = AffineInterp("ddim")

model = AnalyticGaussianVelocity(images, interp)

rf_func = RectifiedFlow(
    data_shape=(3*32*32,),
    model=model,
    interp=interp,
    device=device,
)

In [None]:
from rectified_flow.samplers import EulerSampler

sampler = EulerSampler(
    rectified_flow=rf_func,
    num_steps=100,
)

In [None]:
sampler.sample_loop(130, seed=0)

In [None]:
X_1 = sampler.trajectories[-1]

X_1 = X_1.reshape(-1, 3, 32, 32)

plot_cifar_results(X_1)