#### Music-conditioned human motion generation using AIST++ dataset

AIST++ dataset contains 1408 sequences of 3D human dance motion, each sequence with duration 7-48 seconds, along with corresponding music.

Task: Flow-matching for generating realistic human dance motion conditioned on a music segment, and potentially on a seed motion sequence.

Input: Segment of music -> embed with some pretrained model

Output: Sequence of 3D Human body keypoints data of shape (N frames x 17 joints x 3 coordinates) following COCO format at 60Fps. (sample visualization below)

<img src="https://github.com/thelmn/flow-diffusion/blob/master/aistpp-dataset-dance-screen.png?raw=1" width="350">

In [None]:
import os
import sys

from pathlib import Path
import glob

In [None]:
import numpy as np
import pickle

import jax
import jax.numpy as jnp

import torch


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# download AIST++ keypoints3d data from https://google.github.io/aistplusplus_dataset/download.html
# download AIST++ music dataset from https://github.com/Garfield-kh/TM2D

In [None]:
# HOME_PATH = Path("data/aist_plusplus")
HOME_PATH = Path("drive/MyDrive/datasets/aist_plusplus")


In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)

In [None]:
MUSIC_FOLDER = HOME_PATH / 'all_music'
KEYPOINTS_FOLDER = HOME_PATH / 'keypoints3d'
# MOTIONS_FOLDER = HOME_PATH / 'motions'

music_list = sorted(glob.glob(f'{MUSIC_FOLDER}/*.wav'))
keypoints_list = sorted(glob.glob(f'{KEYPOINTS_FOLDER}/*.pkl'))

print(len(music_list))
print(music_list[:3])
print(len(keypoints_list))
print(keypoints_list[:3])

In [None]:
try:
  import librosa
except ImportError:
  !pip install librosa
  import librosa

try:
  import pywt
except ImportError:
  !pip install PyWavelets
  import pywt

try:
  import soundfile as sf
except ImportError:
  !pip install soundfile
  import soundfile as sf

In [None]:
with open(keypoints_list[0], 'rb') as f:
  sample_keypoints = pickle.load(f)
print(sample_keypoints.keys())
print(sample_keypoints[list(sample_keypoints.keys())[0]].shape)
print(sample_keypoints[list(sample_keypoints.keys())[1]].shape)


##### Embed (and save) music data with pretrained model (MusicFM)

In [None]:
EMBED_MODEL_SR = 24000  # 24kHz
KEYPOINT_FRAME_RATE = 60  # 60 fps

In [None]:
y, sr = librosa.load(music_list[0])
y_len_s = y.shape[0] / sr

print('loaded: ', music_list[0], y.shape, 'sr:', sr, 'len (sec):', y_len_s)

y_resampled = librosa.resample(y, orig_sr=sr, target_sr=EMBED_MODEL_SR)
print(y.shape, y_resampled.shape)
# save
sf.write(HOME_PATH/f'sample_resampled_{EMBED_MODEL_SR}.wav', y_resampled, EMBED_MODEL_SR)

In [None]:
# plot sample
plt.figure(figsize=(10, 4))
plt.plot(y_resampled[:2*EMBED_MODEL_SR])
plt.title('Audio sample')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.show()

In [None]:
# load musicfm model

In [None]:
!mkdir -p $(pwd)/models/musicfm/data/
!wget -P $(pwd)/models/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json
!wget -P $(pwd)/models/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt

In [None]:
!git clone https://github.com/minzwon/musicfm.git $(pwd)/models/musicfm/musicfm

In [None]:
!pip install einops

In [None]:
MUSICFM_PATH = './models/musicfm/'
sys.path.append(MUSICFM_PATH)

from musicfm.model.musicfm_25hz import MusicFM25Hz

musicfm = MusicFM25Hz(
    is_flash=False,
    stat_path=os.path.join(MUSICFM_PATH, "data", "msd_stats.json"),
    model_path=os.path.join(MUSICFM_PATH, "data", "pretrained_msd.pt"),
)
musicfm.cuda()
musicfm.eval()

In [None]:
# embed sample audio file
y_resampled_t = torch.from_numpy(y_resampled.reshape(1, -1))
y_resampled_t = y_resampled_t.cuda()
emb = musicfm.get_latent(y_resampled_t)
print(emb.shape)

plt.figure(figsize=(6, 4))
plt.imshow(emb[0].detach().cpu().numpy().T, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
plt.title('Token embedding')
plt.xlabel('Time')
plt.ylabel('Embedding dimension')
plt.colorbar()
plt.show()

# resample embedding frame rate from 25Hz to 60Hz to match keypoint data
n_frame = int(y_len_s * KEYPOINT_FRAME_RATE)
token_emb = torch.nn.AdaptiveAvgPool1d(n_frame)(emb.transpose(1, 2)).transpose(1, 2)
print(token_emb.shape, token_emb.shape[1] / KEYPOINT_FRAME_RATE)

plt.figure(figsize=(6, 4))
plt.imshow(token_emb[0].detach().cpu().numpy().T, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
plt.title('Token embedding')
plt.xlabel('Time')
plt.ylabel('Embedding dimension')
plt.colorbar()
plt.show()

In [None]:
for i, music_file in enumerate(music_list):
  y, sr = librosa.load(music_file)
  y_resampled = librosa.resample(y, orig_sr=sr, target_sr=EMBED_MODEL_SR)
  y_len_s = y.shape[0] / sr

  y_resampled_t = torch.from_numpy(y_resampled.reshape(1, -1))
  y_resampled_t = y_resampled_t.cuda()

  emb = musicfm.get_latent(y_resampled_t)
  n_frame = int(y_len_s * KEYPOINT_FRAME_RATE)
  token_emb = torch.nn.AdaptiveAvgPool1d(n_frame)(emb.transpose(1, 2)).transpose(1, 2)

  # save token embedding
  music_name = os.path.basename(music_file).split('.')[0]
  embeds_file = os.path.join(MUSIC_FOLDER, f'{music_name}_musicfm_emb_60Hz.npy')
  np.save(embeds_file, token_emb.detach().cpu().numpy())

  print(f'{i+1}/{len(music_list)}: loaded: {music_file}, {y.shape}, sr: {sr}, len (sec): {y_len_s}. \n'
        f'saved: {embeds_file}, token emb shape: {token_emb.shape}, {token_emb.shape[1] / KEYPOINT_FRAME_RATE} sec')

#### To tensorflow dataset

In [None]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm

def load_dataset_arrs(music_folder, keypoints_folder, load_n):
  music_files = glob.glob(os.path.join(music_folder, '*_musicfm_emb_60Hz.npy'))
  keypoints_files = glob.glob(os.path.join(keypoints_folder, '*.pkl'))
  print(f'Found {len(music_files)} music files and {len(keypoints_files)} keypoints files, loading {load_n} files...')

  music_ids = [os.path.basename(music_file).split('_')[0]
               for music_file in music_files]
  music_ids = dict(zip(music_ids, range(len(music_ids))))

  keypoints = []
  music_embeds = {}
  for keypoints_file in tqdm(keypoints_files[:load_n]):
    filename = os.path.basename(keypoints_file).split('.')[0]
    music_key = filename.split('_')[-2]
    music_file = os.path.join(music_folder, f'{music_key}_musicfm_emb_60Hz.npy')

    if os.path.exists(music_file):
      with open(keypoints_file, 'rb') as f:
        keypoints_data = pickle.load(f)
      keypoints3d_arr = keypoints_data['keypoints3d_optim']
      if music_ids[music_key] not in music_embeds:
        music_embed_arr = np.load(music_file).squeeze()
        music_embeds[music_ids[music_key]] = music_embed_arr
        if np.isnan(music_embed_arr).any():
          print('nan in music_embed_arr')
      keypoints += [(music_ids[music_key], keypoints3d_arr)]
      if np.isnan(keypoints3d_arr).any():
        print('nan in keypoints3d_arr')
  return music_embeds, keypoints

In [None]:
sample_music_embeds, sample_keypoints = load_dataset_arrs(str(MUSIC_FOLDER), str(KEYPOINTS_FOLDER), 100)
print()
print(len(sample_music_embeds), len(sample_keypoints))
print(sample_keypoints[0][0], sample_music_embeds[sample_keypoints[0][0]].shape, sample_keypoints[0][1].shape)

In [None]:
from functools import partial

def embed_keypoints_gen(music_embeds, keypoints):
  for music_id, keypoints3d_arr in keypoints:
    yield (
        keypoints3d_arr.reshape(-1, 17*3),
        music_embeds[music_id][:keypoints3d_arr.shape[0]],
    )

sample_dataset = tf.data.Dataset.from_generator(
    partial(embed_keypoints_gen, sample_music_embeds, sample_keypoints),
    output_signature=(
        tf.TensorSpec(shape=(None, 17*3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 1024), dtype=tf.float32),
    )
)

for sample_keypoints_arr, sample_music_embed_arr in sample_dataset.take(1):
  print("Music Embedding shape:", sample_music_embed_arr.shape)
  print("Keypoints shape:", sample_keypoints_arr.shape)

sample_batch_size = 64
sample_dataset = sample_dataset.padded_batch(
    sample_batch_size, padded_shapes=([None, 17*3], [None, 1024]))

for sample_batch_keypoints_arr, sample_batch_music_embed_arr in sample_dataset.take(1):
  sample_batch_keypoints_arr = sample_batch_keypoints_arr.numpy()
  sample_batch_music_embed_arr = sample_batch_music_embed_arr.numpy()
  print("Batch Keypoints shape:", sample_batch_keypoints_arr.shape)
  print("Batch Music Embedding shape:", sample_batch_music_embed_arr.shape)

In [None]:
# first music dim of batch
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), sharex=True, sharey=True)

ax1_plot = ax1.imshow(sample_batch_music_embed_arr[:, :, 0], cmap='coolwarm', vmin=-1,
           vmax=1, aspect='auto')
plt.colorbar(ax1_plot, ax=ax1)
ax1.set_title('feat1 of batch music embeddings')
ax1.set_xlabel('Time')
ax1.set_ylabel('batch dim')

ax2_plot = ax2.matshow(sample_batch_keypoints_arr[:, :, 0], cmap='coolwarm',
            norm='linear',
           vmin=np.quantile(sample_batch_keypoints_arr, 0.01),
           vmax=np.quantile(sample_batch_keypoints_arr, 0.99), aspect='auto')
plt.colorbar(ax2_plot, ax=ax2)
ax2.set_xlabel('Time')
ax2.set_title('feat1 of batch keypoints')


plt.tight_layout()
plt.show()

In [None]:
# load all music+key
music_embeds, keypoints = load_dataset_arrs(str(MUSIC_FOLDER),
                                            str(KEYPOINTS_FOLDER), 200)
len(music_embeds), len(keypoints)

In [None]:
# split train/test by unique songs
music_ids = list(music_embeds.keys())
n_music = len(music_ids)

key = jax.random.PRNGKey(42)
train_music_ids = jax.random.choice(key, jnp.array(music_ids), shape=[int(n_music*0.8)], replace=False).__array__()
test_music_ids = np.setdiff1d(music_ids, train_music_ids)

train_keypoints = [(m_id, kp_arr) for m_id, kp_arr in keypoints if m_id in train_music_ids]
test_keypoints = [(m_id, kp_arr) for m_id, kp_arr in keypoints if m_id in test_music_ids]

print(f'train: music_ids: {len(train_music_ids)}/{n_music}, n_keypoints: {len(train_keypoints)}')
print(f'test: music_ids: {len(test_music_ids)}/{n_music}, n_keypoints: {len(test_keypoints)}')

In [None]:
from functools import partial

tf.random.set_seed(42)

train_epochs = 40
epoch_steps = 100
batch_size = 128

def embed_keypoints_gen(music_embeds, keypoints):
  for music_id, keypoints3d_arr in keypoints:
    yield (
        keypoints3d_arr.reshape(-1, 17*3),
        music_embeds[music_id][:keypoints3d_arr.shape[0]],
    )

train_Xy_ds = tf.data.Dataset.from_generator(
    partial(embed_keypoints_gen, music_embeds, train_keypoints),
    output_signature=(
        tf.TensorSpec(shape=(None, 17*3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 1024), dtype=tf.float32),
    )
)
test_Xy_ds = tf.data.Dataset.from_generator(
    partial(embed_keypoints_gen, music_embeds, test_keypoints),
    output_signature=(
        tf.TensorSpec(shape=(None, 17*3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 1024), dtype=tf.float32),
    )
)

train_Xy_ds = (train_Xy_ds
               .repeat()
               .shuffle(500)
               .padded_batch(
                    batch_size, padded_shapes=([None, 17*3], [None, 1024]))
               .take(epoch_steps * train_epochs)
               .prefetch(tf.data.AUTOTUNE)
               )
test_Xy_ds = (test_Xy_ds
               .padded_batch(
                    batch_size, padded_shapes=([None, 17*3], [None, 1024]))
               .prefetch(tf.data.AUTOTUNE)
               )

In [None]:
batch_X, batch_y = next(iter(train_Xy_ds))
print(batch_X.shape, batch_y.shape)

##### Diffusion transformer

In [None]:
# Guided Diffusion encoder-only transformer with inputs (X, y, t) where X is the noisy input as a sequence, y is a conditioning sequence of vectors, t is the time step.
# Implementation based on the DiT described in https://arxiv.org/abs/2212.09748 in jax/flax.nnx with a lot of modifications

# Components:
# FeedForward: Dense -> ReLU -> Dropout -> Dense -> Dropout
# TransformerEncoderBlock: in -> LayerNorm -> MultiHeadSelfAttention -> Dropout -> Add(input) -> LayerNorm -> FeedForward -> Add(input) -> out
# ConditionalTransformerEncoderBlock:
#   (in_X, in_y_emb, t_sin_enc) -> \
#     - t_sin_enc -> Dense -> (t_u, t_s)
#     - in_X -> LayerNorm -> ScaleShift(_, t_u, t_s) -> MultiHeadSelfAttention -> Dropout -> Add(_, in_X) -> \
#       -> LayerNorm -> ScaleShift(_, t_u, t_s) -> MultiHeadCrossAttention(_, in_y_emb) -> Dropout -> Add(_, in_X) -> \
#         -> LayerNorm -> ScaleShift(_, t_u, t_s) -> FeedForward -> Add(_, in_X) -> out_X
#     -> out_X
# TransformerEncoder:
#   (X, y, t) -> \
#     - t -> get_sinusoidal_encodings(t) -> t_sin_enc
#     - y -> Dense -> PositionalEncoding -> [ TransformerEncoderBlock x n ] -> LayerNorm -> y_emb
#     - X -> Dense -> PositionalEncoding -> [ ConditionalTransformerEncoderBlock(_, y_emb, t_sin_enc) x n ] -> LayerNorm -> Dense -> out_X
#     -> out_X


import jax
import jax.numpy as jnp
import flax
from flax import nnx
from functools import partial


class FeedForward(nnx.Module):
  def __init__(self, *, in_feats, out_feats=None, hidden_feats=None, dropout=0.0, rngs: nnx.Rngs):
    out_feats = out_feats or in_feats
    hidden_feats = hidden_feats or in_feats
    self.dense1 = nnx.Linear(in_feats, hidden_feats, rngs=rngs)
    self.dense2 = nnx.Linear(hidden_feats, out_feats, rngs=rngs)
    self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)

  def __call__(self, x, train: bool = False):
    x = self.dense1(x)
    x = nnx.relu(x)
    x = self.dropout(x, deterministic=not train)
    x = self.dense2(x)
    x = self.dropout(x, deterministic=not train)
    return x


def get_sinusoidal_encodings(t, embedding_dim):
  half_dim = embedding_dim // 2
  emb = jnp.log(embedding_dim) / (half_dim - 1)
  emb = jnp.exp(jnp.arange(half_dim) * -emb)
  emb = half_dim * t * emb[None, :]
  emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
  return emb


class TransformerEncoderBlock(nnx.Module):
  def __init__(self, *, in_feats, num_heads=8, dropout=0.0, ff_feats_x=4, rngs: nnx.Rngs):
    self.attention = nnx.MultiHeadAttention(
      num_heads=num_heads,
      in_features=in_feats,
      out_features=in_feats,
      dropout_rate=dropout,
      decode=False,
      deterministic=True,
      rngs=rngs)
    self.ff = FeedForward(in_feats=in_feats, hidden_feats=in_feats * ff_feats_x, dropout=dropout, rngs=rngs)
    self.layer_norm1 = nnx.LayerNorm(in_feats, rngs=rngs)
    self.layer_norm2 = nnx.LayerNorm(in_feats, rngs=rngs)
    self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)

  def __call__(self, x, mask=None, train: bool = False):
    x = self.layer_norm1(x)
    x = self.attention(x, mask=mask)
    x = self.dropout(x, deterministic=not train)
    x = x + x
    x = self.layer_norm2(x)
    x = self.ff(x, train=train)
    x = x + x
    return x


class ConditionalTransformerEncoderBlock(nnx.Module):
  def __init__(self, *, in_feats, t_feats, num_heads=8, dropout=0.0, ff_feats_x=4, rngs: nnx.Rngs):
    self.attention = nnx.MultiHeadAttention(
      num_heads=num_heads,
      in_features=in_feats,
      out_features=in_feats,
      dropout_rate=dropout,
      decode=False,
      deterministic=True,
      rngs=rngs)
    self.cross_attention = nnx.MultiHeadAttention(
      num_heads=num_heads,
      in_features=in_feats,
      out_features=in_feats,
      dropout_rate=dropout,
      decode=False,
      deterministic=True,
      rngs=rngs)
    self.ff = FeedForward(in_feats=in_feats, hidden_feats=in_feats * ff_feats_x, dropout=dropout, rngs=rngs)
    self.layer_norm1 = nnx.LayerNorm(in_feats, rngs=rngs)
    self.layer_norm2 = nnx.LayerNorm(in_feats, rngs=rngs)
    self.layer_norm3 = nnx.LayerNorm(in_feats, rngs=rngs)
    self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)
    self.dense_t_u_s = nnx.Linear(t_feats, in_feats * 2, rngs=rngs)

  def __call__(self, x, y_enc, t_sin_enc, mask=None, train: bool = False):
    # x: (batch_size, seq_len, in_feats)
    # y_enc: (batch_size, seq_len, in_feats)
    # t_sin_enc: (batch_size, t_feats)
    # mask: (batch_size, seq_len, seq_len)

    t_u_s = self.dense_t_u_s(t_sin_enc)[:, None, :]  # (batch_size, 1, in_feats * 2)
    t_u, t_s = jnp.split(t_u_s, 2, axis=-1)
    #
    x = self.layer_norm1(x)
    x = (x + t_u) * t_s
    x = self.attention(x, mask=mask)
    x = self.dropout(x, deterministic=not train)
    x = x + x
    #
    x = self.layer_norm2(x)
    x = (x + t_u) * t_s
    x = self.cross_attention(x, y_enc, mask=mask)
    x = self.dropout(x, deterministic=not train)
    x = x + x
    #
    x = self.layer_norm3(x)
    x = (x + t_u) * t_s
    x = self.ff(x, train=train)
    x = x + x
    return x


class AddPositionalEncoding(nnx.Module):
  def __init__(self, *, max_len=512, in_feats=64, rngs: nnx.Rngs):
    self.max_len = max_len
    self.in_feats = in_feats
    self.rngs = rngs
    self.pos_enc = get_sinusoidal_encodings((jnp.arange(max_len)/max_len).reshape(-1, 1), in_feats)
    self.pos_enc = jnp.expand_dims(self.pos_enc, axis=0)  # (1, max_len, in_feats)

  def __call__(self, x):
    # x: (batch_size, seq_len, in_feats)
    # self.pos_enc: (1, max_len, in_feats)
    # return x + self.pos_enc[:, :x.shape[1], :]
    return x + self.pos_enc[:, :x.shape[1], :]


class GuidedDiff_TransfomerEncoder(nnx.Module):
  def __init__(self, *,
               in_feats_X, in_feats_y,
               Xy_feats, out_feats, t_feats=64,
               num_heads=8, num_blocks_X=4, num_blocks_y=2,
               dropout=0.0, ff_feats_x=4, max_len=1024,
               padding_max_value=0,
               rngs: nnx.Rngs):
    self.get_masked_timesteps = lambda y: (y != padding_max_value).sum(axis=-1) > 0
    self.dense_X = nnx.Linear(in_feats_X, Xy_feats, rngs=rngs)
    self.dense_y = nnx.Linear(in_feats_y, Xy_feats, rngs=rngs)
    self.layer_norm_out_X = nnx.LayerNorm(Xy_feats, rngs=rngs)
    self.layer_norm_out_y = nnx.LayerNorm(Xy_feats, rngs=rngs)
    self.get_sin_enc = partial(get_sinusoidal_encodings, embedding_dim=t_feats)
    self.pos_enc_X = AddPositionalEncoding(max_len=max_len, in_feats=Xy_feats, rngs=rngs)
    self.pos_enc_y = AddPositionalEncoding(max_len=max_len, in_feats=Xy_feats, rngs=rngs)
    self.encoders_y = nnx.Sequential(*[
      TransformerEncoderBlock(in_feats=Xy_feats, num_heads=num_heads,
                              dropout=dropout, ff_feats_x=ff_feats_x, rngs=rngs)
      for _ in range(num_blocks_y)
    ])
    self.encoders_X = [ConditionalTransformerEncoderBlock(
      in_feats=Xy_feats, t_feats=t_feats, num_heads=num_heads,
      dropout=dropout, ff_feats_x=ff_feats_x, rngs=rngs)
      for _ in range(num_blocks_X)]
    self.dense_out = nnx.Linear(Xy_feats, out_feats, rngs=rngs)

  def __call__(self, x, y, t, train: bool = False):
    # x: (batch_size, seq_len_X, in_feats_X)
    # y: (batch_size, seq_len_y, in_feats_y)
    # t: (batch_size,)
    # mask inputs where all y values are padding_max_value

    # mask = flax.linen.make_attention_mask(inputs != 0, inputs != 0)
    y_masked = self.get_masked_timesteps(y)
    mask = flax.linen.make_attention_mask(
      y_masked, y_masked, dtype=jnp.float32
    )

    x = self.dense_X(x)
    x = self.pos_enc_X(x)

    y = self.dense_y(y)
    y = self.pos_enc_y(y)
    y = self.encoders_y(y, mask=mask, train=train)
    y = self.layer_norm_out_y(y)

    t = t.reshape(-1, 1)
    t_sin_enc = self.get_sin_enc(t)

    for encoder in self.encoders_X:
      x = encoder(x, y, t_sin_enc, mask=mask, train=train)

    x = self.layer_norm_out_X(x)

    # output -> Dense
    x = self.dense_out(x)

    # mask output timesteps
    x = x * y_masked[:, :, None]

    return x



In [None]:
# test the model

sample_batch_size = 3
sample_max_sequence_length = 4
sample_feats_X = 6
sample_feats_y = 4
sample_feats_Xy = 10
sample_feats_out = 2
sample_t_feats = 16

key = jax.random.PRNGKey(2)
X_sample = jax.random.normal(key, (sample_batch_size, sample_max_sequence_length, sample_feats_X))
y_sample = jax.random.normal(key, (sample_batch_size, sample_max_sequence_length, sample_feats_y))
t_sample = jax.random.uniform(key, (sample_batch_size,))

# mask some y time steps
y_sample_len = jax.random.randint(key, (sample_batch_size,), minval=1, maxval=sample_max_sequence_length + 1)
y_sample_mask = jnp.arange(sample_max_sequence_length)[None, :] < y_sample_len[:, None]
y_sample = jnp.where(y_sample_mask[:, :, None], y_sample, 0.0)

X_sample.shape, y_sample.shape, t_sample.shape


In [None]:
y = y_sample
y_masked = (y != 0).sum(axis=-1) > 0
mask = flax.linen.make_attention_mask(
  y_masked, y_masked, dtype=jnp.float32
)
mask

In [None]:
y_masked

In [None]:
sample_model = GuidedDiff_TransfomerEncoder(
  in_feats_X=sample_feats_X, in_feats_y=sample_feats_y,
  Xy_feats=sample_feats_Xy, out_feats=sample_feats_out, t_feats=64,
  num_heads=2, num_blocks_X=2, num_blocks_y=2,
  dropout=0.0, ff_feats_x=4, max_len=sample_max_sequence_length,
  padding_max_value=0,
  rngs=nnx.Rngs(42),
)
nnx.display(sample_model)

In [None]:
X_sample_out = sample_model(X_sample, y_sample, t_sample)
X_sample_out.shape

In [None]:
X_sample_out

In [None]:
batch_X.shape, batch_y.shape

In [None]:
# init actual model

hidden_Xy_dim = 128
num_heads = 8
num_blocks_X = 4
num_blocks_y = 2
dropout = 0.1
ff_feats_x = 4
max_len = batch_X.shape[1]

model = GuidedDiff_TransfomerEncoder(
  in_feats_X=batch_X.shape[-1], in_feats_y=batch_y.shape[-1],
  Xy_feats=hidden_Xy_dim, out_feats=batch_X.shape[-1], t_feats=64,
  num_heads=num_heads, num_blocks_X=num_blocks_X, num_blocks_y=num_blocks_y,
  dropout=dropout, ff_feats_x=ff_feats_x, max_len=max_len,
  padding_max_value=0,
  rngs=nnx.Rngs(42),
)
nnx.display(model)

In [None]:
batch_X_pred = model(batch_X.numpy(), batch_y.numpy(), jnp.ones(batch_X.shape[0]))
batch_X_pred.shape

##### Guided conditional flow matching loss

In [None]:
import optax

def guided_cfm_loss_fn(model, batch_z, batch_y, batch_t, batch_e, train):
  """Compute the conditional flow matching loss."""
  # add noise to the input
  alpha_t = batch_t
  beta_t = (1 - alpha_t)
  batch_x = alpha_t * batch_z + beta_t * batch_e
  pred_e = model(batch_x, batch_y, batch_t, train=train)
  target_e = (batch_z - batch_e)
  loss = optax.losses.squared_error(pred_e, target_e).mean()
  return loss

@nnx.jit
def train_step_guided(model, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch_z, batch_y, key):
  """Train for a single step."""
  # sample t: [batch_size,1,1,1] ~ Unif[0,1] and e: batch_shape ~ N(0,1)
  batch_t = jax.random.uniform(key, [batch_z.shape[0]] + ([1]*(batch_z.ndim-1)))
  batch_e = jax.random.normal(key, batch_z.shape)

  loss, grads = nnx.value_and_grad(guided_cfm_loss_fn, argnums=0)(model, batch_z, batch_y, batch_t, batch_e, train=True)
  metrics.update(loss=loss)  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step_guided(model, metrics: nnx.MultiMetric, batch_z, batch_y, key):
  # sample t: [batch_size,1,1,1] ~ Unif[0,1] and e: batch_shape ~ N(0,1)
  batch_t = jax.random.uniform(key, [batch_z.shape[0]] + ([1]*(batch_z.ndim-1)))
  batch_e = jax.random.normal(key, batch_z.shape)

  loss = guided_cfm_loss_fn(model, batch_z, batch_y, batch_t, batch_e, train=False)
  metrics.update(loss=loss)  # In-place updates.

##### Train

In [None]:
learning_rate = 1e-2
momentum = 0.9

# optax.schedules.cosine_onecycle_schedule(transition_steps: int, peak_value: float, pct_start: float = 0.3, div_factor: float = 25.0, final_div_factor: float = 10000.0)
lr_schedule = optax.schedules.cosine_onecycle_schedule(
  (25)*epoch_steps, peak_value=learning_rate,
  pct_start=5/train_epochs, div_factor=1000, final_div_factor=10.0)

optimizer = nnx.Optimizer(
  model,
  optax.inject_hyperparams(optax.adamw)(lr_schedule, momentum, weight_decay=1e-7),
)
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)

nnx.display(optimizer)
