In [1]:
import os
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import clip

In [2]:
mean = torch.from_numpy(np.load("KIT-ML/Mean.npy")).float()
std  = torch.from_numpy(np.load("KIT-ML/Std.npy")).float()

In [None]:
class KITMotionDataset(Dataset):
    def __init__(
        self,
        root,
        split="train",
        max_len=None,
        normalize=True,
        use_cache=True
    ):
        self.motion_dir = os.path.join(root, "new_joint_vecs")
        self.text_dir = os.path.join(root, "texts")
        split_file = os.path.join(root, f"{split}.txt")
        
        # Cache file path
        cache_file = os.path.join(root, f"cached_{split}_maxlen{max_len}_norm{normalize}.pt")

        # Try to load from cache
        if use_cache and os.path.exists(cache_file):
            print(f"Loading preprocessed dataset from {cache_file}...")
            cached_data = torch.load(cache_file)
            self.ids = cached_data['ids']
            self.motions = cached_data['motions']
            self.texts = cached_data['texts']
            self.max_len = cached_data['max_len']
            self.normalize = cached_data['normalize']
            self.mean = cached_data.get('mean', None)
            self.std = cached_data.get('std', None)
            print(f"Loaded {len(self.ids)} preprocessed motions from cache")
            return

        # Otherwise, process from scratch
        with open(split_file, "r") as f:
            all_ids = [line.strip() for line in f]

        # Filter out missing files
        self.ids = []
        missing_count = 0
        for mid in all_ids:
            path = os.path.join(self.motion_dir, f"{mid}.npy")
            if os.path.exists(path):
                self.ids.append(mid)
            else:
                missing_count += 1
        
        if missing_count > 0:
            print(f"Warning: {missing_count}/{len(all_ids)} motion files missing from {split} split")
        print(f"Loaded {len(self.ids)} valid motions from {split} split")

        self.max_len = max_len
        self.normalize = normalize

        
        # PRELOAD all motions and texts into memory
        self.motions = {}
        self.texts = {}

        print("Preloading motions and texts into memory...")
        
        self.mean = mean
        self.std = std
        
        for mid in self.ids:
            # Load motion
            motion_path = os.path.join(self.motion_dir, f"{mid}.npy")
            motion = np.load(motion_path)  # (T, F)
            
            # Truncate
            if self.max_len is not None:
                motion = motion[:self.max_len]
            
            motion = torch.from_numpy(motion).float()  # (T, F)
            
            # Normalize
            motion = (motion - self.mean) / (self.std + 1e-8)
            
            self.motions[mid] = motion
            
            # Load text description
            text_path = os.path.join(self.text_dir, f"{mid}.txt")
            if os.path.exists(text_path):
                with open(text_path, "r") as f:
                    text_line = f.read().strip()
                    # Format: "description#tokens#..." - extract first part
                    text = text_line.split("#")[0]
                    self.texts[mid] = text
            else:
                self.texts[mid] = "a person is moving"
        print(f"Preloaded {len(self.motions)} motions and {len(self.texts)} texts")
            
        # Save to cache
        if use_cache:
            print(f"Saving preprocessed dataset to {cache_file}...")
            torch.save({
                'ids': self.ids,
                'motions': self.motions,
                'texts': self.texts,
                'max_len': self.max_len,
                'normalize': self.normalize,
                'mean': self.mean,
                'std': self.std
            }, cache_file)
            print("Cache saved!")

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        mid = self.ids[idx]
        motion = self.motions[mid]  # Already preprocessed!
        text = self.texts[mid]

        return {
            "motion": motion,
            "length": motion.shape[0],
            "text": text,
            "id": mid
        }

In [None]:
def collate_motion(batch):
    motions = [b["motion"] for b in batch]     # (T,F)
    lengths = torch.tensor([b["length"] for b in batch])
    texts = [b["text"] for b in batch]
    ids = [b["id"] for b in batch]

    B = len(motions)
    T_max = int(lengths.max())
    Fdim = motions[0].shape[1]

    padded = torch.zeros(B, T_max, Fdim)
    mask = torch.zeros(B, T_max, dtype=torch.bool)

    for i, m in enumerate(motions):
        T = m.shape[0]
        padded[i, :T] = m
        mask[i, :T] = 1

    return {"motion": padded, "mask": mask, "lengths": lengths, "texts": texts, "ids": ids}

In [5]:
dataset = KITMotionDataset(
    root="KIT-ML",
    split="train",
    max_len=196,
)

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_motion,
)


Loading preprocessed dataset from KIT-ML\cached_train_maxlen196_normTrue.pt...
Loaded 4886 preprocessed motions from cache


In [6]:
def timestep_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(0, half, device=timesteps.device).float() / (half - 1)
    )
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb

In [None]:
class DiffusionSchedule:
    def __init__(self, T: int = 1000, beta_start=1e-4, beta_end=2e-2, device="cpu"):
        self.T = T
        betas = torch.linspace(beta_start, beta_end, T, device=device)
        alphas = 1.0 - betas
        alpha_bar = torch.cumprod(alphas, dim=0)

        self.betas = betas
        self.alphas = alphas
        self.alpha_bar = alpha_bar
        self.sqrt_alpha_bar = torch.sqrt(alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - alpha_bar)

    def q_sample(self, x0, t, noise):
        B = x0.shape[0]
        # make (B, 1, 1, ..., 1) to broadcast over all non-batch dims
        broadcast_shape = [B] + [1] * (x0.ndim - 1)
        s1 = self.sqrt_alpha_bar[t].view(*broadcast_shape)
        s2 = self.sqrt_one_minus_alpha_bar[t].view(*broadcast_shape)
        return s1 * x0 + s2 * noise


In [8]:
class MotionDenoiserTransformer(nn.Module):
    """
    Drop-in replacement that fixes the key issue: adds FRAME positional embeddings
    (sequence-time / frame index), in addition to your existing diffusion timestep embedding.

    Same call signature:
        eps_hat = model(x_t, t, text_emb, mask=mask)

    Assumes mask is (B, T) with True=valid, False=padding.
    """

    def __init__(
        self,
        num_joints: int = 21,
        coord_dim: int = 3,
        d_model: int = 512,
        n_layers: int = 8,
        n_heads: int = 8,
        dropout: float = 0.1,
        time_embed_dim: int = 512,
        clip_dim: int = 512,
        max_frames: int = 256,   # <-- IMPORTANT: set >= your dataset max sequence length
    ):
        super().__init__()
        self.num_joints = num_joints
        self.coord_dim = coord_dim
        self.d_model = d_model
        self.max_frames = max_frames

        in_dim = num_joints * coord_dim

        # Frame encoder: (J*3) -> d_model
        self.frame_in = nn.Sequential(
            nn.Linear(in_dim, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )

        # Diffusion timestep embedding -> d_model
        self.time_mlp = nn.Sequential(
            nn.Linear(time_embed_dim, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )

        # Text embedding projection -> d_model (single token)
        self.text_proj = nn.Sequential(
            nn.Linear(clip_dim, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )

        # --- NEW: learnable FRAME positional embedding (sequence-time) ---
        # We prepend 1 text token, so allocate max_frames + 1 positions.
        self.pos_emb = nn.Parameter(torch.zeros(1, max_frames + 1, d_model))
        nn.init.normal_(self.pos_emb, std=0.02)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.tr = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Decode back to (J*3)
        self.frame_out = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, in_dim),
        )

    def forward(
        self,
        x_t: torch.Tensor,
        t: torch.Tensor,
        text_emb: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        x_t: (B, T, J, C)
        t:   (B,) diffusion timestep indices
        text_emb: (B, clip_dim)
        mask: (B, T) True=valid, False=padding
        """
        B, Tm, J, C = x_t.shape
        assert J == self.num_joints and C == self.coord_dim

        if Tm > self.max_frames:
            raise ValueError(
                f"Sequence length T={Tm} exceeds max_frames={self.max_frames}. "
                f"Increase max_frames in MotionDenoiserTransformer."
            )

        # Frame tokens
        x = x_t.reshape(B, Tm, J * C)
        h = self.frame_in(x)  # (B, T, d_model)

        # Diffusion timestep embedding (same added to every frame token)
        te = timestep_embedding(t, dim=self.time_mlp[0].in_features)  # (B, time_embed_dim)
        te = self.time_mlp(te).unsqueeze(1)                           # (B, 1, d_model)
        h = h + te                                                    # broadcast over frames

        # Text token
        text_tok = self.text_proj(text_emb).unsqueeze(1)              # (B, 1, d_model)

        # Concatenate [text, frames]
        h = torch.cat([text_tok, h], dim=1)                           # (B, 1+T, d_model)

        # --- NEW: add positional embedding for (text + frames) ---
        h = h + self.pos_emb[:, : (1 + Tm), :]

        # Build key padding mask: PyTorch expects True where padding
        if mask is not None:
            mask = mask.to(device=x_t.device)
            text_valid = torch.ones(B, 1, dtype=torch.bool, device=x_t.device)
            full_valid = torch.cat([text_valid, mask], dim=1)         # (B, 1+T)
            src_key_padding_mask = ~full_valid                        # True where padding
        else:
            src_key_padding_mask = None

        # Transformer encode
        h = self.tr(src=h, src_key_padding_mask=src_key_padding_mask)  # (B, 1+T, d_model)

        # Drop text token
        h = h[:, 1:, :]                                                # (B, T, d_model)

        # Predict noise (or x0) in original motion shape
        y = self.frame_out(h)                                          # (B, T, J*C)
        return y.reshape(B, Tm, J, C)

In [9]:
class MotionDDPM(nn.Module):
    def __init__(self, denoiser: nn.Module, schedule):
        super().__init__()
        self.denoiser = denoiser
        self.schedule = schedule

    def training_step(self, x0: torch.Tensor, text_emb: torch.Tensor, mask: torch.Tensor | None = None):
        B = x0.shape[0]
        device = x0.device

        t = torch.randint(0, self.schedule.T, (B,), device=device)
        noise = torch.randn_like(x0)  # DO NOT mask noise here

        x_t = self.schedule.q_sample(x0, t, noise)
        eps_hat = self.denoiser(x_t, t, text_emb, mask=mask)

        if mask is None:
            return F.mse_loss(eps_hat, noise)

        # mask is (B,T) with True=valid
        m = mask.to(device=device).float()  # (B,T)
        m = m.unsqueeze(-1).unsqueeze(-1)   # (B,T,1,1)

        se = (eps_hat - noise) ** 2
        se = se * m

        denom = m.sum() * x0.shape[2] * x0.shape[3]  # valid frames * (J*C)
        loss = se.sum() / denom.clamp(min=1e-8)
        return loss

In [10]:
@torch.no_grad()
def ddpm_sample(
    model,
    shape,
    text_emb: torch.Tensor,
    mask: torch.Tensor | None = None,
    device: str | torch.device = "cpu",
):
    B, Tm, J, C = shape
    sched = model.schedule
    device = torch.device(device)

    # Move schedule tensors to the right device/dtype once
    betas = sched.betas.to(device=device)
    alphas = sched.alphas.to(device=device)
    alpha_bar = sched.alpha_bar.to(device=device)

    # Start from pure noise x_T
    x = torch.randn(shape, device=device)

    # Optional: keep padding untouched during the chain (recommended: don't hard-zero every step)
    # We'll apply masking only at the end. Still pass mask into denoiser for attention masking.
    mask_broadcast = None
    if mask is not None:
        # Expect mask True=valid; convert to float broadcast (B,T,1,1)
        mask_broadcast = mask.to(device=device).unsqueeze(-1).unsqueeze(-1).float()

    for ti in reversed(range(sched.T)):
        t = torch.full((B,), ti, device=device, dtype=torch.long)

        # Predict noise epsilon_hat
        eps_hat = model.denoiser(x, t, text_emb.to(device=device), mask=mask.to(device=device) if mask is not None else None)

        beta_t = betas[ti]
        alpha_t = alphas[ti]
        alpha_bar_t = alpha_bar[ti]

        # DDPM mean: μθ(x_t,t) = 1/sqrt(α_t) * (x_t - (β_t / sqrt(1-ᾱ_t)) * εθ )
        coef1 = torch.rsqrt(alpha_t)  # 1/sqrt(alpha_t)
        coef2 = beta_t / torch.sqrt(torch.clamp(1.0 - alpha_bar_t, min=1e-12))
        mean = coef1 * (x - coef2 * eps_hat)

        if ti > 0:
            # Posterior variance (tilde beta): β̃_t = β_t * (1-ᾱ_{t-1})/(1-ᾱ_t)
            alpha_bar_prev = alpha_bar[ti - 1]
            beta_tilde = beta_t * (1.0 - alpha_bar_prev) / torch.clamp(1.0 - alpha_bar_t, min=1e-12)
            beta_tilde = torch.clamp(beta_tilde, min=1e-20)  # numerical safety

            z = torch.randn_like(x)
            x = mean + torch.sqrt(beta_tilde) * z
        else:
            x = mean

    # Crop to actual lengths instead of masking
    # Return list of tensors with different lengths
    if mask is not None:
        lengths = mask.sum(dim=1).cpu().long()  # (B,)
        cropped = []
        for i in range(B):
            cropped.append(x[i, :lengths[i]])  # (T_i, J, C)
        return cropped
    else:
        return x

In [11]:
import ssl

from matplotlib.pylab import std
from numpy import mean
ssl._create_default_https_context = ssl._create_unverified_context

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Load CLIP model
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_model.eval()  # Freeze CLIP
for param in clip_model.parameters():
    param.requires_grad = False

print("CLIP model loaded")

num_joints = 21
denoiser = MotionDenoiserTransformer(
    num_joints=num_joints, 
    d_model=512, 
    n_layers=8, 
    n_heads=8,
    clip_dim=512  # CLIP ViT-B/32 embedding dimension
).to(device)

schedule = DiffusionSchedule(T=1000, device=device)
ddpm = MotionDDPM(denoiser, schedule).to(device)

opt = torch.optim.AdamW(ddpm.parameters(), lr=2e-4, weight_decay=1e-4)

num_epochs = 100

sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epochs * len(loader), eta_min=1e-6)

# Create output directory for samples
output_dir = "kit-ml-diffusion"
os.makedirs(output_dir, exist_ok=True)

global_step = 0
for epoch in range(num_epochs):
    for batch in loader:
        x0 = batch["motion"].to(device)     # (B,T,J,3)
        mask = batch["mask"].to(device)     # (B,T)
        texts = batch["texts"]              # List of strings

        # Encode text with CLIP
        text_tokens = clip.tokenize(texts, truncate=True).to(device)
        with torch.no_grad():
            text_emb = clip_model.encode_text(text_tokens).float()  # (B, 512)
            text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
        loss = ddpm.training_step(x0, text_emb, mask=mask)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(ddpm.parameters(), 1.0)
        opt.step()
        sched.step()
        
        if global_step % 100 == 0:
            print(f"Epoch {epoch}, Step {global_step}, Loss: {float(loss):.4f}, LR: {sched.get_last_lr()[0]:.6e}")
            
        # Sample occasionally and save
        if global_step % 1000 == 0 and global_step > 0:
            # Use text from first sample in batch for generation
            samp_list = ddpm_sample(ddpm, x0.shape, text_emb, mask=mask, device=device)
            
            # Get first sample (now a cropped tensor)
            samp = samp_list[0]  # (T_actual, J, C)
            
            # Denormalize using dataset statistics
            if dataset.mean is not None and dataset.std is not None:
                mean_ = dataset.mean.to(samp.device, samp.dtype)
                std_ = dataset.std.to(samp.device, samp.dtype)
                samp = samp * std_ + mean_

            print(f"  Generated sample shape: {samp.shape}")
            print(f"  Conditioned on: {texts[0]}")
            
            # Save first sample from batch
            motion_np = samp.cpu().numpy()  # (T_actual, J, 3) - no padding!
            save_path = os.path.join(output_dir, f"sample_step{global_step:06d}.npy")
            np.save(save_path, motion_np)
            
            # Save text prompt
            text_path = os.path.join(output_dir, f"sample_step{global_step:06d}.txt")
            with open(text_path, "w") as f:
                f.write(texts[0])
            
            print(f"  Saved to {save_path}")
            
        global_step += 1
    
    print(f"Completed epoch {epoch+1}/{num_epochs}")

print("Training complete!")

Using device: cuda
CLIP model loaded


Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\python_variable_methods.cpp:837.)
  print(f"Epoch {epoch}, Step {global_step}, Loss: {float(loss):.4f}, LR: {sched.get_last_lr()[0]:.6e}")


Epoch 0, Step 0, Loss: 1.3912, LR: 2.000000e-04
Epoch 0, Step 100, Loss: 0.3408, LR: 1.999786e-04
Completed epoch 1/100
Epoch 1, Step 200, Loss: 0.1287, LR: 1.999153e-04
Epoch 1, Step 300, Loss: 0.1877, LR: 1.998100e-04
Completed epoch 2/100
Epoch 2, Step 400, Loss: 0.1249, LR: 1.996629e-04
Completed epoch 3/100
Epoch 3, Step 500, Loss: 0.0801, LR: 1.994740e-04
Epoch 3, Step 600, Loss: 0.1950, LR: 1.992433e-04
Completed epoch 4/100
Epoch 4, Step 700, Loss: 0.1515, LR: 1.989710e-04
Completed epoch 5/100
Epoch 5, Step 800, Loss: 0.1019, LR: 1.986572e-04
Epoch 5, Step 900, Loss: 0.0963, LR: 1.983021e-04
Completed epoch 6/100
Epoch 6, Step 1000, Loss: 0.1681, LR: 1.979057e-04
  Generated sample shape: torch.Size([119, 21, 3])
  Conditioned on: a person break dances
  Saved to kit-ml-diffusion\sample_step001000.npy
Completed epoch 7/100
Epoch 7, Step 1100, Loss: 0.0673, LR: 1.974682e-04
Epoch 7, Step 1200, Loss: 0.1310, LR: 1.969898e-04
Completed epoch 8/100
Epoch 8, Step 1300, Loss: 0.0982