In [1]:
import os
os.environ["VISIBLE_CUDA_DEVICES"] = "0"

In [2]:
import torch
from torch import nn, Tensor
from sklearn.datasets import make_circles
import matplotlib.pyplot as plt

In [3]:
import torch
print(torch.__version__)
print(torch.version.cuda)  # Thông tin version CUDA mà PyTorch build cùng
print(torch.backends.cudnn.enabled)
# pip uninstall torch torchvision torchaudio -y
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129


2.9.1+cu128
12.8
True


In [4]:

class Flow(nn.Module):
    def __init__(self, dim = 2, h = 256, v = 128):
        super().__init__()
        # vocab size
        self.v = v
        
        # embed into 
        self.embed = nn.Embedding(v, h)
        self.net = nn.Sequential(
            nn.Linear(dim * h + 1, h),
            nn.SELU(),
            nn.Linear(h, h),
            nn.SELU(),
            nn.Linear(h, dim * v)
        )
    
    def forward(self, x_t: Tensor, t: Tensor):
        return self.net(torch.cat((t[:, None], self.embed(x_t).flatten(1,2 )), -1)).\
            reshape(list(x_t.shape) + [self.v])

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 256
vocab_size = 1280
dim = 2

flow = Flow(dim=dim, h=128, v=vocab_size).to(device)
optim = torch.optim.Adam(flow.parameters(), lr=1e-3)


In [11]:
# Prepare source distribution



factor=0.5
noise=0.02

# Generate circles dataset
# X, y = make_circles(n_samples=640, factor=0.5, noise=0.05)
# print(X)
# print(y)
# # Plot
# plt.figure(figsize=(6,6))
# plt.scatter(X[:, 0], X[:, 1], c=y, cmap='bwr', edgecolor='k', s=50)
# plt.title("2D Visualization of make_circles Dataset")
# plt.xlabel("X1")
# plt.ylabel("X2")
# plt.axis('equal')
# plt.show()


In [1]:
epoch = 100000

for _ in range(epoch):
    X_1, __ = make_circles(n_samples=batch_size, factor=factor, noise=noise, random_state=0)
    # print(X_1.shape)
    # print(X_1)
    X_1 = Tensor(X_1).to(device=device)
    X_1 = X_1 * (vocab_size / 2) + vocab_size / 2
    X_1 = torch.clamp(X_1, min=0.0, max=vocab_size - 1)
    X_1 = torch.round(X_1).long()
    # print(X_1.shape)
    # print(X_1)
    X_0 = torch.randint(low=0, high=vocab_size, size=(batch_size, 2)).to(device=device)
    # print(X_0)
    t = torch.rand(batch_size).to(device=device)
    X_t = torch.where(torch.rand(batch_size, 2).to(device=device) < t[:, None], X_1, X_0)
    # print(t)
    logits = flow(X_t, t)
    loss = nn.functional.cross_entropy(logits.flatten(0,1), X_1.flatten(0, 1)).mean()

    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if _ % 1000 == 0:
        print(f'Epoch {_}: {loss.item()}')    

NameError: name 'make_circles' is not defined

In [19]:
sample_batch = 200
x_t = torch.randint(0, high=vocab_size, size=(sample_batch, dim)).to(device=device)
t = 0.0
results = [(x_t.clone(), t)]

while t < 1.0 - 1e-3:
    tt = torch.full((sample_batch,), t, device=device)
    p1 = nn.functional.softmax(flow(x_t, tt), dim=-1)  # (batch, dim, vocab)

    one_hot_x_t = nn.functional.one_hot(x_t, num_classes=vocab_size).float()  # (batch, dim, vocab)
    h = min(0.001, 1.0 - t)
    u = (p1 - one_hot_x_t) / max(1e-8, (1.0 - t))

    probs = one_hot_x_t + h * u
    probs = torch.clamp(probs, min=1e-9)
    probs = probs / probs.sum(dim=-1, keepdim=True)

    # sample new discrete tokens
    # Categorical expects probs shape (..., num_categories) and returns indices
    # flatten batch and dim to sample at once, then reshape
    b, D, V = probs.shape
    probs_flat = probs.view(b * D, V)
    cat = torch.distributions.Categorical(probs=probs_flat)
    samples_flat = cat.sample()  # (b*D,)
    x_t = samples_flat.view(b, D)  # (b, D)

    t = t + h
    results.append((x_t.clone(), t))
  
print(len(results[0][0]))
# for x_t, t in results:
#   render_result(x_t)
# fig, axs = plt.subplots(1, len(results)//10, figsize=(4*len(results)//10, 4))
# for ax, (x_t, t) in zip(axs, results):
#     ax.scatter(x_t[:,0], x_t[:,1], c='blue', alpha=0.6)
#     ax.set_title(f"t={t}")
#     ax.grid(True)
# plt.show()

200


In [20]:
gif_images = []
os.makedirs("tmp_frames", exist_ok=True)

for i, (x_t, t) in enumerate(results):
    # Vẽ ảnh
    fig, ax = plt.subplots(figsize=(4,4))
    x = x_t[:, 0].cpu().numpy()
    y = x_t[:, 1].cpu().numpy()
    ax.scatter(x, y, c='blue', alpha=0.6)
    ax.set_title(f"t={t}")
    ax.set_xlabel("x_1")
    ax.set_ylabel("x_2")
    ax.grid(True)
    
    # Lưu từng frame
    frame_path = f"tmp_frames/frame_{i}.png"
    plt.savefig(frame_path)
    gif_images.append(frame_path)
    plt.close(fig)
import imageio
# Xuất GIF
with imageio.get_writer('scatter_animation.gif', mode='I', duration=1) as writer:
    for filename in gif_images:
        image = imageio.imread(filename)
        writer.append_data(image)
        os.remove(filename)



  image = imageio.imread(filename)
