In [24]:
from pythonosc import udp_client
import time
import sounddevice as sd
import torch
from dataloaders.beat import CustomDataset
from dataloaders.build_vocab import Vocab
import pickle
import numpy as np
import os
from utils.other_tools import load_checkpoints
from models.camn import CaMN

camn_config_file = open("camn_config.obj", 'rb') 
gesturegen_config_file = open("gesturegen_config.obj", 'rb')

gesturegen_args = pickle.load(gesturegen_config_file)
camn_args = pickle.load(camn_config_file)

mean_facial = torch.from_numpy(np.load(camn_args.root_path+camn_args.mean_pose_path+f"{camn_args.facial_rep}/json_mean.npy")).float()
std_facial = torch.from_numpy(np.load(camn_args.root_path+camn_args.mean_pose_path+f"{camn_args.facial_rep}/json_std.npy")).float()
mean_audio = torch.from_numpy(np.load(camn_args.root_path+camn_args.mean_pose_path+f"{camn_args.audio_rep}/npy_mean.npy")).float()
std_audio = torch.from_numpy(np.load(camn_args.root_path+camn_args.mean_pose_path+f"{camn_args.audio_rep}/npy_std.npy")).float()
mean_pose = torch.from_numpy(np.load(camn_args.root_path+camn_args.mean_pose_path+f"{camn_args.pose_rep}/bvh_mean.npy")).float()
std_pose = torch.from_numpy(np.load(camn_args.root_path+camn_args.mean_pose_path+f"{camn_args.pose_rep}/bvh_std.npy")).float()

In [25]:
test_data = CustomDataset(camn_args, "test")
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=1,  
    shuffle=False,  
    drop_last=False,
)

batch_size = 1
solo_speaker = 17

for its, template in enumerate(test_loader):
    if template['id'][0] == solo_speaker:
        break

test_demo = camn_args.root_path + camn_args.test_data_path + f"{camn_args.pose_rep}_vis/"
test_seq_list = os.listdir(test_demo)
test_seq_list.sort()

template_file = test_seq_list[its]
print(template_file)

18_daiki_0_103_a.bvh


In [26]:
 # load in facial model
from scripts.MulticontextNet import GestureGen
model_path = 'tmp/multicontextnet-no-text.pth'
net = GestureGen(gesturegen_args)
net.load_state_dict(torch.load(model_path))
net = net.cuda().eval()

In [27]:
# load in test_audio
import librosa
test_audio_file = 'test_audio/gandhi-speech.wav'
test_audio_raw, sr = librosa.load(test_audio_file, sr=None) # np array
test_audio = librosa.resample(test_audio_raw, orig_sr=sr, target_sr=16000)#test_audio_raw[::sr//16000] # convert to 16khz
print('Original sample rate:', sr)
out_audio = torch.from_numpy(test_audio).unsqueeze(0)
audio = (out_audio - mean_audio) / std_audio

Original sample rate: 22050


In [28]:
limit_sec = 10
sd.play(test_audio[0:limit_sec*16000], 16000)
sd.wait()
print("Audio finished:", time.time())

Audio finished: 1718228446.8094735


In [30]:
in_audio = audio.expand(batch_size, -1).cuda()
in_id = torch.zeros((batch_size, 1)).int().cuda()
in_id[0] = solo_speaker
# for i in range(batch_size):
#    in_id[i] = i
in_emo = torch.zeros((batch_size, in_audio.shape[1]//16000*15)).int() + 0
in_emo = in_emo.cuda()
pre_frames = 4
in_pre_facial = torch.zeros((batch_size,in_audio.shape[1]//16000*15, 52)).float().cuda()
in_pre_facial[:, 0:pre_frames, :-1] = template['facial'][:, 0:pre_frames]
in_pre_facial[:, 0:pre_frames, -1] = 1 

In [31]:
pred_facial = net(in_pre_facial, in_audio=in_audio, in_id=in_id, in_emo=in_emo).cpu().detach()
print(pred_facial.shape)

torch.Size([1, 5580, 51])


In [32]:
 # load in model
model_path = os.path.join(camn_args.root_path, 'datasets/beat_cache/beat_4english_15_141/weights/camn.bin')
camn_model = CaMN(camn_args)
load_checkpoints(camn_model, camn_args.root_path+camn_args.test_ckpt, camn_args.g_name)
camn_model = camn_model.cuda().eval()

[32m2024-06-12 14:40:57.897[0m | [1mINFO    [0m | [36mutils.other_tools[0m:[36mload_checkpoints[0m:[36m96[0m - [1mload self-pretrained checkpoints for CaMN[0m


In [33]:
in_facial = pred_facial.cuda()

pre_frames = 4
pre_pose = torch.zeros((batch_size, in_facial.shape[1], gesturegen_args.pose_dims + 1)).cuda()
pre_pose[:, 0:pre_frames, :-1] = template['pose'][:, 0:pre_frames]
pre_pose[:, 0:pre_frames, -1] = 1

in_audio = in_audio.reshape(1, -1)

print(pre_pose.shape, in_facial.shape, in_audio.shape, in_id.shape, in_emo.shape)

torch.Size([1, 5580, 142]) torch.Size([1, 5580, 51]) torch.Size([1, 5954075]) torch.Size([1, 1]) torch.Size([1, 5580])


In [34]:
out_dir_vec = camn_model(pre_seq=pre_pose, in_audio=in_audio, in_facial=in_facial, in_id=in_id, in_emo=in_emo)
out_final = np.array((out_dir_vec.cpu().detach().reshape(-1, camn_args.pose_dims) * std_pose) + mean_pose)

In [35]:
out_final.shape

(5580, 141)

In [46]:
from joints_list import JOINTS_LIST

def result2target_vis(template_file, bvh_file, res_frames, save_path):
    ori_list = JOINTS_LIST["beat_joints"]
    target_list = JOINTS_LIST["spine_neck_141"]
    file_content_length = 431

    template_file_path = f"{camn_args.root_path}/datasets/beat_cache/beat_4english_15_141/test/bvh_rot_vis/{template_file}"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    short_name = bvh_file.split("\\")[-1][11:]
    save_file_path = os.path.join(save_path, f'res_{short_name}')
    
    write_file = open(save_file_path,'w+')
    with open(template_file_path,'r') as pose_data_pre:
        pose_data_pre_file = pose_data_pre.readlines()
        ori_lines = pose_data_pre_file[:file_content_length]
        offset_data = np.fromstring(pose_data_pre_file[file_content_length], dtype=float, sep=' ')
    write_file.close()

    ori_lines[file_content_length-2] = 'Frames: ' + str(res_frames) + '\n'

    write_file = open(os.path.join(save_path, f'res_{short_name}'),'w+')
    write_file.writelines(i for i in ori_lines[:file_content_length])    
    write_file.close() 

    with open(save_file_path,'a+') as write_file: 
        with open(bvh_file, 'r') as pose_data:
            data_each_file = []
            pose_data_file = pose_data.readlines()
            for j, line in enumerate(pose_data_file):
                if not j:
                    pass
                else:          
                    data = np.fromstring(line, dtype=float, sep=' ')
                    data_rotation = offset_data.copy()   
                    for iii, (k, v) in enumerate(target_list.items()): # here is 147 rotations by 3
                        data_rotation[ori_list[k][1]-v:ori_list[k][1]] = data[iii*3:iii*3+3]
                    data_each_file.append(data_rotation)
    
        for line_data in data_each_file:
            line_data = np.array2string(line_data, max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
            write_file.write(line_data[1:-2]+'\n')


In [48]:
res_file = os.path.join("result_pose",f"result_raw_ghandhi_speech.bvh")

with open(res_file, 'w+') as f_real:
    for line_id in range(out_final.shape[0]): #,args.pre_frames, args.pose_length
        line_data = np.array2string(out_final[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
        f_real.write(line_data[1:-2]+'\n')  
res_frames = out_final.shape[0] - 1
result2target_vis(template_file, res_file, res_frames, 'result_pose/')