In [None]:
from dataloaders.beat import CustomDataset
from dataloaders.build_vocab import Vocab
import pickle
import numpy as np
from utils import config
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from scripts.MulticontextNet import GestureGen, ConvDiscriminator
from scripts.Logger import Logger
import wandb
import random
import uuid
from tqdm import tqdm
import trimesh
from blendshapes import BLENDSHAPE_NAMES
from time import time

config_file = open("gesturegen_config.obj", 'rb') 
args = pickle.load(config_file)
args.batch_size = 16
args.continue_training = False
args.no_adv_epochs = 1

In [None]:
class Trainer():
    def __init__(self, args, device, train_data, val_data, model, d_model, logger):
        # Set up data loading
        self.mean_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy")).float()
        self.std_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy")).float()
        self.mean_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_mean.npy")).float()
        self.std_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_std.npy")).float()

        self.batch_size = args.batch_size
        self.train_data = train_data
        self.train_loader = torch.utils.data.DataLoader(
            train_data, 
            batch_size=self.batch_size,  
            shuffle=True,  
            drop_last=True,
        )
        self.val_data = val_data

        # Set up model and loss functions
        self.no_text = args.no_text
        self.model = model.to(device) # generative model
        self.d_model = d_model.to(device) # discriminator
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.d_optimizer = torch.optim.Adam(self.d_model.parameters(), lr=1e-4)
        self.target_loss_function = torch.nn.HuberLoss()
        self.smooth_loss_function = torch.nn.CosineSimilarity(dim=2)
        self.mse_loss_function = torch.nn.MSELoss()

        # Set up blendshape to flame parameters
        self.bs_to_flame = torch.from_numpy(np.load('mat_final.npy')).to(device)
        self.flame_to_bs = self.bs_to_flame.pinverse()
        self.predict_flame = args.predict_flame
        self.bs2vertices = args.bs2vertices
        self.normalize_face = args.normalize_face
        self.pre_frames = args.pre_frames

        # Set up blendshape to vertices
        if self.bs2vertices:
            self.V_factor = 100
            self.V_basis = torch.tensor(trimesh.load('bs/Basis.obj').vertices, dtype=torch.float32) * self.V_factor
            self.V_bs = torch.stack([torch.tensor(trimesh.load(f'bs/exp/{bs_name}.obj').vertices, dtype=torch.float32) for bs_name in BLENDSHAPE_NAMES[:51]]) * self.V_factor
            self.V_deltas = (self.V_bs - self.V_basis.unsqueeze(0)).unsqueeze(0).unsqueeze(0).to(device)

        # Set up training/validation parameters
        self.epochs = args.epochs
        self.target_weight = args.target_weight
        self.smooth_weight = args.smooth_weight
        self.expressive_weight = args.expressive_weight
        self.adv_weight = args.adv_weight
        self.val_size = args.val_size
        self.log_period = args.log_period
        self.val_period = args.val_period
        self.no_adv_epochs = args.no_adv_epochs

        # Set up logging
        self.logger = logger
        self._iter = 0
        self._ep_idx = 0
        self._start_time = time()

        # Checkpoint
        self.ckpt_exp_dir = f'{args.wandb_project}-{args.wandb_group}-{str(args.random_seed)}'
        self.save_period = args.save_period
        self.save_ckpt = args.save_ckpt
        self.ckpt_dir = args.ckpt_dir
        self.ckpt_path = os.path.join(self.ckpt_dir, self.ckpt_exp_dir)
        if self.save_ckpt:
            if not os.path.exists(self.ckpt_dir):
                os.makedirs(self.ckpt_dir)
            if not os.path.exists(self.ckpt_path):
                os.makedirs(self.ckpt_path)

    def common_metrics(self):
        """Return a dictionary of current metrics."""
        return dict(
            iteration=self._iter,
            epoch=self._ep_idx,
            total_time=time() - self._start_time,
        )

    def save_checkpoint(self, name):
        if self.save_ckpt:
            pth_path = os.path.join(self.ckpt_path, f'multicontextnet-{name}.pth')
            torch.save(self.model.state_dict(), pth_path)
            pth_path = os.path.join(self.ckpt_path, f'd-{name}.pth')
            torch.save(self.model.state_dict(), pth_path)
    
    def expressive_loss_function(self, output, target): # max squared error over blendshape for each frame, then take the mean
        loss = torch.mean(torch.max((output - target) ** 2, dim=-1).values)
        return loss

    def val(self):
        self.model.eval()
        val_target_loss_st = []
        if self.bs2vertices:
            val_V_smooth_loss_st = []
        if self.predict_flame:
            val_flame_smooth_loss_st = []
        val_bs_smooth_loss_st = []
        val_bs_expressive_loss_st = []
        val_cnt = 0

        val_loader = torch.utils.data.DataLoader(
            self.val_data, 
            batch_size=self.batch_size,  
            shuffle=True,  
            drop_last=True,
        )
        
        for _, data in enumerate(val_loader):
            in_audio = data['audio']
            if self.normalize_face:
                bs_facial = data['facial']
            else:
                bs_facial = data['facial'] * self.std_facial + self.mean_facial
            
            in_id = data["id"]
            if self.no_text == False:
                in_word = data["word"]
            in_emo = data["emo"]

            in_audio = in_audio.cuda()
            bs_facial = bs_facial.cuda()
            in_id = in_id.cuda()
            if self.no_text == False:
                in_word = in_word.cuda()
            in_emo = in_emo.cuda()
            
            if self.predict_flame:
                flame_facial = torch.cat((bs_facial @ self.bs_to_flame, bs_facial[:,:,6:14]), dim=-1)
                in_pre_face = flame_facial.new_zeros((flame_facial.shape[0], flame_facial.shape[1], flame_facial.shape[2])).cuda()
                in_pre_face[:, 0:self.pre_frames] = flame_facial[:, 0:self.pre_frames]

                flame_out_face = self.model(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)
                target_loss = self.target_loss_function(flame_out_face, flame_facial)
                flame_smooth_loss = 1 - self.smooth_loss_function(flame_out_face[:,:-1,:], flame_out_face[:,1:,:]).mean()

                val_target_loss_st.append(target_loss.item())
                val_flame_smooth_loss_st.append(flame_smooth_loss.item()) 

            else: 
                in_pre_face = bs_facial.new_zeros((bs_facial.shape[0], bs_facial.shape[1], bs_facial.shape[2] + 1)).cuda()
                in_pre_face[:, 0:self.pre_frames, :-1] = bs_facial[:, 0:self.pre_frames]
                in_pre_face[:, 0:self.pre_frames, -1] = 1 

                if self.no_text:
                    bs_pred_face = self.model(in_pre_face,in_audio=in_audio, in_id=in_id, in_emo=in_emo)
                else:
                    bs_pred_face = self.model(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)

                if self.bs2vertices:
                    V_pred_face = torch.sum(bs_pred_face.unsqueeze(3).unsqueeze(4) * self.V_deltas, axis=2)
                    V_gt_face = torch.sum(bs_facial.unsqueeze(3).unsqueeze(4) * self.V_deltas, axis=2)
                    
                    target_loss = self.target_loss_function(V_pred_face, V_gt_face)
                    V_smooth_loss = 1 - self.smooth_loss_function(V_pred_face[:,:-1,:], V_pred_face[:,1:,:]).mean()

                    val_target_loss_st.append(target_loss.item())
                    val_V_smooth_loss_st.append(V_smooth_loss.item())
                else:
                    target_loss = self.target_loss_function(bs_pred_face, bs_facial)
                    bs_smooth_loss = 1 - self.smooth_loss_function(bs_pred_face[:,:-1,:], bs_pred_face[:,1:,:]).mean()
                    bs_expressive_loss = self.expressive_loss_function(bs_pred_face, bs_facial)

                    val_target_loss_st.append(target_loss.item())
                    val_bs_smooth_loss_st.append(bs_smooth_loss.item())
                    val_bs_expressive_loss_st.append(bs_expressive_loss.item())
            
            val_cnt += 1
            if val_cnt >= self.val_size:
                break
        if self.predict_flame:
            return {
                "target_loss": float(np.average(val_target_loss_st)),
                "flame_smooth_loss": float(np.average(val_flame_smooth_loss_st)),
            }
        else:
            if self.bs2vertices:
                return {
                    "target_loss": float(np.average(val_target_loss_st)),
                    "V_smooth_loss": float(np.average(val_V_smooth_loss_st)),
                }
            else:
                return {
                    "target_loss": float(np.average(val_target_loss_st)),
                    "smooth_loss": float(np.average(val_bs_smooth_loss_st)),
                    "expressive_loss": float(np.average(val_bs_expressive_loss_st)),
                }

    def train(self):
        train_metrics = {}
        for self._ep_idx in range(self.epochs):
            use_adv = bool(self._ep_idx>=self.no_adv_epochs)
            for it, data in enumerate(self.train_loader):
                self.model.train()
                in_audio = data['audio']
                if self.normalize_face:
                    bs_facial = data['facial']
                else:
                    bs_facial = data['facial'] * self.std_facial + self.mean_facial
                in_id = data["id"]
                if self.no_text == False:
                    in_word = data["word"]
                in_emo = data["emo"]

                in_audio = in_audio.cuda()
                bs_facial = bs_facial.cuda()
                in_id = in_id.cuda()
                if self.no_text == False:
                    in_word = in_word.cuda()
                in_emo = in_emo.cuda()
                
                if self.predict_flame:
                    flame_facial = torch.cat((bs_facial @ self.bs_to_flame, bs_facial[:,:,6:14]), dim=-1)
                    in_pre_face = flame_facial.new_zeros((flame_facial.shape[0], flame_facial.shape[1], flame_facial.shape[2])).cuda()
                    in_pre_face[:, 0:self.pre_frames] = flame_facial[:, 0:self.pre_frames]

                    self.optimizer.zero_grad()
                    flame_out_face = self.model(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)
                    target_loss = self.target_loss_function(flame_out_face, flame_facial)
                    flame_smooth_loss = 1 - self.smooth_loss_function(flame_out_face[:,:-1,:], flame_out_face[:,1:,:]).mean()
                    loss = self.target_weight * target_loss  + self.smooth_weight * flame_smooth_loss
                    loss.backward()
                    self.optimizer.step()

                    if it % self.log_period == 0:
                        train_metrics = {
                            "target_loss": float(target_loss.item()),
                            "flame_smooth_loss": float(flame_smooth_loss.item()),
                        }
                        train_metrics.update(self.common_metrics())
                        self.logger.log(train_metrics, 'train')
                        print(f'[{self._ep_idx}][{it}/{len(self.train_loader)}]: [train] [target loss]: {train_metrics["target_loss"]} flame smooth loss]: {train_metrics["flame_smooth_loss"]}')

                else: 
                    in_pre_face = bs_facial.new_zeros((bs_facial.shape[0], bs_facial.shape[1], bs_facial.shape[2] + 1)).cuda()
                    in_pre_face[:, 0:self.pre_frames, :-1] = bs_facial[:, 0:self.pre_frames]
                    in_pre_face[:, 0:self.pre_frames, -1] = 1 

                    # discriminator training
                    if use_adv:
                        self.d_optimizer.zero_grad()
                        if self.no_text:
                            bs_pred_face = self.model(in_pre_face,in_audio=in_audio, in_id=in_id, in_emo=in_emo)
                        else:
                            bs_pred_face = self.model(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)
                        out_d_fake = self.d_model(bs_pred_face)
                        out_d_real = self.d_model(bs_facial)
                        d_loss = torch.sum(-torch.mean(torch.log(out_d_real + 1e-8) + torch.log(1 - out_d_fake + 1e-8)))
                        d_loss.backward()
                        self.d_optimizer.step()

                    self.optimizer.zero_grad()
                    if self.no_text:
                        bs_pred_face = self.model(in_pre_face,in_audio=in_audio, in_id=in_id, in_emo=in_emo)
                    else:
                        bs_pred_face = self.model(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)

                    if self.bs2vertices:
                        V_pred_face = torch.sum(bs_pred_face.unsqueeze(3).unsqueeze(4) * self.V_deltas, axis=2)
                        V_gt_face = torch.sum(bs_facial.unsqueeze(3).unsqueeze(4) * self.V_deltas, axis=2)

                        target_loss = self.target_loss_function(V_pred_face, V_gt_face)
                        V_smooth_loss = 1 - self.smooth_loss_function(V_pred_face[:,:-1,:], V_pred_face[:,1:,:]).mean()
                        loss = self.target_weight * target_loss  + self.smooth_weight * V_smooth_loss
                    
                        loss.backward()
                        self.optimizer.step()

                        if it % self.log_period == 0:
                            train_metrics = {
                                "target_loss": float(target_loss.item()),
                                "V_smooth_loss": float(V_smooth_loss.item())
                            }
                            train_metrics.update(self.common_metrics())
                            self.logger.log(train_metrics, 'train')
                            print(f'[{self._ep_idx}][{it}/{len(self.train_loader)}]: [train] [target loss]: {train_metrics["target_loss"]} [vertices smooth loss]: {train_metrics["V_smooth_loss"]}')
                    else:
                        # generator training
                        adv_loss = None
                        target_loss = self.target_loss_function(bs_pred_face, bs_facial)
                        bs_expressive_loss = self.expressive_loss_function(bs_pred_face, bs_facial)
                        bs_smooth_loss = 1 - self.smooth_loss_function(bs_pred_face[:,:-1,:], bs_pred_face[:,1:,:]).mean()
                        loss = self.target_weight * target_loss + self.smooth_weight * bs_smooth_loss + self.expressive_weight * bs_expressive_loss

                        if use_adv:
                            dis_out = self.d_model(bs_pred_face)
                            adv_loss = -torch.mean(torch.log(dis_out + 1e-8)) # self.adv_loss(out_d_fake, real_gt) # here 1 is real
                            loss += self.adv_weight * adv_loss
                        
                        loss.backward()
                        self.optimizer.step()

                        if it % self.log_period == 0:
                            train_metrics.update({
                                "target_loss": float(target_loss.item()),
                                "smooth_loss": float(bs_smooth_loss.item()),
                                "expressive_loss": float(bs_expressive_loss.item()),
                            })
                            if use_adv:
                                train_metrics["dis_loss"] = float(d_loss.item())
                                train_metrics["adversarial_loss"] = float(adv_loss.item())
                            else:
                                train_metrics["dis_loss"] = None
                                train_metrics["adversarial_loss"] = None
                            train_metrics.update(self.common_metrics())
                            self.logger.log(train_metrics, 'train')
                            print(f'[{self._ep_idx}][{it}/{len(self.train_loader)}]: [train] [target loss]: {train_metrics["target_loss"]} [adv loss]: {train_metrics["adversarial_loss"]} [smooth loss]: {train_metrics["smooth_loss"]} [exp loss]: {train_metrics["expressive_loss"]} [dis loss]: {train_metrics["dis_loss"]}')
                if it % self.val_period == 0:
                    val_metrics = self.val()
                    val_metrics.update(self.common_metrics())
                    self.logger.log(val_metrics,'val')
                    if self.predict_flame:
                        print(f'[{self._ep_idx}][{it}/{len(self.train_loader)}]: [val] [target loss]: {val_metrics["target_loss"]} [flame smooth loss]: {val_metrics["flame_smooth_loss"]}')
                    else:
                        if self.bs2vertices:
                            print(f'[{self._ep_idx}][{it}/{len(self.train_loader)}]: [val] [target loss]: {val_metrics["target_loss"]} [vertices smooth loss]: {val_metrics["V_smooth_loss"]}')
                        else:
                            print(f'[{self._ep_idx}][{it}/{len(self.train_loader)}]: [val] [target loss]: {val_metrics["target_loss"]} [bs smooth loss]: {val_metrics["smooth_loss"]} [bs expressive loss]: {val_metrics["expressive_loss"]}')

                self._iter += 1
            if (self._ep_idx+1) % self.save_period == 0:
                self.save_checkpoint(str(self._ep_idx+1))
        self.logger.finish()

In [None]:
def set_seed(seed):
	"""Set seed for reproducibility."""
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)

set_seed(args.random_seed)
logger = Logger(args)
train_data = CustomDataset(args, "train")
val_data = CustomDataset(args, "val")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available() == True

model = GestureGen(args)
d_model = ConvDiscriminator(args)
if args.continue_training:
        model.load_state_dict(torch.load(args.pretrained_model))

trainer = Trainer(args, device, train_data, val_data, model, d_model, logger)

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from scripts.Dataset import a2bsDataset
from scripts.MulticontextNet import GestureGen
#import wandb
import uuid

In [None]:
#model_path = 'ckpt_model/multicontextnet-all-vertex-2024/multicontextnet-300.pth'
net = GestureGen(args).cuda()
#net.load_state_dict(torch.load(model_path))
optimizer = torch.optim.Adam( net.parameters(), lr=1e-4)#, weight_decay=1e-5)
train_target_loss = []
train_expressive_loss = []
train_smooth_loss = []
train_mse_loss = []
val_target_loss = []
val_expressive_loss = []
val_smooth_loss = []
val_mse_loss = []

def plot_train_val_loss():
    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    axs[0].plot(train_target_loss, 'r-')
    axs[0].set_title('Target Loss')
    axs[1].plot(train_expressive_loss, 'p-')
    axs[1].set_title('Expressive Loss')
    axs[2].plot(train_smooth_loss, 'g-')
    axs[2].set_title('Smooth Loss')
    axs[3].plot(train_mse_loss, 'b-')
    axs[3].set_title('MSE Loss')
    fig.suptitle('Training Iterations', fontsize = 16)
    plt.tight_layout()
    plt.show()

    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    axs[0].plot(val_target_loss, 'r-')
    axs[0].set_title('Target Loss')
    axs[1].plot(val_expressive_loss, 'p-')
    axs[1].set_title('Expressive Loss')
    axs[2].plot(val_smooth_loss, 'g-')
    axs[2].set_title('Smooth Loss')
    axs[3].plot(val_mse_loss, 'b-')
    axs[3].set_title('MSE Loss')
    fig.suptitle('Validation Iterations', fontsize = 16)
    plt.tight_layout()
    plt.show()

In [None]:
print(len(train_data))
data = next(iter(train_loader))
in_audio = data['audio']
facial = data['facial']
in_id = data["id"]
in_word = data["word"]
in_emo = data["emo"]

In [None]:
def expressive_loss_function(output, target): # max squared error over blendshape for each frame, then take the mean
    loss = torch.mean(torch.max((output - target) ** 2, dim=-1).values)
    return loss
# a = torch.tensor([[[1,2,3],[1,2,3]],[[4,5,6],[4,5,6]]]).float()
# b = torch.tensor([[[3,3,4],[2,3,4]],[[5,6,7],[5,6,7]]]).float()
# print(a.shape)
# max_square_error(a, b)

In [None]:
from tqdm import tqdm
num_epochs = 700
log_period = 200
val_period = 600
val_size = 25
bs_to_flame = torch.from_numpy(np.load('mat_final.npy'))
flame_to_bs = bs_to_flame.pinverse()
target_loss_function = torch.nn.HuberLoss()
smooth_loss_function = torch.nn.CosineSimilarity(dim=2)
mse_loss_function = torch.nn.MSELoss()
target_weight = 1.5
expressive_weight = 0.5
smooth_weight = 0.5

for epoch in range(num_epochs):
    for it, data in enumerate(tqdm(train_loader)):
        net.train()
        in_audio = data['audio']
        facial = data['facial'] * std_facial + mean_facial
        in_id = data["id"]
        in_word = data["word"]
        in_emo = data["emo"]

        in_audio = in_audio.cuda()
        in_facial = torch.cat((facial @ bs_to_flame, facial[:,:,6:14]), dim=-1)
        in_facial = in_facial.cuda()
        in_id = in_id.cuda()
        in_word = in_word.cuda()
        in_emo = in_emo.cuda()

        pre_frames = 4
        if args.predict_flame:
            in_pre_face = in_facial.new_zeros((in_facial.shape[0], in_facial.shape[1], in_facial.shape[2])).cuda()
            in_pre_face[:, 0:pre_frames] = in_facial[:, 0:pre_frames]
        else:
            in_pre_face = in_facial.new_zeros((in_facial.shape[0], in_facial.shape[1], in_facial.shape[2] + 1)).cuda()
            in_pre_face[:, 0:pre_frames, :-1] = in_facial[:, 0:pre_frames]
            in_pre_face[:, 0:pre_frames, -1] = 1 
        
        optimizer.zero_grad()
        out_face = net(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)
        target_loss = target_loss_function(out_face,in_facial)# + target_loss_function(out_face[:,:,6:14], in_facial[:,:,6:14]) # to account for eye movement
        #expressive_loss = expressive_loss_function(out_face, in_facial)
        smooth_loss = 1 - smooth_loss_function(out_face[:,:-1,:], out_face[:,1:,:]).mean()
        loss = target_weight * target_loss  + smooth_weight * smooth_loss# + expressive_weight * expressive_loss
        loss.backward()
        optimizer.step()
        
        train_target_loss.append(target_loss.item())
        #train_expressive_loss.append(expressive_loss.item())
        train_smooth_loss.append(smooth_loss.item())
        #train_mse_loss.append(mse_loss_function(out_face.cpu()*std_facial+mean_facial, facial.cpu()*std_facial+mean_facial).item())
        
        #logging
        #if it % log_period == 0:
        #    print(f'[{epoch}][{it}/{len(train_loader)}]: [train] [target loss]: {train_target_loss[-1]} [exp loss]: {train_expressive_loss[-1]} [smooth loss]: {train_smooth_loss[-1]} [mse]: {train_mse_loss[-1]}')
        
        if it % val_period == 0:
            net.eval()
            val_target_loss_st = []
            #val_expressive_loss_st = []
            val_smooth_loss_st = []
            val_mse_loss_st = []
            val_cnt = 0

            val_loader = torch.utils.data.DataLoader(
                val_data, 
                batch_size=args.batch_size,  
                shuffle=True,  
                drop_last=True,
            )
            
            for _, data in enumerate(val_loader):
                in_audio = data['audio']
                facial = data['facial'] * std_facial + mean_facial
                in_id = data["id"]
                in_word = data["word"]
                in_emo = data["emo"]

                in_audio = in_audio.cuda()
                in_facial = torch.cat((facial @ bs_to_flame, facial[:,:,6:14]), dim=-1)
                in_facial = in_facial.cuda()
                in_id = in_id.cuda()
                in_word = in_word.cuda()
                in_emo = in_emo.cuda()

                pre_frames = 4
                if args.predict_flame:
                    in_pre_face = in_facial.new_zeros((in_facial.shape[0], in_facial.shape[1], in_facial.shape[2])).cuda()
                    in_pre_face[:, 0:pre_frames] = in_facial[:, 0:pre_frames]
                else:
                    in_pre_face = in_facial.new_zeros((in_facial.shape[0], in_facial.shape[1], in_facial.shape[2] + 1)).cuda()
                    in_pre_face[:, 0:pre_frames, :-1] = in_facial[:, 0:pre_frames]
                    in_pre_face[:, 0:pre_frames, -1] = 1 

                out_face = net(in_pre_face,in_audio=in_audio,in_text=in_word, in_id=in_id, in_emo=in_emo)
                target_loss = target_loss_function(out_face, in_facial)# + target_loss_function(out_face[:,:,6:14], facial[:,:,6:14])
                #expressive_loss = expressive_loss_function(out_face, facial)
                smooth_loss = 1 - smooth_loss_function(out_face[:,:-1,:], out_face[:,1:,:]).mean()

                val_target_loss_st.append(target_loss.item())
                #val_expressive_loss_st.append(expressive_loss.item())
                val_smooth_loss_st.append(smooth_loss.item())
                #val_mse_loss_st.append(mse_loss_function(out_face.cpu()*std_facial+mean_facial, facial.cpu()*std_facial+mean_facial).item())
                
                
                val_cnt += 1
                if val_cnt >= val_size:
                    break
            
            val_target_loss.append(np.average(val_target_loss_st))
            #val_expressive_loss.append(np.average(val_expressive_loss_st))
            val_smooth_loss.append(np.average(val_smooth_loss_st))
            #val_mse_loss.append(np.average(val_mse_loss_st))
            #print(f'[{epoch}][{it}/{len(train_loader)}]: [val] [target loss]: {val_target_loss[-1]} [exp loss]: {val_expressive_loss[-1]} [smooth loss]: {val_smooth_loss[-1]} [mse]: {val_mse_loss[-1]}')
    plot_train_val_loss()
        

In [None]:
torch.save(net.state_dict(), 'ckpt_model/multicontextnet-flame-7.pth')

In [None]:
plot_train_val_loss()

### Testing

In [None]:
from pythonosc import udp_client
import time
import sounddevice as sd
import torch
from dataloaders.beat import CustomDataset
from dataloaders.build_vocab import Vocab
import pickle
import numpy as np

config_file = open("gesturegen_config.obj", 'rb') 
args = pickle.load(config_file)
args.batch_size = 16

mean_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy")).float()
std_facial = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy")).float()
mean_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_mean.npy")).float()
std_audio = torch.from_numpy(np.load(args.root_path+args.mean_pose_path+f"{args.audio_rep}/npy_std.npy")).float()
bs_to_flame = torch.from_numpy(np.load('mat_final.npy')).float()
flame_to_bs = bs_to_flame.pinverse()

In [None]:
test_data = CustomDataset(args, "test")
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=1,  
    shuffle=True,  
    drop_last=False,
)

In [None]:
data = next(iter(test_loader))

In [None]:
audio = data['audio']
facial = data['facial']
id = data["id"]
word = data["word"]
emo = data["emo"]

In [None]:
out_facial = facial * std_facial + mean_facial
out_audio = audio * std_audio + mean_audio

In [None]:
# Try playing the audio, which is at 16KHZ
print(out_facial.min(), out_facial.max())
print(out_facial.std(), out_facial.mean())

In [None]:
out_audio.shape

In [None]:
blend =  [
        "browDownLeft",
        "browDownRight",
        "browInnerUp",
        "browOuterUpLeft",
        "browOuterUpRight",
        "cheekPuff",
        "cheekSquintLeft",
        "cheekSquintRight",
        "eyeBlinkLeft",
        "eyeBlinkRight",
        "eyeLookDownLeft",
        "eyeLookDownRight",
        "eyeLookInLeft",
        "eyeLookInRight",
        "eyeLookOutLeft",
        "eyeLookOutRight",
        "eyeLookUpLeft",
        "eyeLookUpRight",
        "eyeSquintLeft",
        "eyeSquintRight",
        "eyeWideLeft",
        "eyeWideRight",
        "jawForward",
        "jawLeft",
        "jawOpen",
        "jawRight",
        "mouthClose",
        "mouthDimpleLeft",
        "mouthDimpleRight",
        "mouthFrownLeft",
        "mouthFrownRight",
        "mouthFunnel",
        "mouthLeft",
        "mouthLowerDownLeft",
        "mouthLowerDownRight",
        "mouthPressLeft",
        "mouthPressRight",
        "mouthPucker",
        "mouthRight",
        "mouthRollLower",
        "mouthRollUpper",
        "mouthShrugLower",
        "mouthShrugUpper",
        "mouthSmileLeft",
        "mouthSmileRight",
        "mouthStretchLeft",
        "mouthStretchRight",
        "mouthUpperUpLeft",
        "mouthUpperUpRight",
        "noseSneerLeft",
        "noseSneerRight",
        "tongueOut"
    ]

In [None]:
def play_audio(out_audio, init_time):
    time.sleep(init_time - time.time())
    sd.play(out_audio, 16000)
    sd.wait()
    print("Audio finished:", time.time())

In [None]:

def send_udp(out_face, init_time):
    #outWeight = np.zeros(52)

    ##need to implement get value in
    outWeight = out_face

    outWeight = outWeight * (outWeight >= 0)

    client = udp_client.SimpleUDPClient('127.0.0.1', 5008)
    osc_array = outWeight.tolist()
    
    fps = 15
    time.sleep(init_time - time.time())
    #start_time = time.time()
    for i in range(len(osc_array)):
        #print(out_face[i].shape)
        for j, out in enumerate(osc_array[i]):
            client.send_message('/' + str(blend[j]), out)

        elpased_time = time.time() - init_time
        sleep_time = 1.0/fps * (i+1) - elpased_time
        if sleep_time > 0:
            time.sleep(sleep_time)
        #start_time = time.time()
    print("Facial finished:", time.time())

In [None]:
import threading

init_time = time.time() + 1

limit_sec = 20

udp_thread = threading.Thread(target=send_udp, args=(out_facial[0, 0:limit_sec*15],init_time))
udp_thread.daemon = True  # Set the thread as a daemon to allow it to exit when the main program exits

audio_thread = threading.Thread(target=play_audio, args=(out_audio[0, 0:limit_sec*16000],init_time-0.3))
audio_thread.daemon = True

udp_thread.start()
audio_thread.start()

udp_thread.join()
audio_thread.join()

In [None]:
print(len(out_audio[0])/16000, len(out_facial[0])/15)

In [None]:
 # load in model
from scripts.MulticontextNet import GestureGen
#model_path = 'ckpt_model/multicontextnet-flame-7.pth'
model_path = 'tmp/multicontextnet-no-text.pth'
#model_path = 'ckpt_model/multicontextnet-bs2vertices-2024/multicontextnet-500.pth'
#model_path = 'ckpt_model/multicontextnet100.pth'
#args.predict_flame = False
args.normalize_face = False
net = GestureGen(args)
net.load_state_dict(torch.load(model_path))
net = net.cuda().eval()

In [None]:
in_audio = audio.cuda()
bs_facial = facial * std_facial.float() + mean_facial.float()
#flame_facial = torch.cat((bs_facial @ bs_to_flame, bs_facial[:,:,6:14]), dim=-1)
in_id = id.cuda()
in_word = word.cuda()
in_emo = emo.cuda()
#in_emo = torch.zeros_like(emo) + 1
#in_emo = in_emo.cuda()
pre_frames = 4
in_pre_facial = bs_facial.new_zeros((bs_facial.shape[0], bs_facial.shape[1], bs_facial.shape[2] + 1)).cuda() 
in_pre_facial[:, 0:pre_frames, :-1] = bs_facial[:, 0:pre_frames]
in_pre_facial[:, 0:pre_frames, -1] = 1 
#in_pre_facial = in_facial.new_zeros((flame_facial.shape[0], flame_facial.shape[1], flame_facial.shape[2])).cuda() 
#in_pre_facial[:, 0:pre_frames] = flame_facial[:, 0:pre_frames]

In [None]:
pred_facial = net(in_pre_facial, in_audio=in_audio, in_text=in_word, in_id=in_id, in_emo=in_emo).cpu().detach()
#pred_facial = pred_facial * std_facial + mean_facial

In [None]:
print(bs_facial.cpu().min(), bs_facial.cpu().max())
print(bs_facial.cpu().std(), bs_facial.cpu().mean())
print(pred_facial.min(), pred_facial.max())
print(pred_facial.std(), pred_facial.mean())

In [None]:
# pred_bs_facial = pred_facial[:,:,:103] @ flame_to_bs
# pred_bs_facial[:,:,6:14] = pred_facial[:,:,103:111]

In [None]:
import threading

init_time = time.time() + 1

limit_sec = 70

udp_thread = threading.Thread(target=send_udp, args=(pred_facial[0,0:limit_sec*15],init_time))
udp_thread.daemon = True  # Set the thread as a daemon to allow it to exit when the main program exits

audio_thread = threading.Thread(target=play_audio, args=(out_audio[0,0:limit_sec*16000],init_time-0.3))
audio_thread.daemon = True

udp_thread.start()
audio_thread.start()

udp_thread.join()
audio_thread.join()

In [None]:
print(expressive_loss_function(pred_bs_facial, bs_facial))
print(torch.nn.functional.mse_loss(pred_bs_facial, bs_facial))
print(torch.nn.functional.mse_loss(pred_flame_facial, flame_facial))

In [None]:
print(expressive_loss_function(pred_bs_facial, bs_facial))
print(torch.nn.functional.mse_loss(pred_bs_facial, bs_facial))
print(torch.nn.functional.mse_loss(pred_flame_facial, flame_facial))