# Here we use the [MELD dataset](https://affective-meld.github.io/) to extract videos and faces.

Download [this small_dataset json](https://raw.githubusercontent.com/cltl/ma-communicative-robots/master/multimodal/small_dataset.json), as well

In [16]:
import json
import csv
import cv2
import numpy as np
from tqdm.notebook import tqdm
import shutil
import os
import random

raw_train_videos_path = '/home/tk/datasets/MELD/MELD.Raw/train/train_splits'
train_sent_emo_path = '/home/tk/repos/MELD/data/MELD/train_sent_emo.csv'
small_dataset_path = '/home/tk/datasets/MELD/MELD.Raw/train/small_dataset.json'
save_dir = '/home/tk/repos/cltl-face-all/your-faces'

with open(small_dataset_path, 'r') as stream:
    small_dataset = json.load(stream)

small_dataset = sorted(small_dataset['train'])

with open(train_sent_emo_path) as f:
    reader = csv.reader(f)
    data = list(reader)


dia_utts = []
for d in data:
    name = d[2]
    did = d[5]
    uid = d[6]
    vidfile = f"dia{did}_utt{uid}.mp4"

    dia = f"dia{did}"

    if dia in small_dataset:
        dia_utts.append((vidfile, name))
    else:
        continue

random.shuffle(dia_utts)
dia_utts = dia_utts[:50]

names_mentioned = [name for du, name in dia_utts]
unique_names = sorted(list(set(names_mentioned)))
important_names = []

for un in unique_names:
    num_occurences = names_mentioned.count(un)
#     if num_occurences > 5:
    print(un, num_occurences)
    important_names.append(un)
print()

print(f"The number of videos BEFORE removing not important name videos: {len(dia_utts)}")

for du_name in list(dia_utts):
    du, name = du_name
    if name not in important_names:
        dia_utts.remove(du_name)

dia_utts.sort(key=lambda x:x[1])

print(f"The number of videos AFTER removing not important name videos: {len(dia_utts)}")

for name in important_names:
    shutil.rmtree(os.path.join(save_dir, name), ignore_errors=True)
    os.mkdir(os.path.join(save_dir, name))

Ben 1
Chandler 5
Eric 1
Janice 1
Joey 7
Man 1
Monica 14
Mr. Franklin 1
Mr. Treeger 1
Phoebe 5
Rachel 5
Rick 1
Ross 6
Terry 1

The number of videos BEFORE removing not important name videos: 50
The number of videos AFTER removing not important name videos: 50


In [17]:
import av
from cltl_face_all.agegender import AgeGender
from cltl_face_all.arcface import ArcFace
from cltl_face_all.arcface import calc_angle_distance
from cltl_face_all.face_alignment import FaceDetection
import uuid


ag = AgeGender(device='cpu')
af = ArcFace(device='cpu')
fd = FaceDetection(device='cpu', face_detector='blazeface')


batch_size = 64
face_threshold = 0.9


for du, name in tqdm(dia_utts):
    full_video_path = os.path.join(raw_train_videos_path, du)

    container = av.open(full_video_path)

    batch = []
    print(full_video_path)
    for frame in container.decode(video=0):
        idx = frame.index
        numpy_RGB = np.array(frame.to_image())
        batch.append(numpy_RGB)

        if len(batch) == batch_size:
            print(len(batch))

            batch = np.stack(batch, axis=0)
            print(batch.shape)
            bboxes = fd.detect_faces(batch)
            print(len(bboxes))
            landmarks = fd.detect_landmarks(batch, bboxes)
            faces = fd.crop_and_align(batch, bboxes, landmarks)
            print(len(faces))
            faces = np.concatenate(faces, axis=0)
            print(faces.shape)
            embeddings = af.predict(faces)
            print(embeddings.shape)

            assert len(faces) == len(embeddings)

            for bb, fa, em in zip(bboxes, faces, embeddings):
                if len(bb) == 0:
                    continue
                prob = bb[0][4]
                if prob > face_threshold:
                    unique_name = str(uuid.uuid4())
                    cv2.imwrite(os.path.join(save_dir, name, unique_name) + '.jpg', cv2.cvtColor(fa, cv2.COLOR_RGB2BGR))
                    with open(os.path.join(save_dir, name, unique_name) + '.npy', 'wb') as stream:
                        np.save(stream, em)


            batch = []
    if len(batch) != 0:
        print(len(batch))

        batch = np.stack(batch, axis=0)
        print(batch.shape)
        bboxes = fd.detect_faces(batch)
        print(len(bboxes))
        landmarks = fd.detect_landmarks(batch, bboxes)
        faces = fd.crop_and_align(batch, bboxes, landmarks)
        print(len(faces))
        faces = np.concatenate(faces, axis=0)
        print(faces.shape)
        embeddings = af.predict(faces)
        print(embeddings.shape)

        assert len(faces) == len(embeddings)

        for bb, fa, em in zip(bboxes, faces, embeddings):
            if len(bb) == 0:
                continue

            prob = bb[0][4]
            if prob > face_threshold:
                unique_name = str(uuid.uuid4())
                cv2.imwrite(os.path.join(save_dir, name, unique_name) + '.jpg', cv2.cvtColor(fa, cv2.COLOR_RGB2BGR))
                unique_name = str(uuid.uuid4())
                with open(os.path.join(save_dir, name, unique_name) + '.npy', 'wb') as stream:
                    np.save(stream, em)

        batch = []

    print()

[*] load ckpt from /home/tk/.virtualenvs/dev-python-3.7/lib/python3.7/site-packages/cltl_face_all/arcface/./pretrained_models/arc_res50/e_8_b_40000.ckpt


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))

/home/tk/datasets/MELD/MELD.Raw/train/train_splits/dia500_utt1.mp4
37
(37, 720, 1280, 3)
37
37
(59, 112, 112, 3)
(59, 512)

/home/tk/datasets/MELD/MELD.Raw/train/train_splits/dia664_utt8.mp4
64
(64, 720, 1280, 3)
64
64
(173, 112, 112, 3)
(173, 512)
11
(11, 720, 1280, 3)
11
11
(23, 112, 112, 3)
(23, 512)

/home/tk/datasets/MELD/MELD.Raw/train/train_splits/dia332_utt2.mp4
43
(43, 720, 1280, 3)
43
43
(102, 112, 112, 3)
(102, 512)

/home/tk/datasets/MELD/MELD.Raw/train/train_splits/dia274_utt0.mp4
64
(64, 720, 1280, 3)
64
64
(32, 112, 112, 3)
(32, 512)
64
(64, 720, 1280, 3)
64
64
(17, 112, 112, 3)
(17, 512)
32
(32, 720, 1280, 3)
32
32
(20, 112, 112, 3)
(20, 512)

/home/tk/datasets/MELD/MELD.Raw/train/train_splits/dia822_utt7.mp4
23
(23, 720, 1280, 3)
23
23
(48, 112, 112, 3)
(48, 512)

/home/tk/datasets/MELD/MELD.Raw/train/train_splits/dia616_utt2.mp4
64
(64, 720, 1280, 3)
64
64
(78, 112, 112, 3)
(78, 512)
19
(19, 720, 1280, 3)
19
19
(39, 112, 112, 3)
(39, 512)

/home/tk/datasets/MELD/MELD.

## Now go through the images and delete the wrong images.

## Let's clean the data

In [18]:
from glob import glob
import os

important_names = ['Chandler', 'Joey', 'Monica', 'Phoebe', 'Rachel', 'Ross']
for name in important_names:
    npys = glob(os.path.join(save_dir, name, "*.npy"))

    for npy in list(npys):
        jpg = npy.replace('.npy', '.jpg')
        if not os.path.isfile(jpg):
            os.remove(npy)

In [19]:
from cltl_face_all.arcface import calc_angle_distance

for name in important_names:
    print(name)
    npys = glob(os.path.join(save_dir, name, "*.npy"))
    embs = []

    for npy in npys:
        with open(npy, 'rb') as stream:
            embs.append(np.load(stream))
    embs = np.stack(embs)
    print(embs.shape)

    dists = calc_angle_distance(embs, embs)
    print(dists.round(2))
    print()

    indexes = dists.mean(axis=1).argsort()[:len(dists.mean(axis=1)) // 2]

    print(indexes)
    embs = embs[indexes]
    print(len(embs))


    emb_mean = embs.mean(axis=0)
    emb_final = emb_mean / np.linalg.norm(emb_mean)
    print(emb_final.shape)

    with open(os.path.join(save_dir, name) + '.npy', 'wb') as stream:
        np.save(stream, emb_final)


Joey
(11, 512)
[[0.   0.72 0.42 0.41 0.56 0.66 0.47 0.56 0.42 0.41 0.48]
 [0.72 0.   0.6  0.62 0.66 0.24 0.61 0.31 0.52 0.55 0.67]
 [0.42 0.6  0.   0.44 0.51 0.55 0.47 0.47 0.26 0.4  0.24]
 [0.41 0.62 0.44 0.   0.46 0.55 0.19 0.47 0.46 0.47 0.49]
 [0.56 0.66 0.51 0.46 0.   0.61 0.5  0.58 0.52 0.56 0.59]
 [0.66 0.24 0.55 0.55 0.61 0.   0.53 0.23 0.45 0.48 0.61]
 [0.47 0.61 0.47 0.19 0.5  0.53 0.   0.48 0.5  0.52 0.53]
 [0.56 0.31 0.47 0.47 0.58 0.23 0.48 0.   0.35 0.39 0.53]
 [0.42 0.52 0.26 0.46 0.52 0.45 0.5  0.35 0.   0.28 0.38]
 [0.41 0.55 0.4  0.47 0.56 0.48 0.52 0.39 0.28 0.   0.47]
 [0.48 0.67 0.24 0.49 0.59 0.61 0.53 0.53 0.38 0.47 0.  ]]

[8 2 7 9 3]
5
(512,)
