In [1]:
# =========================================================
# 1. Imports & Config
# =========================================================
import numpy as np
import torch, math, random
from torch import nn
from torch.utils.data import Dataset, DataLoader
K = 20
input_path   = f'traj_{K}.npy'
output_path  = f'traj_{K}_generated_seq_transformer.npy'
BATCH_SIZE   = 10
EPOCHS       = 400
T_STEPS      = 50          # diffusion steps
LR           = 2e-4
EMBED_DIM    = 64
N_LAYERS     = 4
N_HEADS      = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_R   = 2
NUM_TOK = K * NUM_R        # 每个 time-step 的组合 token 种数
# =========================================================
# 2. Data : (N, L, 2) → (N, L) integer tokens
# =========================================================
raw = np.load(input_path)               # (N, L, 2)
SEQ_LEN = raw.shape[1]
print('Loaded', raw.shape[0], 'trajectories of length', SEQ_LEN)
a_arr = raw[:, :, 0].astype(np.int64)
r_arr = raw[:, :, 1].astype(np.int64)
tok_arr = (a_arr * NUM_R + r_arr).astype(np.int64)   # (N, L)
class TrajDataset(Dataset):
    def __init__(self, tokens):
        self.tok = torch.from_numpy(tokens).long()
    def __len__(self):
        return self.tok.size(0)
    def __getitem__(self, idx):
        return self.tok[idx]          # (L,)
dataloader = DataLoader(TrajDataset(tok_arr), batch_size=BATCH_SIZE,
                        shuffle=True, drop_last=True)
# =========================================================
# 3. Discrete forward diffusion helper (token → token)
# =========================================================
def forward_diffusion(x0, betas, num_classes):
    """x0 : (B, L) LongTensor → list len T+1, each (B, L)"""
    traj = [x0]
    x_prev = x0
    for beta in betas:
        mask  = (torch.rand_like(x_prev.float()) < beta)
        noise = torch.randint(0, num_classes, x_prev.shape, device=x_prev.device)
        x_next = torch.where(mask, noise, x_prev)
        traj.append(x_next)
        x_prev = x_next
    return traj
# =========================================================
# 4. Model : sequence‐aware discrete diffusion network
# =========================================================
class SeqDiscreteDiffusion(nn.Module):
    def __init__(self, num_tok=NUM_TOK, seq_len=SEQ_LEN, embed_dim=EMBED_DIM, n_layers=N_LAYERS, n_heads=N_HEADS):
        super().__init__()
        self.token_emb = nn.Embedding(num_tok, embed_dim)
        self.pos_emb   = nn.Embedding(seq_len, embed_dim)
        self.time_emb  = nn.Embedding(1000, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_heads, dim_feedforward=embed_dim*4, activation='relu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(embed_dim, num_tok)
    def forward(self, x_t, t):
        """
        x_t : (B, L)  LongTensor
        t   : (B,)    LongTensor  (same t for the whole seq)
        """
        B, L = x_t.shape
        tok  = self.token_emb(x_t)                         # (B,L,E)
        pos  = self.pos_emb(torch.arange(L, device=x_t.device))  # (L,E)
        pos  = pos.unsqueeze(0).expand(B, -1, -1)          # (B,L,E)
        time = self.time_emb(t).unsqueeze(1).expand(-1, L, -1)   # (B,L,E)
        h = tok + pos + time                               # (B,L,E)
        h = self.transformer(h)                            # (B,L,E)
        logits = self.head(h)                              # (B,L,NUM_TOK)
        return logits
# =========================================================
# 5. Training loop
# =========================================================
betas = [0.1] * T_STEPS
model = SeqDiscreteDiffusion().to(device)
optim = torch.optim.Adam(model.parameters(), lr=LR)
ce = nn.CrossEntropyLoss()
for epoch in range(EPOCHS):
    model.train()
    for x0 in dataloader:          # x0 : (B,L)
        x0 = x0.to(device)
        B, L = x0.shape
        # forward diffusion once per batch (global t for the whole seq)
        traj  = forward_diffusion(x0, betas, NUM_TOK)      # list len T+1
        t_bar = torch.randint(1, T_STEPS + 1, (B,), device=device)  # (B,)
        x_t   = torch.stack([traj[t][i]   for i,t in enumerate(t_bar)])  # (B,L)
        x_prev= torch.stack([traj[t-1][i] for i,t in enumerate(t_bar)])  # (B,L)
        logits = model(x_t, t_bar)                         # (B,L,NUM_TOK)
        loss   = ce(logits.reshape(-1, NUM_TOK), x_prev.reshape(-1))
        optim.zero_grad()
        loss.backward()
        optim.step()
    print(f'Epoch {epoch+1}/{EPOCHS}  loss={loss.item():.4f}')
# =========================================================
# 6. Sampling (reverse diffusion)
# =========================================================
NUM_SAMPLES = 100   # 可自行调整
def sample(model, n_samples=NUM_SAMPLES):
    model.eval()
    B = n_samples
    with torch.no_grad():
        x_t = torch.randint(0, NUM_TOK, (B, SEQ_LEN), device=device)
        for t in reversed(range(1, T_STEPS + 1)):
            t_vec = torch.full((B,), t, device=device)
            logits = model(x_t, t_vec)                     # (B,L,NUM_TOK)
            probs  = torch.softmax(logits, dim=-1)
            x_t    = torch.multinomial(probs.view(-1, NUM_TOK), 1).squeeze(-1).view(B, SEQ_LEN)      # (B,L)
        return x_t.cpu()                                   # (B,L)
gen_tok = sample(model, NUM_SAMPLES)                       # (N,L)
# 解码回 (a,r) 二元组
gen_a = (gen_tok // NUM_R).numpy()
gen_r = (gen_tok %  NUM_R).numpy()
gen_traj = np.stack([gen_a, gen_r], axis=-1)               # (N,L,2)
# concat the training data `raw`` and the generated `data gen_traj`
trajs = np.concatenate([raw, gen_traj], axis=0)          # (N+200, L, 2)
np.save(output_path, trajs)
print('Saved generated trajectories:', trajs.shape)

Loaded 100 trajectories of length 50
Epoch 1/400  loss=3.7122
Epoch 2/400  loss=3.5170
Epoch 3/400  loss=3.3124
Epoch 4/400  loss=3.1706
Epoch 5/400  loss=3.0624
Epoch 6/400  loss=2.9313
Epoch 7/400  loss=2.6812
Epoch 8/400  loss=2.4647
Epoch 9/400  loss=2.2802
Epoch 10/400  loss=2.1173
Epoch 11/400  loss=1.9006
Epoch 12/400  loss=1.8155
Epoch 13/400  loss=1.6987
Epoch 14/400  loss=1.5736
Epoch 15/400  loss=1.5050
Epoch 16/400  loss=1.3765
Epoch 17/400  loss=1.1983
Epoch 18/400  loss=1.2376
Epoch 19/400  loss=1.1423
Epoch 20/400  loss=1.1418
Epoch 21/400  loss=0.9692
Epoch 22/400  loss=0.9411
Epoch 23/400  loss=1.0018
Epoch 24/400  loss=0.9949
Epoch 25/400  loss=0.9223
Epoch 26/400  loss=0.8308
Epoch 27/400  loss=0.8692
Epoch 28/400  loss=0.9680
Epoch 29/400  loss=0.8575
Epoch 30/400  loss=0.8404
Epoch 31/400  loss=0.8392
Epoch 32/400  loss=0.8902
Epoch 33/400  loss=0.8407
Epoch 34/400  loss=0.8412
Epoch 35/400  loss=0.7769
Epoch 36/400  loss=0.8063
Epoch 37/400  loss=0.7505
Epoch 38/4