In [17]:
!PYTHONPATH=.

In [18]:
import os
import json
from torch.utils.data import DataLoader
from utils.fixseed import fixseed
from utils.parser_util import train_args
from utils import dist_util
from train.training_loop import TrainLoop
from data_loaders.get_data import get_dataset_loader
from utils.model_util import create_model_and_diffusion
from train.train_platforms import ClearmlPlatform, TensorboardPlatform, NoPlatform  # required for the eval operation

## Creating Dataloader

In [19]:
from torch.utils.data import TensorDataset, DataLoader
import os
import numpy as np
import torch
import torch.nn.functional as F

pose_tensor = torch.tensor(np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Motion_feature.npy"))
keep_joints = [13,14,16,17,18,19,20,21]
filtered = pose_tensor[:, :, keep_joints, :].reshape(-1,120,48)
filtered = F.pad(filtered, (0, 31), mode="constant", value=0)
pose_tensor = filtered.reshape(-1,1,120,79).permute(0,3,1,2)

audio_tensor = torch.tensor(np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Audio_feature.npy"))
#audio_tensor = F.pad(audio_tensor, (0, 74), mode="constant", value=0)
audio_tensor = audio_tensor.reshape(-1,1,120,64)

beat_tensor = torch.tensor(np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Beat_feature.npy"))
#beat_tensor = F.pad(beat_tensor, (0, 33), mode="constant", value=0)
beat_tensor = beat_tensor.reshape(-1,1,120,15)

full_tensor = torch.concatenate([audio_tensor,beat_tensor], axis=-1).permute(0,3,1,2)

dataset = TensorDataset(full_tensor, pose_tensor)
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=5)
#test_dataloader = DataLoader(dataset[15793:], batch_size=64, shuffle=True, num_workers=5)

## Model Framework

In [20]:
from argparse import ArgumentParser
import argparse

def add_base_options(parser):
    group = parser.add_argument_group('base')
    group.add_argument("--cuda", default=True, type=bool, help="Use cuda device, otherwise use CPU.")
    group.add_argument("--device", default=0, type=int, help="Device id to use.")
    group.add_argument("--seed", default=10, type=int, help="For fixing random seed.")
    group.add_argument("--batch_size", default=64, type=int, help="Batch size during training.")
    group.add_argument("--short_db", action='store_true', help="Load short babel for debug.")
    group.add_argument("--cropping_sampler", action='store_true', help="Load short babel for debug.")
    
def add_data_options(parser):
    group = parser.add_argument_group('dataset')
    group.add_argument("--dataset", default='humanml', choices=['humanml', 'amass', 'babel'], type=str,
                       help="Dataset name (choose from list).")
    group.add_argument("--data_dir", default="", type=str,
                       help="If empty, will use defaults according to the specified dataset.")
    
def add_model_options(parser):
    group = parser.add_argument_group('model')
    group.add_argument("--arch", default='trans_enc',
                       choices=['trans_enc', 'trans_dec', 'gru'], type=str,
                       help="Architecture types as reported in the paper.")
    group.add_argument("--emb_trans_dec", default=False, type=bool,
                       help="For trans_dec architecture only, if true, will inject condition as a class token"
                            " (in addition to cross-attention).")
    group.add_argument("--layers", default=8, type=int,
                       help="Number of layers.")
    group.add_argument("--latent_dim", default=512, type=int,
                       help="Transformer/GRU width.")
    group.add_argument("--cond_mask_prob", default=.1, type=float,
                       help="The probability of masking the condition during training."
                            " For classifier-free guidance learning.")
    group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.")
    group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.")
    group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.")
    group.add_argument("--use_tta", action='store_true', help="Time To Arrival position encoding")  # FIXME REMOVE?
    group.add_argument("--concat_trans_emb", action='store_true', help="Concat transition emb, else append after linear")  # FIXME REMOVE?
    group.add_argument("--trans_emb", action='store_true', help="Allow transition embedding")  # FIXME REMOVE?
    
    group.add_argument("--context_len", default=0, type=int, help="If larger than 0, will do prefix completion.")
    group.add_argument("--pred_len", default=0, type=int, help="If context_len larger than 0, will do prefix completion. If pred_len will not be specified - will use the same length as context_len")
    
def add_diffusion_options(parser):
    group = parser.add_argument_group('diffusion')
    group.add_argument("--noise_schedule", default='cosine', choices=['linear', 'cosine'], type=str,
                       help="Noise schedule type")
    group.add_argument("--diffusion_steps", default=1000, type=int,
                       help="Number of diffusion steps (denoted T in the paper)")
    group.add_argument("--sigma_small", default=True, type=bool, help="Use smaller sigma values.")
    
def add_training_options(parser):
    group = parser.add_argument_group('training')
    group.add_argument("--save_dir", required=False, default="/workspace/priorMD/temporary_folder/test_our_chatgptData", type=str,
                       help="Path to save checkpoints and results.")
    group.add_argument("--overwrite", action='store_true',
                       help="If True, will enable to use an already existing save_dir.")
    group.add_argument("--train_platform_type", default='NoPlatform', choices=['NoPlatform', 'ClearmlPlatform', 'TensorboardPlatform'], type=str,
                       help="Choose platform to log results. NoPlatform means no logging.")
    group.add_argument("--lr", default=1e-4, type=float, help="Learning rate.")
    group.add_argument("--weight_decay", default=0.0, type=float, help="Optimizer weight decay.")
    group.add_argument("--lr_anneal_steps", default=0, type=int, help="Number of learning rate anneal steps.")
    group.add_argument("--eval_batch_size", default=32, type=int,
                       help="Batch size during evaluation loop. Do not change this unless you know what you are doing. "
                            "T2m precision calculation is based on fixed batch size 32.")
    group.add_argument("--eval_split", default='test', choices=['val', 'test'], type=str,
                       help="Which split to evaluate on during training.")
    group.add_argument("--eval_during_training", action='store_true',
                       help="If True, will run evaluation during training.")
    group.add_argument("--eval_rep_times", default=3, type=int,
                       help="Number of repetitions for evaluation loop during training.")
    group.add_argument("--eval_num_samples", default=1_000, type=int,
                       help="If -1, will use all samples in the specified split.")
    group.add_argument("--log_interval", default=1_000, type=int,
                       help="Log losses each N steps")
    group.add_argument("--save_interval", default=6_000, type=int,
                       help="Save checkpoints and run evaluation each N steps")
    group.add_argument("--num_steps", default=300_000, type=int,
                       help="Training will stop after the specified number of steps.")
    group.add_argument("--num_frames", default=120, type=int,
                       help="Limit for the maximal number of frames. In HumanML3D and KIT this field is ignored.")
    group.add_argument("--resume_checkpoint", default="", type=str,
                       help="If not empty, will start from the specified checkpoint (path to model###.pt file).")
    
def add_frame_sampler_options(parser):
    group = parser.add_argument_group('framesampler')
    group.add_argument("--min_seq_len", default=45, type=int,
                       help="babel dataset FrameSampler minimum length")
    group.add_argument("--max_seq_len", default=250, type=int,
                       help="babel dataset FrameSampler maximum length")

parser = ArgumentParser()
add_base_options(parser)
add_data_options(parser)
add_model_options(parser)
add_diffusion_options(parser)
add_training_options(parser)
add_frame_sampler_options(parser)
args, unknown = parser.parse_known_args()

#args

In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
import os
#from priorMD.model.rotation2xyz import Rotation2xyz

class DiffusionModel(nn.Module):
    def __init__(self, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot,
                 latent_dim=512, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
                 ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512,
                 arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs):
        super().__init__()

        self.legacy = legacy
        #self.modeltype = modeltype
        self.njoints = njoints
        self.nfeats = nfeats
        self.num_actions = num_actions
        self.data_rep = data_rep
        self.dataset = dataset

        self.pose_rep = pose_rep
        self.glob = glob
        self.glob_rot = glob_rot
        self.translation = translation

        self.latent_dim = latent_dim

        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout

        self.ablation = ablation
        self.activation = activation
        self.clip_dim = clip_dim
        self.action_emb = kargs.get('action_emb', None)

        self.input_feats = self.njoints * self.nfeats

        self.normalize_output = kargs.get('normalize_encoder_output', False)

        self.cond_mode = 'text'# kargs.get('cond_mode', 'no_cond')
        self.cond_mask_prob = kargs.get('cond_mask_prob', 0.)
        self.arch = arch
        self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0
        self.Cam_input_process = CamInputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim)
        self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim)
        
        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        self.cond_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        self.cam_sequence_pos_encoder = CamPositionalEncoding(self.latent_dim, self.dropout)
        self.emb_trans_dec = emb_trans_dec
        
        self.Camera_input_process = nn.Linear(6, self.latent_dim)
        self.Camera_output_process = nn.Linear(self.latent_dim, 6)
        
        self.encode_img = nn.Linear(768, 512)
        self.encode_output = nn.Linear(1536, 512)
        self.cond_proj = nn.Linear(512, 512)
        self.ada_mlp = nn.Linear(1024, 1024)
        self.ln = nn.LayerNorm(latent_dim)
        self.mlp = nn.Linear(512, 512)
        
        self.camera_person = CameraPersonBlock(arch="trans_enc",
                                             fn_type='in_both_out_cur',
                                             num_layers=2,
                                             latent_dim=self.latent_dim,
                                             input_feats=self.input_feats,
                                             predict_6dof=True)

        if self.arch == 'trans_enc':
            #assert 0 < self.args.multi_backbone_split <= self.num_layers
            print(f'CUTTING BACKBONE AT LAYER 8')
            seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                              nhead=self.num_heads,
                                                              dim_feedforward=self.ff_size,
                                                              dropout=self.dropout,
                                                              activation=self.activation)
            
            CamTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                              nhead=self.num_heads,
                                                              dim_feedforward=self.ff_size,
                                                              dropout=self.dropout,
                                                              activation=self.activation)

            self.CamTransEncoder = nn.TransformerEncoder(CamTransEncoderLayer,
                                                         num_layers=self.num_layers)
        
            self.seqTransEncoder_start = nn.TransformerEncoder(seqTransEncoderLayer,
                                                               num_layers=8)
            self.seqTransEncoder_end = nn.TransformerEncoder(seqTransEncoderLayer,
                                                             num_layers=self.num_layers - 8)

        self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)

        if self.cond_mode != 'no_cond':
            if 'text' in self.cond_mode:
                self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
                print('EMBED TEXT')
                print('Loading CLIP...')
               # self.clip_version = clip_version
               # self.clip_model = self.load_and_freeze_clip(clip_version)
            if 'action' in self.cond_mode:
                self.embed_action = EmbedAction(self.num_actions, self.latent_dim)
                print('EMBED ACTION')

        self.Cam_output_process = CamOutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
                                            self.nfeats)

        #self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)
        self.series_model = TimeSeriesTransformer(input_dim=150, embed_dim=latent_dim, num_heads=8, num_layers=4)
        
        self.motion_cross_attn = nn.MultiheadAttention(latent_dim, num_heads=4, batch_first=True)
        self.cross_attn_x_to_y = nn.MultiheadAttention(latent_dim, num_heads=4, batch_first=True)
        self.motion_norm = nn.LayerNorm(latent_dim)

    def parameters_wo_clip(self):
        return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]

    def load_and_freeze_clip(self, clip_version):
        clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(
            clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        clip_model.eval()
        for p in clip_model.parameters():
            p.requires_grad = False

        return clip_model

    def mask_cond(self, cond, force_mask=False):
        bs, d = cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_mask_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1)  # 1-> use null_cond, 0-> use real cond
            return cond * (1. - mask)
        else:
            return cond

    def encode_text(self, raw_text):
        # raw_text - list (batch_size length) of strings with input text prompts
        device = next(self.parameters()).device
        max_text_len = 20 if self.dataset in ['humanml', 'kit'] else None  # Specific hardcoding for humanml dataset
        if max_text_len is not None:
            default_context_length = 77
            context_length = max_text_len + 2 # start_token + 20 + end_token
            assert context_length < default_context_length
            texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate
            # print('texts', texts.shape)
            zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
            texts = torch.cat([texts, zero_pad], dim=1)
            # print('texts after pad', texts.shape, texts)
        else:
            texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate
        return self.clip_model.encode_text(texts).float()

    def forward(self, x, timesteps, y=None):
        #print(x.shape)
        """
        x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
        timesteps: [batch_size] (int)
        """
        
        #print(y.shape)
        #print(x.shape)
        
        
        bs, njoints, nfeats, nframes = x.shape
        emb = self.embed_timestep(timesteps)  # [1, bs, d
        
        #seq1 = y[:,:,0,:]
        #seq2 = y[:,:,1,:]
        
        x_cur = self.input_process(y) #[seqlen, bs, d]
       # x_other = self.input_process(seq2.unsqueeze(2))
        x_camera = self.Cam_input_process(x)
        #print(x_camera.shape)

        #print(emb.shape, x.shape, x_other.shape, x_camera.shape)

        low_x = x_cur
        low_x_camera = x_camera
        #print(emb.shape,x.shape)

        # adding the timestep embed
        xseq = torch.cat((emb, x_cur), axis=0)  # [seqlen+1, bs, d]
        xseq = self.sequence_pos_encoder(xseq)  # [seqlen+1, bs, d]
        #x_other = torch.cat((emb, x_other), axis=0)
        #x_other = self.sequence_pos_encoder(x_other)
        x_camera = torch.cat((emb, x_camera), axis=0)
        x_camera = self.cam_sequence_pos_encoder(x_camera)


        mid = self.seqTransEncoder_start(xseq)[1:]
        #mid_other = self.seqTransEncoder_start(x_other)[1:]
        mid_Camera = self.CamTransEncoder(x_camera)[1:]


        delta_camera_1, delta_x_1 = self.camera_person(low_cur=low_x_camera, 
                                                        low_other=low_x, 
                                                       cur=mid_Camera, 
                                                        other=mid)

        #delta_camera_2, delta_x_other_1 = self.camera_person(low_cur=low_x_camera, 
                                                           # low_other=low_x_other, 
                                                          #  cur=mid_Camera, 
                                                           # other=mid_other)

        mid_Camera += delta_camera_1           

        output_camera = mid_Camera
            #print("222222222")

        #print(output_x.shape, output_other.shape, output_camera.shape)

        output_camera = self.Cam_output_process(output_camera)
        
        return output_camera


    #def _apply(self, fn):
        #super()._apply(fn)
        #self.rot2xyz.smpl_model._apply(fn)


    def train(self, *args, **kwargs):
        super().train(*args, **kwargs)
        #self.rot2xyz.smpl_model.train(*args, **kwargs)

        
# ---- Learnable Positional Encoding ----
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        return x + self.pos_embed[:, :x.size(1)]

# ---- Transformer Encoder for Time Series ----
class TimeSeriesTransformer(nn.Module):
    def __init__(self, input_dim, embed_dim=512, num_heads=8, num_layers=4):
        super().__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)
        self.pos_encoding = LearnablePositionalEncoding(embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1024, dropout=0.1, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        x = self.transformer(x)
        return self.norm(x)
    
    
class CamPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(CamPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)
        

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)


class TimestepEmbedder(nn.Module):
    def __init__(self, latent_dim, sequence_pos_encoder):
        super().__init__()
        self.latent_dim = latent_dim
        self.sequence_pos_encoder = sequence_pos_encoder

        time_embed_dim = self.latent_dim
        self.time_embed = nn.Sequential(
            nn.Linear(self.latent_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

    def forward(self, timesteps):
        return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
    
class CameraPersonBlock(nn.Module):
    def __init__(self, arch, fn_type, num_layers, latent_dim, input_feats, predict_6dof):
        super().__init__()
        self.arch = arch
        self.fn_type = fn_type
        self.predict_6dof = predict_6dof
        self.num_layers = num_layers
        self.latent_dim = latent_dim
        self.num_heads = 4
        self.ff_size = 1024
        self.dropout = 0.1
        self.activation = 'gelu'
        self.input_feats = input_feats
        if self.predict_6dof:
            self.canon_agg = nn.Linear(9*2, self.latent_dim)
            # self.canon_agg = nn.Linear(9*2, self.latent_dim)
            # self.canon_agg = nn.Linear(self.input_feats*2, self.latent_dim)
            self.canon_out = nn.Linear(self.latent_dim, 9)
            # self.canon_out = nn.Linear(self.latent_dim, self.input_feats)
        if 'in_both' in self.fn_type:
            self.aggregation = nn.Linear(self.latent_dim*2, self.latent_dim)
        if self.arch == 'trans_enc':
            seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                              nhead=self.num_heads,
                                                              dim_feedforward=self.ff_size,
                                                              dropout=self.dropout,
                                                              activation=self.activation)
            self.model = nn.TransformerEncoder(seqTransEncoderLayer,
                                               num_layers=self.num_layers)
        self.cross_attention = CrossAttention(embed_size=512, heads=4) 
        self.con_global = nn.Linear(self.latent_dim, 9)
        self.avg_pooling = nn.AdaptiveAvgPool1d(output_size=1)
        self.max_pooling = nn.AdaptiveMaxPool1d(output_size=1)

    def forward(self, low_cur=None, low_other=None, other=None, cur=None):     
        low_x, low_x_other = low_cur + cur, low_other + other
        # 交换concat顺序计算向量
        x = self.aggregation(torch.concatenate((low_x, low_x_other), dim=-1))
        x_other = self.aggregation(torch.concatenate((low_x_other, low_x), dim=-1))
        # print("COMDMD中的x维度:", x.shape) torch.Size([120, 64, 512])
        x_out = self.model(x)# torch.Size([120, 64, 512])
        x_other_out= self.model(x_other)# torch.Size([120, 64, 512])
        # print("文本向量的emb:", text_emb.shape) torch.Size([1, 64, 512])
        # cross attention 
        #xA_cross, _ = self.cross_attn_B2A(x_out, x_other_out, x_other_out)
        #xB_cross, _ = self.cross_attn_A2B(x_other_out, x_out, x_out)

        return x_out, x_other_out


class CamInputProcess(nn.Module):
    def __init__(self, data_rep, input_feats, latent_dim):
        super().__init__()
        self.data_rep = data_rep
        self.input_feats = input_feats
        self.latent_dim = latent_dim
        self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
        if self.data_rep == 'rot_vel':
            self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim)

    def forward(self, x):
        bs, njoints, nfeats, nframes = x.shape
        x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats)

        if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
            x = self.poseEmbedding(x)  # [seqlen, bs, d]
            return x
        elif self.data_rep == 'rot_vel':
            first_pose = x[[0]]  # [1, bs, 150]
            first_pose = self.poseEmbedding(first_pose)  # [1, bs, d]
            vel = x[1:]  # [seqlen-1, bs, 150]
            vel = self.velEmbedding(vel)  # [seqlen-1, bs, d]
            return torch.cat((first_pose, vel), axis=0)  # [seqlen, bs, d]
        else:
            raise ValueError


class CamOutputProcess(nn.Module):
    def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats):
        super().__init__()
        self.data_rep = data_rep
        self.input_feats = input_feats
        self.latent_dim = latent_dim
        self.njoints = njoints
        self.nfeats = nfeats
        self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
        if self.data_rep == 'rot_vel':
            self.velFinal = nn.Linear(self.latent_dim, self.input_feats)

    def forward(self, output):
        nframes, bs, d = output.shape
        if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
            output = self.poseFinal(output)  # [seqlen, bs, 150]
        elif self.data_rep == 'rot_vel':
            first_pose = output[[0]]  # [1, bs, d]
            first_pose = self.poseFinal(first_pose)  # [1, bs, 150]
            vel = output[1:]  # [seqlen-1, bs, d]
            vel = self.velFinal(vel)  # [seqlen-1, bs, 150]
            output = torch.cat((first_pose, vel), axis=0)  # [seqlen, bs, 150]
        else:
            raise ValueError
        output = output.reshape(nframes, bs, self.njoints, self.nfeats)
        output = output.permute(1, 2, 3, 0)  # [bs, njoints, nfeats, nframes]
        return output
    
class InputProcess(nn.Module):
    def __init__(self, data_rep, input_feats, latent_dim):
        super().__init__()
        self.data_rep = data_rep
        self.input_feats = input_feats
        self.latent_dim = latent_dim
        self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
        if self.data_rep == 'rot_vel':
            self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim)

    def forward(self, x):
        bs, njoints, nfeats, nframes = x.shape
        x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats)

        if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
            x = self.poseEmbedding(x)  # [seqlen, bs, d]
            return x
        elif self.data_rep == 'rot_vel':
            first_pose = x[[0]]  # [1, bs, 150]
            first_pose = self.poseEmbedding(first_pose)  # [1, bs, d]
            vel = x[1:]  # [seqlen-1, bs, 150]
            vel = self.velEmbedding(vel)  # [seqlen-1, bs, d]
            return torch.cat((first_pose, vel), axis=0)  # [seqlen, bs, d]
        else:
            raise ValueError


class OutputProcess(nn.Module):
    def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats):
        super().__init__()
        self.data_rep = data_rep
        self.input_feats = input_feats
        self.latent_dim = latent_dim
        self.njoints = njoints
        self.nfeats = nfeats
        self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
        if self.data_rep == 'rot_vel':
            self.velFinal = nn.Linear(self.latent_dim, self.input_feats)

    def forward(self, output):
        nframes, bs, d = output.shape
        if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
            output = self.poseFinal(output)  # [seqlen, bs, 150]
        elif self.data_rep == 'rot_vel':
            first_pose = output[[0]]  # [1, bs, d]
            first_pose = self.poseFinal(first_pose)  # [1, bs, 150]
            vel = output[1:]  # [seqlen-1, bs, d]
            vel = self.velFinal(vel)  # [seqlen-1, bs, 150]
            output = torch.cat((first_pose, vel), axis=0)  # [seqlen, bs, 150]
        else:
            raise ValueError
        output = output.reshape(nframes, bs, self.njoints, self.nfeats)
        output = output.permute(1, 2, 3, 0)  # [bs, njoints, nfeats, nframes]
        return output


class EmbedAction(nn.Module):
    def __init__(self, num_actions, latent_dim):
        super().__init__()
        self.action_embedding = nn.Parameter(torch.randn(num_actions, latent_dim))

    def forward(self, input):
        idx = input[:, 0].to(torch.long)  # an index array must be long
        output = self.action_embedding[idx]
        return output
    
class CrossAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(CrossAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        ori_v = values
        values, keys, query = values.permute(1,0,2),keys.permute(1,0,2), query.permute(1,0,2)
        query = torch.repeat_interleave(query, repeats=values.shape[1],dim=1)
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Attention mechanism
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        out =  out.permute(1,0,2) + ori_v
        return out

## Initializing Model

In [22]:
from diffusion.respace import SpacedDiffusion, space_timesteps
from utils.model_util import create_gaussian_diffusion
from diffusion.gaussian_diffusion import (
    GaussianDiffusion,
    get_named_beta_schedule,
    create_named_schedule_sampler,
    ModelMeanType,
    ModelVarType,
    LossType
)

model = DiffusionModel(njoints=79, nfeats=1, num_actions=1, translation=True, pose_rep="rot6d", glob=True, glob_rot=True).to("cuda")

sampler_name = 'uniform'
beta_scheduler = 'linear'
betas = get_named_beta_schedule(beta_scheduler, args.diffusion_steps)
diffusion = GaussianDiffusion(
            betas=betas,
            model_mean_type=ModelMeanType.EPSILON,
            model_var_type=ModelVarType.FIXED_SMALL,
            loss_type=LossType.MSE
        )
sampler = create_named_schedule_sampler(sampler_name, diffusion)

n_epoch = 10000
batch_size = 32
n_T = 1000 # 500
device = "cuda"
n_feature = 5
n_textemb = 512
lrate = 0.0001
save_model = True
save_dir = './weight/'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
    
optim = torch.optim.Adam(model.parameters(), lr=lrate)

CUTTING BACKBONE AT LAYER 8
EMBED TEXT
Loading CLIP...


## Model Training

In [None]:
from tqdm import tqdm

for ep in range(n_epoch):
    print(f'epoch {ep}')
    model.train()

    # linear lrate decay
    optim.param_groups[0]['lr'] = 0.0001*(1-ep/n_epoch)

    pbar = tqdm(train_dataloader)
    loss_ema = None
    for x in pbar:
        optim.zero_grad()
        
        motion = x[1].to(device).to(torch.float32)
        audio = x[0].to(device).to(torch.float32)
        
        img_dic = {"y":audio}
        
        c = torch.zeros([64,512])
        t, _ = sampler.sample(motion.shape[0], device)
        

        loss = diffusion.training_losses(
                                            model=model,
                                            x_start=motion,
                                            t=t,
                                            model_kwargs=img_dic
                                            #model_kwargs={}
                                        )["mse"].to(torch.float32)
        loss = torch.mean(loss)
        loss.backward()
        if loss_ema is None:
            loss_ema = loss.item()
        else:
            loss_ema = 0.95*loss_ema+0.05*loss.item()
        pbar.set_description(f"loss: {loss_ema:.4f}")
        optim.step()

    torch.save(model.state_dict(), save_dir + f"conductor_latest.pth")
    if save_model and ep % 100 == 0:
        torch.save(model.state_dict(), save_dir + f"conductor_model_{ep}.pth")
        print('saved model at ' + save_dir + f"conductor_model_{ep}.pth")

In [82]:
test_emb = audio_tensor[::1000].to("cuda").to(torch.float32)
test_emb.shape

torch.Size([20, 138, 1, 120])

In [24]:
test_audio_list=[]
A_Mean = np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Audio_Mean.npy")
A_Std = np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Audio_Std.npy")
for audio_file in os.listdir("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Mel_Clips/liu/"):
    if(audio_file.split("_")[1]=="1.npy"):
        test_audio_list.append((np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Mel_Clips/liu/"+audio_file)-A_Mean)/(A_Std+0.000000001))
        #print(audio_file)
test_audio_tensor = torch.tensor(np.array(test_audio_list))
#test_audio_tensor = F.pad(test_audio_tensor, (0, 74), mode="constant", value=0)
test_audio_tensor = test_audio_tensor.reshape(-1,1,120,64).permute(0,3,1,2)

In [25]:
test_beat_list=[]
B_Mean = np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Beat_Mean.npy")
B_Std = np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Beat_Std.npy")
for audio_file in os.listdir("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Beat_Clips/liu/"):
    if(audio_file.split("_")[1]=="1.npy"):
        test_beat_list.append((np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Beat_Clips/liu/"+audio_file)-B_Mean)/(B_Std+0.000000001))
        #print(audio_file)
test_beat_tensor = torch.tensor(np.array(test_beat_list))
#test_beat_tensor = F.pad(test_beat_tensor, (0, 33), mode="constant", value=0)
test_beat_tensor = test_beat_tensor.reshape(-1,1,120,15).permute(0,3,1,2)

In [26]:
test_full_tensor = torch.concatenate([test_audio_tensor,test_beat_tensor], axis=1)

In [27]:
model.load_state_dict(torch.load("weight/conductor_beat_mel_model_500.pth", map_location=torch.device('cuda')))



for i in range(10):
    
    test_emb = test_full_tensor[20*i:20*i+20].to("cuda").to(torch.float32)
    test_dic = {"y":test_emb}
    
    test_array = diffusion.p_sample_loop(
                model,
                (20, 79, 1,120),
                clip_denoised=False,
                progress=True,
                #model_kwargs={})
                model_kwargs=test_dic)
    break
    
    #np.save("Evaluation/beat_"+str(i)+".npy",np.array(test_array.cpu()))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [28]:
from scipy.ndimage import gaussian_filter1d

Mean = np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Motion_Mean.npy")
Std = np.load("../PoseEstimation/SMPLest-X/SMPLest-X-main/demo/CONDUCTOR_DATA/Motion_Std.npy")


#Mean = np.load("../Camera/Mix_Mean.npy")
#Std = np.load("../Camera/Mix_Std.npy")
#D_mean = np.load("../Camera/Mix_D_Mean.npy")
#D_std = np.load("../Camera/Mix_D_Std.npy")

test_data = test_array.cpu().detach().numpy().transpose(2,0,3,1)
motion1 = test_data[0][:,:,:48].reshape(-1,120,8,6)

m = motion1[4]

arr_full = np.zeros((120, 23, 6))

target_idx = [13, 14, 16, 17, 18, 19, 20, 21]   # 对应 13,14,16,17,18,19,20,21

# 1) 把 (120,8,6) 填进去
for i, t in enumerate(target_idx):
    arr_full[:, t, :] = m[:, i, :]
    
arr_full = arr_full*(Std)+Mean

# 2) 其他位置用 (23,6) 的值补充
for j in range(23):
    if j not in target_idx:
        arr_full[:, j, :] = Mean[j]

d = np.array([1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0])
np.save("../Camera/generated_data/new_joint_vecs/train/0_p0.npy",arr_full)
np.save("../Camera/generated_data/new_joint_vecs/train/0_p1.npy",arr_full)
np.save("../Camera/generated_data/canon_data/train/0_p0.npy",d)
np.save("../Camera/generated_data/canon_data/train/0_p1.npy",d)