In [35]:
import torch
from dataloaders.beat import CustomDataset
from dataloaders.build_vocab import Vocab
import pickle
import numpy as np

config_file = open("gesturegen_config.obj", 'rb') 
args = pickle.load(config_file)
args.batch_size = 16

mean_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy"))
std_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy"))
mean_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_mean.npy"))
std_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_std.npy"))

In [36]:
torch.cuda.is_available()

True

In [37]:
train_data = CustomDataset(args, "train")
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=args.batch_size,  
    shuffle=True,  
    drop_last=True,
)

In [38]:
len(train_loader)

14857

In [39]:
val_data = CustomDataset(args, "val")
val_loader = torch.utils.data.DataLoader(
    val_data, 
    batch_size=args.batch_size,  
    shuffle=True,  
    drop_last=True,
)

In [40]:
len(val_loader)

2972

In [56]:
data = next(iter(train_loader))
facial = data['facial'] * std_facial.float() + mean_facial.float()
to_flame = torch.from_numpy(np.load('mat_final.npy'))
facial_vertex = facial @ to_flame
print(facial.min(), facial.max(), facial.mean(), facial.std())
print(facial_vertex.min(), facial_vertex.max(), facial_vertex.mean(), facial_vertex.std())

tensor(-7.4506e-09) tensor(0.9524) tensor(0.1426) tensor(0.1631)
tensor(-6.0584) tensor(7.1772) tensor(-0.0174) tensor(1.1267)


### Test A2BS SimpleNet

In [57]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from scripts.Dataset import a2bsDataset
from scripts.MulticontextNet import GestureGen
#import wandb
import uuid

In [58]:
# project = "testing"
# group = "facegenerator"
# name = "test"
# dataset_name = "beat"
# entity = "hm_gesture"
# run = wandb.init(
#         project="testing",
#         group="facegenerator",
#         name= f"{name}-{dataset_name}-{str(uuid.uuid4())[:8]}",
#         id=str(uuid.uuid4()),
#         entity=entity,
#     )
# wandb.run.save()

In [60]:
model_path = 'ckpt_model/multicontextnet-all-vertex-2024/multicontextnet-300.pth'
net = GestureGen(args).cuda()
net.load_state_dict(torch.load(model_path))
optimizer = torch.optim.Adam( net.parameters(), lr=1e-4)#, weight_decay=1e-5)
train_target_loss = []
train_expressive_loss = []
train_smooth_loss = []
train_mse_loss = []
val_target_loss = []
val_expressive_loss = []
val_smooth_loss = []
val_mse_loss = []

def plot_train_val_loss():
    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    axs[0].plot(train_target_loss, 'r-')
    axs[0].set_title('Target Loss')
    axs[1].plot(train_expressive_loss, 'p-')
    axs[1].set_title('Expressive Loss')
    axs[2].plot(train_smooth_loss, 'g-')
    axs[2].set_title('Smooth Loss')
    axs[3].plot(train_mse_loss, 'b-')
    axs[3].set_title('MSE Loss')
    fig.suptitle('Training Iterations', fontsize = 16)
    plt.tight_layout()
    plt.show()

    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    axs[0].plot(val_target_loss, 'r-')
    axs[0].set_title('Target Loss')
    axs[1].plot(val_expressive_loss, 'p-')
    axs[1].set_title('Expressive Loss')
    axs[2].plot(val_smooth_loss, 'g-')
    axs[2].set_title('Smooth Loss')
    axs[3].plot(val_mse_loss, 'b-')
    axs[3].set_title('MSE Loss')
    fig.suptitle('Validation Iterations', fontsize = 16)
    plt.tight_layout()
    plt.show()

In [61]:
print(len(train_data))
data = next(iter(train_loader))
in_audio = data['audio']
facial = data['facial']
in_id = data["id"]
in_word = data["word"]
in_emo = data["emo"]

237714


In [62]:
def expressive_loss_function(output, target): # max squared error over blendshape for each frame, then take the mean
    loss = torch.mean(torch.max((output - target) ** 2, dim=-1).values)
    return loss
# a = torch.tensor([[[1,2,3],[1,2,3]],[[4,5,6],[4,5,6]]]).float()
# b = torch.tensor([[[3,3,4],[2,3,4]],[[5,6,7],[5,6,7]]]).float()
# print(a.shape)
# max_square_error(a, b)

In [None]:
from tqdm import tqdm
num_epochs = 700
log_period = 200
val_period = 600
val_size = 25
to_flame = torch.from_numpy(np.load('mat_final.npy')).cuda()
target_loss_function = torch.nn.HuberLoss()
smooth_loss_function = torch.nn.CosineSimilarity(dim=2)
mse_loss_function = torch.nn.MSELoss()
target_weight = 1.5
expressive_weight = 0.5
smooth_weight = 0.5

for epoch in range(num_epochs):
    for it, data in enumerate(tqdm(train_loader)):
        net.train()
        in_audio = data['audio']
        facial = data['facial']
        in_id = data["id"]
        in_word = data["word"]
        in_emo = data["emo"]

        in_audio = in_audio.cuda()
        facial = facial.cuda()
        in_id = in_id.cuda()
        in_word = in_word.cuda()
        in_emo = in_emo.cuda()

        pre_frames = 4
        in_pre_face = facial.new_zeros((facial.shape[0], facial.shape[1], facial.shape[2] + 1)).cuda()
        in_pre_face[:, 0:pre_frames, :-1] = facial[:, 0:pre_frames]
        in_pre_face[:, 0:pre_frames, -1] = 1 
        
        optimizer.zero_grad()
        out_face = net(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)
        target_loss = target_loss_function(out_face@to_flame, facial@to_flame) + target_loss_function(out_face[:,:,6:14], facial[:,:,6:14]) # to account for eye movement
        expressive_loss = expressive_loss_function(out_face, facial)
        smooth_loss = 1 - smooth_loss_function(out_face[:,:-1,:], out_face[:,1:,:]).mean()
        loss = target_weight * target_loss  + smooth_weight * smooth_loss# + expressive_weight * expressive_loss
        loss.backward()
        optimizer.step()
        
        train_target_loss.append(target_loss.item())
        train_expressive_loss.append(expressive_loss.item())
        train_smooth_loss.append(smooth_loss.item())
        train_mse_loss.append(mse_loss_function(out_face.cpu()*std_facial+mean_facial, facial.cpu()*std_facial+mean_facial).item())
        
        #logging
        #if it % log_period == 0:
        #    print(f'[{epoch}][{it}/{len(train_loader)}]: [train] [target loss]: {train_target_loss[-1]} [exp loss]: {train_expressive_loss[-1]} [smooth loss]: {train_smooth_loss[-1]} [mse]: {train_mse_loss[-1]}')
        
        if it % val_period == 0:
            net.eval()
            val_target_loss_st = []
            val_expressive_loss_st = []
            val_smooth_loss_st = []
            val_mse_loss_st = []
            val_cnt = 0

            val_loader = torch.utils.data.DataLoader(
                val_data, 
                batch_size=args.batch_size,  
                shuffle=True,  
                drop_last=True,
            )
            
            for _, data in enumerate(val_loader):
                in_audio = data['audio']
                facial = data['facial']
                in_id = data["id"]
                in_word = data["word"]
                in_emo = data["emo"]

                in_audio = in_audio.cuda()
                facial = facial.cuda()
                in_id = in_id.cuda()
                in_word = in_word.cuda()
                in_emo = in_emo.cuda()

                pre_frames = 4
                in_pre_face = facial.new_zeros((facial.shape[0], facial.shape[1], facial.shape[2] + 1)).cuda()
                in_pre_face[:, 0:pre_frames, :-1] = facial[:, 0:pre_frames]
                in_pre_face[:, 0:pre_frames, -1] = 1 

                out_face = net(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)
                target_loss = target_loss_function(out_face@to_flame, facial@to_flame) + target_loss_function(out_face[:,:,6:14], facial[:,:,6:14])
                expressive_loss = expressive_loss_function(out_face, facial)
                smooth_loss = 1 - smooth_loss_function(out_face[:,:-1,:], out_face[:,1:,:]).mean()

                val_target_loss_st.append(target_loss.item())
                val_expressive_loss_st.append(expressive_loss.item())
                val_smooth_loss_st.append(smooth_loss.item())
                val_mse_loss_st.append(mse_loss_function(out_face.cpu()*std_facial+mean_facial, facial.cpu()*std_facial+mean_facial).item())
                
                
                val_cnt += 1
                if val_cnt >= val_size:
                    break
            
            val_target_loss.append(np.average(val_target_loss_st))
            val_expressive_loss.append(np.average(val_expressive_loss_st))
            val_smooth_loss.append(np.average(val_smooth_loss_st))
            val_mse_loss.append(np.average(val_mse_loss_st))
            #print(f'[{epoch}][{it}/{len(train_loader)}]: [val] [target loss]: {val_target_loss[-1]} [exp loss]: {val_expressive_loss[-1]} [smooth loss]: {val_smooth_loss[-1]} [mse]: {val_mse_loss[-1]}')
    plot_train_val_loss()
        

In [None]:
torch.save(net.state_dict(), 'ckpt_model/multicontextnet100.pth')

In [None]:
plot_train_val_loss()

### Testing

In [123]:
from pythonosc import udp_client
import time
import sounddevice as sd
import torch
from dataloaders.beat import CustomDataset
from dataloaders.build_vocab import Vocab
import pickle
import numpy as np

config_file = open("gesturegen_config.obj", 'rb') 
args = pickle.load(config_file)
args.batch_size = 16

mean_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy")).float()
std_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy")).float()
mean_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_mean.npy")).float()
std_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_std.npy")).float()
to_flame = torch.from_numpy(np.load('mat_final.npy')).float()

In [124]:
test_data = CustomDataset(args, "test")
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=1,  
    shuffle=True,  
    drop_last=False,
)

In [125]:
data = next(iter(test_loader))

In [126]:
audio = data['audio']
facial = data['facial']
id = data["id"]
word = data["word"]
emo = data["emo"]

In [127]:
out_facial = facial * std_facial + mean_facial
out_audio = audio * std_audio + mean_audio

In [128]:
# Try playing the audio, which is at 16KHZ
print(out_facial.min(), out_facial.max())
print(out_facial.std(), out_facial.mean())

tensor(-7.4506e-09) tensor(0.9209)
tensor(0.1671) tensor(0.1659)


In [129]:
out_audio.shape

torch.Size([1, 1072000])

In [130]:
blend =  [
        "browDownLeft",
        "browDownRight",
        "browInnerUp",
        "browOuterUpLeft",
        "browOuterUpRight",
        "cheekPuff",
        "cheekSquintLeft",
        "cheekSquintRight",
        "eyeBlinkLeft",
        "eyeBlinkRight",
        "eyeLookDownLeft",
        "eyeLookDownRight",
        "eyeLookInLeft",
        "eyeLookInRight",
        "eyeLookOutLeft",
        "eyeLookOutRight",
        "eyeLookUpLeft",
        "eyeLookUpRight",
        "eyeSquintLeft",
        "eyeSquintRight",
        "eyeWideLeft",
        "eyeWideRight",
        "jawForward",
        "jawLeft",
        "jawOpen",
        "jawRight",
        "mouthClose",
        "mouthDimpleLeft",
        "mouthDimpleRight",
        "mouthFrownLeft",
        "mouthFrownRight",
        "mouthFunnel",
        "mouthLeft",
        "mouthLowerDownLeft",
        "mouthLowerDownRight",
        "mouthPressLeft",
        "mouthPressRight",
        "mouthPucker",
        "mouthRight",
        "mouthRollLower",
        "mouthRollUpper",
        "mouthShrugLower",
        "mouthShrugUpper",
        "mouthSmileLeft",
        "mouthSmileRight",
        "mouthStretchLeft",
        "mouthStretchRight",
        "mouthUpperUpLeft",
        "mouthUpperUpRight",
        "noseSneerLeft",
        "noseSneerRight",
        "tongueOut"
    ]

In [131]:
def play_audio(out_audio, init_time):
    time.sleep(init_time - time.time())
    sd.play(out_audio, 16000)
    sd.wait()
    print("Audio finished:", time.time())

In [132]:

def send_udp(out_face, init_time):
    #outWeight = np.zeros(52)

    ##need to implement get value in
    outWeight = out_face

    outWeight = outWeight * (outWeight >= 0)

    client = udp_client.SimpleUDPClient('127.0.0.1', 5008)
    osc_array = outWeight.tolist()
    
    fps = 15
    time.sleep(init_time - time.time())
    #start_time = time.time()
    for i in range(len(osc_array)):
        #print(out_face[i].shape)
        for j, out in enumerate(osc_array[i]):
            client.send_message('/' + str(blend[j]), out)

        elpased_time = time.time() - init_time
        sleep_time = 1.0/fps * (i+1) - elpased_time
        if sleep_time > 0:
            time.sleep(sleep_time)
        #start_time = time.time()
    print("Facial finished:", time.time())

In [133]:
import threading

init_time = time.time() + 1

limit_sec = 20

udp_thread = threading.Thread(target=send_udp, args=(out_facial[0, 0:limit_sec*15],init_time))
udp_thread.daemon = True  # Set the thread as a daemon to allow it to exit when the main program exits

audio_thread = threading.Thread(target=play_audio, args=(out_audio[0, 0:limit_sec*16000],init_time-0.3))
audio_thread.daemon = True

udp_thread.start()
audio_thread.start()

udp_thread.join()
audio_thread.join()

Audio finished: 1717109722.657436
Facial finished: 1717109722.8734372


In [134]:
print(len(out_audio[0])/16000, len(out_facial[0])/15)

67.0 67.0


In [160]:
 # load in model
from scripts.MulticontextNet import GestureGen
model_path = 'ckpt_model/multicontextnet-all-vertex-2024/multicontextnet-300.pth'
#model_path = 'ckpt_model/multicontextnet100.pth'
net = GestureGen(args)
net.load_state_dict(torch.load(model_path))
net = net.cuda().eval()

In [204]:
in_audio = audio.cuda()
in_facial = facial.cuda() * std_facial.float().cuda() + mean_facial.float().cuda()
in_id = id.cuda()
in_word = word.cuda()
#in_emo = emo.cuda()
in_emo = torch.zeros_like(emo) + 1
in_emo = in_emo.cuda()
pre_frames = 4
in_pre_facial = in_facial.new_zeros((in_facial.shape[0], in_facial.shape[1], in_facial.shape[2] + 1)).cuda() 
in_pre_facial[:, 0:pre_frames, :-1] = in_facial[:, 0:pre_frames]
in_pre_facial[:, 0:pre_frames, -1] = 1 

In [205]:
pred_facial = net(in_pre_facial, in_audio=in_audio, in_text=in_word, in_id=in_id, in_emo=in_emo).cpu().detach()
#pred_facial = pred_facial * std_facial + mean_facial

In [173]:
print(in_facial.cpu().min(), in_facial.cpu().max())
print(in_facial.cpu().std(), in_facial.cpu().mean())
print(pred_facial.min(), pred_facial.max())
print(pred_facial.std(), pred_facial.mean())

tensor(-7.4506e-09) tensor(0.9209)
tensor(0.1671) tensor(0.1659)
tensor(-0.7328) tensor(1.2204)
tensor(0.1933) tensor(0.1581)


In [207]:
import threading

init_time = time.time() + 1

limit_sec = 20

udp_thread = threading.Thread(target=send_udp, args=(pred_facial[0,0:limit_sec*15],init_time))
udp_thread.daemon = True  # Set the thread as a daemon to allow it to exit when the main program exits

audio_thread = threading.Thread(target=play_audio, args=(out_audio[0,0:limit_sec*16000],init_time-0.3))
audio_thread.daemon = True

udp_thread.start()
audio_thread.start()

udp_thread.join()
audio_thread.join()

Audio finished: 1717110581.9901443
Facial finished: 1717110582.2231467


In [206]:
print(expressive_loss_function(pred_facial, out_facial))
print(torch.nn.functional.mse_loss(pred_facial, out_facial))
print(torch.nn.functional.mse_loss(pred_facial@to_flame, out_facial@to_flame))

tensor(0.1283)
tensor(0.0159)
tensor(0.3069)


In [147]:
print(expressive_loss_function(pred_facial, out_facial))
print(torch.nn.functional.mse_loss(pred_facial, out_facial))
print(torch.nn.functional.mse_loss(pred_facial@to_flame, out_facial@to_flame))

tensor(0.1502)
tensor(0.0188)
tensor(0.3199)


In [None]:
print(pred_facial.shape)
print(out_facial.shape)

In [None]:
print(pred_facial[:,0,:])

In [None]:
print(out_facial[:,0,:])