### Data

In [1]:
import sys
import os

path = os.path.abspath(os.path.join('..'))
if path not in sys.path:
    sys.path.append(path)

In [2]:
from data_loaders.dataset import MujocoMotionDataset
dataset = MujocoMotionDataset("../data/motions")
len(dataset)

15

In [3]:
first_motion = dataset[0]
pos, vel, label = first_motion
pos.shape, vel.shape, label

(torch.Size([121, 35]), torch.Size([121, 34]), 'roll')

### Model

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)


class TimestepEmbedder(nn.Module):
    def __init__(self, latent_dim, sequence_pos_encoder):
        super().__init__()
        self.latent_dim = latent_dim
        self.sequence_pos_encoder = sequence_pos_encoder

        time_embed_dim = self.latent_dim
        self.time_embed = nn.Sequential(
            nn.Linear(self.latent_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

    def forward(self, timesteps):
        return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
    
class MotionTransformer(nn.Module):
    def __init__(self, njoints, nfeats, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, activation="gelu"):
        super(MotionTransformer, self).__init__()
        
        self.njoints = njoints
        self.nfeats = nfeats
        self.input_dim = njoints * nfeats
        self.latent_dim = latent_dim
        self.ff_size = ff_size  
        self.dropout = dropout

        self.poseEmbedding = nn.Linear(self.nfeats, self.latent_dim)
        self.velEmbedding = nn.Linear(self.nfeats, self.latent_dim)
        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)

        # Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=num_heads, 
                                                    dim_feedforward=ff_size, dropout=dropout, activation=activation)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

        # Output Linear Layer
        self.poseFinal = nn.Linear(self.latent_dim, self.nfeats)
        self.velFinal = nn.Linear(self.latent_dim, self.nfeats)

    def forward(self, x, timesteps, y=None):
        """
        x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
        timesteps: [batch_size] (int)
        """
        # x: [batch_size, njoints, nfeats, seq_len]
        bs, n_joints, n_feats, n_frames = x.shape
        emb = self.embed_timestep(timesteps)  # [1, bs, d]

        # Input process
        x = x.permute(3, 0, 1, 2).reshape(n_frames, bs, self.input_dim)  # [n_frames, batch_size, input_dim]
        first_pose = x[[0]]  # [1, bs, 150]
        first_pose = self.poseEmbedding(first_pose)  # [1, bs, d]
        vel = x[1:]  # [n_frames-1, bs, 150]
        vel = self.velEmbedding(vel)  # [n_frames-1, bs, d]
        x = torch.cat((first_pose, vel), axis=0)  # [n_frames, bs, d]

        # Transformer Encoder
        # adding the timestep embed
        xseq = torch.cat((emb, x), axis=0)  # [n_frames+1, bs, d]
        xseq = self.sequence_pos_encoder(xseq)  # [n_frames+1, bs, d]
        output = self.transformer_encoder(xseq)[1:]  # , src_key_padding_mask=~maskseq)  # [n_frames, bs, d]

        # Output Linear
        first_pose = output[[0]]  # [1, bs, d]
        first_pose = self.poseFinal(first_pose)  # [1, bs, 150]
        vel = output[1:]  # [n_frames-1, bs, d]
        vel = self.velFinal(vel)  # [n_frames-1, bs, 150]
        output = torch.cat((first_pose, vel), axis=0)  # [n_frames, bs, 150]

        # Reshape to original format
        x = x.permute(1, 2, 0).reshape(bs, self.njoints, self.nfeats, n_frames)
        return x