In [13]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import imutils
from PIL import Image, ImageDraw, ImageFont
import pandas as pd
import function as fn

import torch
from torchvision import transforms

import mediapipe as mp
import numpy as np
import os
import time
from PIL import Image
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mp_face_mesh = mp.solutions.face_mesh


## FUNCTIONS

### PREPROCESS

In [14]:
def pth_processing(fp):
    class PreprocessInput(torch.nn.Module):
        def init(self):
            super(PreprocessInput, self).init()
        def forward(self, x):
            x = x.to(torch.float32)
            x = torch.flip(x, dims=(0,))
            x[0, :, :] -= 91.4953
            x[1, :, :] -= 103.8827
            x[2, :, :] -= 131.0912
            return x

    def get_img_torch(img):
        ttransform = transforms.Compose([
            transforms.PILToTensor(),
            PreprocessInput()
        ])
        img = img.resize((224, 224), Image.Resampling.NEAREST)
        img = ttransform(img)
        img = torch.unsqueeze(img, 0).to(device)
        return img
    return get_img_torch(fp)

### ANNOTATIONS

In [15]:
def resize_and_center(image, target_shape):
    h, w = image.shape[:2]
    target_w, target_h = target_shape['shape']
    target_x, target_y = target_shape['position']

    if not target_shape['target'] == 'face':
        orientation = 'vertical' if h > w else 'horizontal'

        if orientation == 'horizontal':
            image = imutils.resize(image, width=target_w)
            y = target_y + (target_h - image.shape[0]) // 2
            x = target_x
        
        else:
            image = imutils.resize(image, height=target_h)
            x = target_x + (target_w - image.shape[1]) // 2
            y = target_y
        
        return image, (x, y)
    else:
        image = cv2.resize(image, (target_w, target_h))
        x = target_x
        y = target_y
        return image, (x, y)

In [16]:
def draw_text_on_screen(screen, text, position, font_size=20, color=(0, 0, 0)):
    screen_pil = Image.fromarray(cv2.cvtColor(screen, cv2.COLOR_BGR2RGB))
    font = ImageFont.truetype("src/Bebas.ttf", font_size)
    draw = ImageDraw.Draw(screen_pil)
    draw.text(position, text, font=font, fill=color)
    return cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)

In [17]:
def place_image_on_screen(screen, cv2_image, target_shape):
    resized_image, (x, y) = resize_and_center(cv2_image, target_shape)
    screen[y:y + resized_image.shape[0], x:x + resized_image.shape[1]] = resized_image
    return screen

In [18]:
screen_shape = (1280, 800)

right_offset = (0,70)
left_offset = (0,-20)

video_shape = {'shape': (537, 570), 'position': (85, 155 + left_offset[1]), 'target': 'video'}
graphic_shape = {'shape': (537, 398), 'position': (658, 290 + right_offset[1]), 'target': 'graphic'}
face_shape = {'shape': (200, 200), 'position': (728, 105 + right_offset[1]), 'target': 'face'}
predominant_shape = {'shape': (252, 298), 'position': (955, 85 + right_offset[1]), 'target': 'predominant'}

def update_screen_info(screen, predominant_emotion, emotion_duration, current_emotion, current_emotion_prob, backbone_model, LSTM_model, device, fps):
    
    text_predominant_emotion = f'Predominant Emotion: {predominant_emotion}'
    text_duration = f'Duration: {emotion_duration:.2f} seconds'
    text_current_emotion = f'Emotion: {current_emotion} ({current_emotion_prob*100:.2f}%)'
    
    screen = draw_text_on_screen(screen, text_current_emotion, (predominant_shape['position'][0] + 10, predominant_shape['position'][1] + 140), font_size=20, color=(0, 0, 0))
    screen = draw_text_on_screen(screen, text_duration, (predominant_shape['position'][0] + 10, predominant_shape['position'][1] + 170), font_size=20, color=(0, 0, 0))
    screen = draw_text_on_screen(screen, text_predominant_emotion, (predominant_shape['position'][0] + 10, predominant_shape['position'][1] + 200), font_size=20, color=(0, 0, 0))
    
    screen = draw_text_on_screen(screen, 'EMOTION CLASSIFICATION', (450, 30), 40, (0, 0, 0))
    screen = draw_text_on_screen(screen, f'Backbone: {backbone_model}', (video_shape['position'][0], video_shape['position'][1] - 60), font_size=20, color=(0, 0, 0))
    screen = draw_text_on_screen(screen, f'LSTM: {LSTM_model}', (video_shape['position'][0], video_shape['position'][1] - 35), font_size=20, color=(0, 0, 0))

    screen = draw_text_on_screen(screen, f'Using {device}', (video_shape['position'][0], video_shape['position'][1] - 85), font_size=20, color=(0, 0, 0))
    screen = draw_text_on_screen(screen, f'FPS: {fps:.1f}', (video_shape['position'][0], video_shape['position'][1] - 110), font_size=20, color=(0, 0, 0))
    
    return screen

In [19]:
def update_emotion_graph(ax, emotion_probs):
    ax.clear()
    for emotion, probs in emotion_probs.items():
        if probs:
            ax.plot(probs, label=emotion)
            ax.annotate(emotion, 
                        xy=(len(probs) - 1, probs[-1]), 
                        xytext=(5, 0), 
                        textcoords='offset points',
                        color=ax.get_lines()[-1].get_color(),
                        fontsize=8,
                        fontweight='regular')
    
    ax.set_xlabel('Frames')
    ax.set_ylabel('Probability')
    ax.set_title('Emotions Probabilities')
    ax.legend(fontsize='x-small', loc='upper left')
    ax.grid(True)
    
    plt.tight_layout()
    ax.set_ylim(-0.05, 1.05)
    ax.xaxis.set_major_locator(plt.MaxNLocator(5))

### PROCESS VIDEO

In [20]:
def process_video(video_path, output_path, backbone_model='0_66_49_wo_gl', LSTM_model='RAVDESS', device=device):

    pth_backbone_model = torch.jit.load(f'model/Torch/torchscript_model_{backbone_model}.pth').to(device)
    pth_backbone_model.eval()

    pth_LSTM_model = torch.jit.load(f'model/Torch/{LSTM_model}.pth').to(device)
    pth_LSTM_model.eval()

    DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}
    emotion_probs = {emotion: [] for emotion in DICT_EMO.values()}
    
    output_path_basename = os.path.basename(output_path)
    output_log_path = f'logs/{output_path_basename}.csv'

    data_list = []
    emotion_history = []

    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    combined_writer = cv2.VideoWriter(output_path, fourcc, fps, (screen_shape[0], screen_shape[1]))
    
    blank_screen = np.ones((screen_shape[1], screen_shape[0], 3), dtype=np.uint8) * 255
    
    lstm_features = []
    emotion_probs = {emotion: [] for emotion in DICT_EMO.values()}
    
    fig, ax = plt.subplots(1, 1, figsize=(5.37, 2.95))
    
    face_mesh = mp.solutions.face_mesh.FaceMesh(
        max_num_faces=1,
        refine_landmarks=False,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5
    )
    
    current_emotion = None
    emotion_start_time = None
    
    try:
        for frame_count in tqdm(range(total_frames), desc="Processing frames"):
            t1 = time.time()
            success, frame = cap.read()
            if not success:
                break
            
            screen = blank_screen.copy()
            
            frame_copy = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = face_mesh.process(frame_copy)
            
            if results.multi_face_landmarks:
                for fl in results.multi_face_landmarks:
                    startX, startY, endX, endY = fn.get_box(fl, w, h)
                    face_copy = frame_copy[startY:endY, startX:endX]
                    face_copy = cv2.cvtColor(face_copy, cv2.COLOR_RGB2BGR)
                    
                    cur_face = pth_processing(Image.fromarray(frame_copy[startY:endY, startX:endX]))
                    features = torch.nn.functional.relu(pth_backbone_model.extract_features(cur_face)).cpu().detach().numpy()
                    
                    if not lstm_features:
                        lstm_features = [features] * 10
                    else:
                        lstm_features.pop(0)
                        lstm_features.append(features)
                    
                    lstm_f = torch.unsqueeze(torch.from_numpy(np.vstack(lstm_features)), 0).to(device)
                    output = pth_LSTM_model(lstm_f).cpu().detach().numpy()
                    
                    for i, emotion in DICT_EMO.items():
                        emotion_probs[emotion].append(output[0, i])
                    
                    cl = np.argmax(output)
                    label = DICT_EMO[cl]
                    # frame = fn.annotate(frame, (startX, startY, endX, endY), label)
                    
                    face_region = face_copy
                    screen = place_image_on_screen(screen, face_region, face_shape)
                    
                    if label != current_emotion:
                        current_emotion = label
                        emotion_start_time = time.time()
                    
                    if emotion_start_time is not None:
                        emotion_duration = time.time() - emotion_start_time
                    else:
                        emotion_duration = 0
                    
                    emotion_history.append(label)
                    
                    predominant_emotion = max(set(emotion_history), key=emotion_history.count)
                    current_emotion_prob = output[0, cl]
                    
                    data_list.append({
                        'Frame': frame_count,
                        'Emotion': current_emotion,
                        'Time': emotion_duration,
                        'NE_PROB': output[0, 0],
                        'HA_PROB': output[0, 1],
                        'SA_PROB': output[0, 2],
                        'SU_PROB': output[0, 3],
                        'FE_PROB': output[0, 4],
                        'DI_PROB': output[0, 5],
                        'AN_PROB': output[0, 6]
                    })

            screen = place_image_on_screen(screen, frame, video_shape)
            
            update_emotion_graph(ax, emotion_probs)
            
            fig.canvas.draw()
            plot_frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            plot_frame = plot_frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            plot_frame = cv2.cvtColor(plot_frame, cv2.COLOR_RGB2BGR)
            
            screen = place_image_on_screen(screen, plot_frame, graphic_shape)
            
            fps = 1 / (time.time() - t1)
            screen = update_screen_info(screen, predominant_emotion, emotion_duration, current_emotion, current_emotion_prob, backbone_model, LSTM_model, device, fps)
            
            combined_writer.write(screen)
    
    finally:
        cap.release()
        combined_writer.release()
        face_mesh.close()
        plt.close(fig)

        df = pd.DataFrame(data_list)
        df.to_csv(output_log_path, index=False)
    
    return screen

## MAIN

In [21]:
models_available = {
    'backbone': ['0_66_37_wo_gl', '0_66_49_wo_gl'],
    'LSTM': ['RAVDESS', 'CREMA-D', 'Aff-Wild2', 'SAVEE', 'RAMAS', 'IEMOCAP']
}

backbone_model = '0_66_49_wo_gl'
LSTM_model = 'SAVEE'

os.makedirs('output', exist_ok=True)
os.makedirs('logs', exist_ok=True)

video_src = 'src/will.mp4'
video_name = os.path.basename(video_src).split('.')[0]
output_path = f'output/{backbone_model}_{LSTM_model}_{video_name}.mp4'
final_screen = process_video(
    video_path=video_src, 
    output_path=output_path, 
    backbone_model=backbone_model, 
    LSTM_model=LSTM_model, 
    device=device
)

plt.figure(figsize=(12, 7))
plt.imshow(cv2.cvtColor(final_screen, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()

print(f'Output video saved at {output_path}')

I0000 00:00:1723818468.651249 14616658 gl_context.cc:357] GL version: 2.1 (2.1 Metal - 88.1), renderer: Apple M1
Processing frames:   0%|          | 0/528 [00:00<?, ?it/s]W0000 00:00:1723818468.654101 14630069 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1723818468.657114 14630070 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


torch.Size([1, 3, 224, 224])


  plot_frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
Processing frames:   0%|          | 1/528 [00:00<05:24,  1.62it/s]

torch.Size([1, 3, 224, 224])


Processing frames:   1%|          | 3/528 [00:01<03:01,  2.89it/s]

torch.Size([1, 3, 224, 224])


Processing frames:   1%|          | 4/528 [00:01<02:23,  3.66it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   1%|          | 6/528 [00:01<01:57,  4.46it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   2%|▏         | 8/528 [00:02<01:39,  5.25it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   2%|▏         | 10/528 [00:02<01:34,  5.47it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   2%|▏         | 12/528 [00:02<01:29,  5.78it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   3%|▎         | 14/528 [00:03<01:28,  5.84it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   3%|▎         | 16/528 [00:03<01:25,  6.01it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   3%|▎         | 17/528 [00:03<01:27,  5.83it/s]

torch.Size([1, 3, 224, 224])


Processing frames:   4%|▎         | 19/528 [00:03<01:30,  5.64it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   4%|▍         | 21/528 [00:04<01:26,  5.86it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   4%|▍         | 23/528 [00:04<01:28,  5.73it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   5%|▍         | 25/528 [00:04<01:26,  5.83it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   5%|▌         | 27/528 [00:05<01:22,  6.04it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   5%|▌         | 29/528 [00:05<01:22,  6.09it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   6%|▌         | 31/528 [00:06<01:26,  5.77it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   6%|▋         | 33/528 [00:06<01:23,  5.90it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   7%|▋         | 35/528 [00:06<01:23,  5.90it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   7%|▋         | 37/528 [00:07<01:23,  5.86it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   7%|▋         | 39/528 [00:07<01:20,  6.06it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   8%|▊         | 41/528 [00:07<01:16,  6.38it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   8%|▊         | 43/528 [00:07<01:21,  5.96it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   9%|▊         | 45/528 [00:08<01:20,  5.98it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   9%|▉         | 47/528 [00:08<01:19,  6.02it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:   9%|▉         | 49/528 [00:08<01:16,  6.29it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:  10%|▉         | 51/528 [00:09<01:14,  6.44it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


Processing frames:  10%|▉         | 51/528 [00:09<01:27,  5.46it/s]


KeyboardInterrupt: 