In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from feat import Detector
import librosa
from scipy.signal import resample as sci_resample
from mtrf.model import TRF
import seaborn as sns

#### Helper functions

In [None]:
def get_rms_envelope(audio_file, reqd_sr):
    stim, sr = librosa.load(audio_file)

    # Compute RMS 
    rms_win = 0.01 # 10ms
    rms_hop = 1/reqd_sr # hop by eeg sampling rate
    rms = librosa.feature.rms(y=stim, frame_length=int(sr*rms_win), hop_length=int(sr*rms_hop))
    rms_sr = 1/rms_hop # the rms time series is sampled with period rms_hop
    rms=rms[0]

    return rms


def resample_signal(signal, duration, reqd_sr, num_samples, pad_before):
    signal_resampled = sci_resample(signal, len(np.arange(0, duration, 1/reqd_sr)))
        
    pad_after = num_samples - (len(signal_resampled)+pad_before)
    signal_padded = np.pad(signal_resampled, pad_width=(pad_before,pad_after))

    return signal_padded


def analyse_videos( input_file, 
                    target_file, 
                    skip_frames=10, 
                    # batch_size=900, 
                    num_workers=16, 
                    pin_memory=False, 
                    n_jobs = 12,
                    face_model = "retinaface",
                    landmark_model = "mobilefacenet",
                    au_model = 'xgb',
                    emotion_model = "resmasknet",
                    facepose_model = "img2pose",
                    device = "cuda"):
    #New detector
    detector = Detector(
        face_model = face_model,
        landmark_model = landmark_model,
        au_model = au_model,
        emotion_model = emotion_model,
        facepose_model = facepose_model,
        device = device
    )

    video_prediction = detector.detect_video(input_file
                                            , skip_frames = skip_frames
                                            # , batch_size = batch_size
                                            , num_workers = num_workers
                                            , pin_memory = pin_memory
                                            , n_jobs = n_jobs)

    video_prediction.to_csv(target_file)


def get_aus(df, condition):
    counter = 0
    for i in range(len(df)):
        filepath = df.iloc[i]['VideoPath']
        disp_dyad = df.iloc[i]['DisplayedDyad']
        aus_filepath = f'./data/aus_pure/{condition}/{disp_dyad}_{filepath.split(os.sep)[-1][:-4]}_aus.csv'
        print(f'{counter}. ', filepath)
        analyse_videos(filepath, aus_filepath)
        counter += 1

In [None]:
df = pd.read_csv('./stim/all_trials_dispDyad.csv')

df_va = df[df['Modality'] == 'va']
df_va_trues = df_va[df_va['Condition'] == 'TRUE']
df_va_fakes = df_va[df_va['Condition'] != 'TRUE']

#### Format training data for TRF

In [None]:
sr = 30
min_time_lag = -1                       # in seconds
max_time_lag = 26                        # in seconds
pad_before = np.abs(sr*min_time_lag)
num_samples = (sr*max_time_lag)+(pad_before)

In [None]:
rms_all = []
nods_all = []

for i in range(len(df_va_trues)):
    filepath = df_va_trues.iloc[i]['VideoPath']
    disp_dyad = df_va_trues.iloc[i]['DisplayedDyad']
    duration = df_va_trues.iloc[i]['Duration']
    aus_filepath = f'./data/aus_pure/true/{disp_dyad}_{filepath.split(os.sep)[-1][:-4]}_aus.csv'

    rms = get_rms_envelope(df_va_trues.iloc[i]['AudioPath'], sr)
    rms = resample_signal(rms, duration, sr, num_samples, pad_before)

    df = pd.read_csv(aus_filepath)
    nods = df['Pitch'].to_numpy()
    nods = resample_signal(nods, duration, sr, num_samples, pad_before)

    rms_all.append(rms)
    nods_all.append(nods)


nods_all = np.asarray(nods_all)
nods_all_reshaped = []
for nod in nods_all:
    nods_all_reshaped.append(np.reshape(nod, (-1, 1)))

rms_all = np.asarray(rms_all)
rms_all_reshaped = []
for el in rms_all:
    rms_all_reshaped.append(np.reshape(el, (-1, 1)))

#### Train TRF

In [None]:
trf = TRF(direction=1)
# regularization = np.logspace(-1, 6, 20)
# [correlation, error] = trf.train(rms_all_reshaped, smiles_all_reshaped, 30, tmin=-1, tmax=26, k=-1, regularization=regularization)
trf.train(rms_all_reshaped, nods_all_reshaped, sr, tmin=min_time_lag, tmax=max_time_lag, regularization=1)

# fig, ax1 = plt.subplots()
# ax2 = ax1.twinx()
# ax1.semilogx(regularization, correlation, color='c')
# ax2.semilogx(regularization, error, color='m')
# ax1.set(xlabel='Regularization value', ylabel='Correlation coefficient')
# ax2.set(ylabel='Mean squared error')
# ax1.axvline(regularization[np.argmin(error)], linestyle='--', color='k')
# plt.show()

trf.plot(channel='avg', kind='line');

#### Prediction

In [None]:
dyad_col = []
stim_col = []
time_col = []
actual_col = []
pred_col = []
corr_col = []
error_col = []

for j in range(len(df_va_fakes)):
    input_rms = get_rms_envelope(df_va_fakes.iloc[j]['AudioPath'], sr)
    input_rms = resample_signal(input_rms, duration, sr, num_samples, pad_before)
    input_rms = np.asarray(input_rms)
    input_rms_reshaped = np.reshape(input_rms, (-1, 1))

    # Actual
    filepath = df_va_fakes.iloc[j]['VideoPath']
    disp_dyad = df_va_fakes.iloc[j]['DisplayedDyad']
    speaker_extract = df_va_fakes.iloc[j]['SpeakerExtract']
    listener_extract = df_va_fakes.iloc[j]['ListenerExtract']
    duration = df_va_fakes.iloc[j]['Duration']
    aus_filepath = f'./data/aus_pure/fake/{disp_dyad}_{filepath.split(os.sep)[-1][:-4]}_aus.csv'

    if os.path.exists(aus_filepath):
        df = pd.read_csv(aus_filepath)
        actual_nod = df['Pitch'].to_numpy()
        actual_nod = resample_signal(actual_nod, duration, sr, num_samples, pad_before)
        actual_nod = np.asarray(actual_nod)
        nod_reshaped = np.reshape(actual_nod, (-1, 1))
        
        dyad_col.append(np.repeat(disp_dyad, num_samples))
        stim = str(disp_dyad)+'_'+filepath.split(os.sep)[-1][:-6]
        stim_col.append(np.repeat(stim, num_samples))
        time_col.append(np.arange(min_time_lag, max_time_lag, 1/sr))
        actual_col.append(actual_nod)
        
        # Predicted
        [prediction, correlation, error] = trf.predict(input_rms_reshaped, nod_reshaped, average=True)
        pred_col.append(np.asanyarray(prediction).flatten())
        corr_col.append(np.repeat(correlation, num_samples))
        error_col.append(np.repeat(error, num_samples))

        # sns.lineplot(y=np.asanyarray(smiles).flatten(), x=np.arange(min_time_lag, max_time_lag, 1/sr))
        # sns.lineplot(y=np.asanyarray(prediction).flatten(), x=np.arange(min_time_lag, max_time_lag, 1/sr))
        # plt.ylabel('AU12')
        # plt.show()

In [None]:
df_trf = pd.DataFrame({
                        'Dyad': np.asanyarray(dyad_col).flatten(),
                        'Stim': np.asanyarray(stim_col).flatten(),
                        'Time': np.asanyarray(time_col).flatten(),
                        'Actual': np.asanyarray(actual_col).flatten(), 
                        'Predicted': np.asanyarray(pred_col).flatten(),
                        'Correlation': np.asanyarray(corr_col).flatten(),
                        'Error': np.asanyarray(error_col).flatten()
                    })
df_trf

In [None]:
sns.lineplot(data=df_trf, y='Actual', x='Time', errorbar='ci')
sns.lineplot(data=df_trf, y='Predicted', x='Time', errorbar='ci')
plt.ylabel('Head Pitch')

In [None]:
df_trf_pred = df_trf.groupby(['Dyad']).agg({'Correlation': 'mean', 'Error': 'mean'})
df_trf_pred

In [None]:
z_vals = np.arctanh(df_trf_pred['Correlation'].to_numpy())
avg_z = np.mean(z_vals)
r = np.tanh(avg_z)
print(r)