# Discrete Diffusion on 2‑D Trajectories `(a, r)`

This notebook is a **drop‑in replacement** for your original *discrete.ipynb*.
It keeps the same cell order and variable names as much as possible, but now
treats `a ∈ {0,…,19}` and `r ∈ {0,1}` as **two coupled chains** instead of
packing them into one token.

**Expected data** : a NumPy file `traj.npy` of shape `(N, L, 2)` where the last
dimension holds `(a, r)` pairs.
Run the notebook top‑to‑bottom to train and sample new trajectories; outputs are
saved as `generated_trajs.npy` with shape `(B, L, 2)`.

In [1]:
import numpy as np
import torch, math, random
from torch import nn
from torch.utils.data import Dataset, DataLoader

# ---- Config ----
K = 20
input_path   = f'traj_{K}.npy'
output_path  = f'traj_{K}_generated_pair.npy'
BATCH_SIZE   = 64
EPOCHS       = 30
SEQ_LEN      = None         # will infer after loading data
T_STEPS      = 12           # diffusion steps
LR           = 1e-3
device       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
raw = np.load(input_path)
SEQ_LEN = raw.shape[1]
print('Loaded', raw.shape[0], 'trajectories of length', SEQ_LEN)

# Split into two integer tensors
a_arr = raw[:, :, 0].astype(np.int64)   # (N, L)
r_arr = raw[:, :, 1].astype(np.int64)   # (N, L)

NUM_A = K   # classes for a
NUM_R = 2    # classes for r

Loaded 100 trajectories of length 50


In [3]:
class TrajDataset(Dataset):
    def __init__(self, a_data, r_data):
        self.a = torch.from_numpy(a_data).long()
        self.r = torch.from_numpy(r_data).long()
    def __len__(self):
        return self.a.size(0)
    def __getitem__(self, idx):
        return {
            'a': self.a[idx],   # (L,)
            'r': self.r[idx],   # (L,)
        }

dataset = TrajDataset(a_arr, r_arr)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [4]:
def forward_diffusion(x0_flat, betas, num_classes):
    """Return list [x_0, x_1, ..., x_T]"""
    traj = [x0_flat]
    x_prev = x0_flat
    for beta in betas:
        mask = torch.rand_like(x_prev.float()) < beta
        noise = torch.randint(0, num_classes, x_prev.shape, device=x_prev.device)
        x_next = torch.where(mask, noise, x_prev)
        traj.append(x_next)
        x_prev = x_next
    return traj

In [5]:
class DiscreteDiffusion(nn.Module):
    def __init__(self, hidden_dim=128, time_emb_dim=32):
        super().__init__()
        self.time_emb = nn.Embedding(1000, time_emb_dim)
        in_dim = NUM_A + NUM_R + time_emb_dim
        self.backbone = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
        )
        self.head_a = nn.Linear(hidden_dim, NUM_A)
        self.head_r = nn.Linear(hidden_dim, NUM_R)

    def forward(self, a_t, r_t, t):
        # a_t, r_t: [B*L] LongTensor; t: [B*L]
        a_one = torch.zeros(a_t.size(0), NUM_A, device=a_t.device)
        a_one.scatter_(1, a_t.unsqueeze(1), 1.)
        r_one = torch.zeros(r_t.size(0), NUM_R, device=r_t.device)
        r_one.scatter_(1, r_t.unsqueeze(1), 1.)
        h = torch.cat([a_one, r_one, self.time_emb(t)], dim=1)
        h = self.backbone(h)
        return self.head_a(h), self.head_r(h)

In [6]:
betas = [0.1] * T_STEPS
model = DiscreteDiffusion().to(device)
optim = torch.optim.Adam(model.parameters(), lr=LR)
ce = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    model.train()
    for batch in dataloader:
        a0 = batch['a'].to(device)    # (B,L)
        r0 = batch['r'].to(device)
        B, L = a0.shape
        a0_flat = a0.reshape(-1)
        r0_flat = r0.reshape(-1)

        traj_a = forward_diffusion(a0_flat, betas, NUM_A)
        traj_r = forward_diffusion(r0_flat, betas, NUM_R)

        t_bar = torch.randint(1, T_STEPS + 1, (B*L,), device=device)
        a_t = torch.stack([traj_a[t][i] for i, t in enumerate(t_bar)])
        r_t = torch.stack([traj_r[t][i] for i, t in enumerate(t_bar)])
        a_prev = torch.stack([traj_a[t-1][i] for i, t in enumerate(t_bar)])
        r_prev = torch.stack([traj_r[t-1][i] for i, t in enumerate(t_bar)])

        logits_a, logits_r = model(a_t, r_t, t_bar)
        loss = ce(logits_a, a_prev) + ce(logits_r, r_prev)

        optim.zero_grad()
        loss.backward()
        optim.step()

    print(f'Epoch {epoch+1}/{EPOCHS}  loss={loss.item():.4f}')

Epoch 1/30  loss=3.6819
Epoch 2/30  loss=3.6559
Epoch 3/30  loss=3.6343
Epoch 4/30  loss=3.6175
Epoch 5/30  loss=3.6046
Epoch 6/30  loss=3.5825
Epoch 7/30  loss=3.5665
Epoch 8/30  loss=3.5610
Epoch 9/30  loss=3.5356
Epoch 10/30  loss=3.5427
Epoch 11/30  loss=3.5406
Epoch 12/30  loss=3.5126
Epoch 13/30  loss=3.4952
Epoch 14/30  loss=3.5001
Epoch 15/30  loss=3.4783
Epoch 16/30  loss=3.4675
Epoch 17/30  loss=3.4519
Epoch 18/30  loss=3.4442
Epoch 19/30  loss=3.4396
Epoch 20/30  loss=3.4043
Epoch 21/30  loss=3.4149
Epoch 22/30  loss=3.3905
Epoch 23/30  loss=3.3748
Epoch 24/30  loss=3.3581
Epoch 25/30  loss=3.3448
Epoch 26/30  loss=3.3198
Epoch 27/30  loss=3.2995
Epoch 28/30  loss=3.2838
Epoch 29/30  loss=3.2490
Epoch 30/30  loss=3.2453


In [7]:
NUM_SAMPLES = 100   # <‑‑‑ set this to any N you like (e.g. 1000)

def sample(model, n_samples=NUM_SAMPLES):
    """Return tensor (N, SEQ_LEN, 2) on CPU."""
    model.eval()
    BATCH = n_samples
    with torch.no_grad():
        a_t = torch.randint(0, NUM_A, (BATCH, SEQ_LEN), device=device)
        r_t = torch.randint(0, NUM_R, (BATCH, SEQ_LEN), device=device)
        for t in reversed(range(1, T_STEPS + 1)):
            t_vec = torch.full((BATCH * SEQ_LEN,), t, device=device)
            logits_a, logits_r = model(a_t.reshape(-1), r_t.reshape(-1), t_vec)
            probs_a = torch.softmax(logits_a, dim=-1)
            probs_r = torch.softmax(logits_r, dim=-1)
            a_t = torch.multinomial(probs_a, 1).squeeze(-1).reshape(BATCH, SEQ_LEN)
            r_t = torch.multinomial(probs_r, 1).squeeze(-1).reshape(BATCH, SEQ_LEN)
        trajs = torch.stack([a_t.cpu(), r_t.cpu()], dim=-1)
        return trajs  # (N, L, 2)

trajs = sample(model, NUM_SAMPLES)
# concat the training data and the generated data
trajs = torch.cat([torch.from_numpy(raw), trajs], dim=0)
np.save(f'{output_path}', trajs.numpy())
print('Saved generated_trajs', trajs.shape)

Saved generated_trajs torch.Size([200, 50, 2])
