# face extractor

In [1]:
# !pip install facenet-pytorch #../input/package4/facenet_pytorch-2.0.1-py3-none-any.whl

import sys, os
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import cv2
import glob
import time
import torch
import random
import time
from PIL import Image
from facenet_pytorch import MTCNN, InceptionResnetV1, extract_face
from torchvision.transforms import Normalize, RandomHorizontalFlip, ToTensor, ToPILImage, Compose, Resize
from sklearn.metrics import log_loss
import pathlib
from tqdm import tqdm
import dlib

In [10]:
def isotropically_resize_image(img, size, resample=cv2.INTER_AREA):
    h, w = img.shape[:2]
    if w > h:
        h = h * size // w
        w = size
    else:
        w = w * size // h
        h = size

    resized = cv2.resize(img, (w, h), interpolation=resample)
    make_square_image(resized)
    return resized


def make_square_image(img):
    h, w = img.shape[:2]
    size = max(h, w)
    top = 0
    bottom = size - h
    left = 0
    right = size - w
    return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)

class Video_reader:
    def extract_video(self, video_path):
        cap = cv2.VideoCapture(video_path)
        frames = []
        while(cap.isOpened()):
            ret, frame = cap.read()
            if ret==True:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            else:
                break
        cap.release()
        assert len(frames) != 0
        return np.array(frames)
        
        
    def extract_one_frame(self, video_path, frame_index):
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)  #设置要获取的帧号
        _, frame=cap.read()
        cap.release()
        if _:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            return frame
        else:
            return None


class Cache_loader:    
    def extract_video(self, video_path):
        filename = video_path.split('/')[-1].split('.')[0]
        cache_path = '/data1/data/deepfake/faces/'+filename
        if os.path.exists(cache_path):
            ret = {}
            for root,subdirs,files in os.walk(cache_path):
                for file in files:
                    face = cv2.cvtColor(np.load(os.path.join(root, file)), cv2.COLOR_BGR2RGB)
                    ret[int(file.split(".")[0])] = isotropically_resize_image(face, 224)
            return ret
        else:
            raise "cache not found"
            
    def get_faces(self, cache_path):
        faces = [cv2.cvtColor(np.load(fn), cv2.COLOR_BGR2RGB) for fn in cache_path]
        return [isotropically_resize_image(face, 224) for face in faces]
        
        

class Face_extractor:
    def __init__(self):
        pass
        
    def _get_boundingbox(self, bbox, width, height, scale=1.2, minsize=None):
        x1, y1, x2, y2 = bbox[:4]
        if not 0.33 < (x2-x1)/(y2-y1) < 3:
            return np.array([0,0,0,0])
        size_bb = int(max(x2 - x1, y2 - y1) * scale)
        if minsize:
            if size_bb < minsize:
                size_bb = minsize
        center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2

        x1 = max(int(center_x - size_bb / 2), 0)
        y1 = max(int(center_y - size_bb / 2), 0)
        size_bb = min(width - x1, size_bb)
        size_bb = min(height - y1, size_bb)

        return np.array([x1,y1,x1+size_bb,y1+size_bb]).astype(int)
    
    
    def _rectang_crop(self, image, bbox):
        height, width = image.shape[:2]
        l,t,r,b = self._get_boundingbox(bbox, width, height) 
        return image[t:b, l:r]
    
    def _get(images):
        pass
    
    
    def get_faces(self, images, with_person_num = False, only_one = True):
        faces, nums = self._get(images)
        if only_one:
            faces = [face[0] for face in faces if len(face)>0]
            nums = [num for num, face in zip(nums,faces) if len(face)>0]
        if with_person_num:
            faces = (faces, nums)
        return faces
    
    
    def get_face(self, image, with_person_num = False, only_one = True):
        faces, nums = self.get_faces(np.array([image]), with_person_num=True, only_one=False)
        faces, nums = faces[0], nums[0]
        if only_one:
            if len(faces)>0:
                faces = faces[0]
            else:
                faces = None
        if with_person_num:
            faces = (faces, nums)
        return faces
    
    
class MTCNN_extractor(Face_extractor):
    def __init__(self, device = 'cuda:0' if torch.cuda.is_available() else 'cpu', down_sample = 2):
        self.extractor = MTCNN(keep_all=True, device=device, min_face_size=80//down_sample).eval()
        self.down_sample = down_sample
            
    def _get(self, images):
        h, w = images.shape[1:3]
        pils = [Image.fromarray(img).resize((w//self.down_sample, h//self.down_sample)) for img in images]
        bboxes, probs = self.extractor.detect(pils)
        facelist = [[self._rectang_crop(img, box) for box in boxes*self.down_sample] for boxes, img in zip(bboxes,images) if boxes is not None]
        person_nums = [np.sum(prob>0.9) for prob,fss in zip(probs, bboxes) if len(fss)>0]
        
        assert len(person_nums) == len(facelist)
        return facelist, person_nums


class Dlib_extractor(Face_extractor):
    def __init__(self, device = 'cuda:0' if torch.cuda.is_available() else 'cpu'):
        self.extractor = dlib.get_frontal_face_detector()
        
    def _get(self, images):
        rets = [self.dlib_get_one_face(image) for image in images]
        person_nums = [p for f,p in rets]
        faces = [f for f,p in rets]
        return faces, person_nums
    
    def dlib_get_one_face(self, image):
        height, width = image.shape[:2]
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        faces = self.extractor(gray, 0)
        bboxes = [[face.left(), face.top(), face.right(), face.bottom()] for face in faces]
        facelist = [self._rectang_crop(image, box) for box in bboxes]
        
        return facelist, len(facelist)
        
    
class Inference_model:
    def __init__(self):
        pass
        
    
    def data_transform(self):
        # transform to Tensor
        pre_trained_mean, pre_trained_std = [0.439, 0.328, 0.304], [0.232, 0.206, 0.201]
        return Compose([Resize(224), ToTensor(), Normalize(pre_trained_mean, pre_trained_std)])
    
    
    def TTA(self, pil_img):
        return [pil_img, RandomHorizontalFlip(p=1)(pil_img)]
    
    
    def predict(self, batch):
        print(batch[0])
        return 0.5
    
    def getx(sel, faoa):
        l = np.float64(0)
        r = np.float64(1)
        while r-l>5e-8:
            mid = (l+r)/2
            if mid**faoa > 1-mid:
                r = mid
            else:
                l=mid
        return (r+l)/2

    def give_predict(self, y):
        y = y.clip(5e-8, 1-(5e-8))
        faoa = np.sum(np.log(1-y))/np.sum(np.log(y))
        ret = self.getx(faoa)
        if ret > 0.7 and len(y[y<0.5]) > 0:
            return self.give_predict(y[y>0.5])
        return ret
    
    def test(self, shape = (1,3,224,224)):
        return self.predict(torch.rand(shape))
    

class Model1(Inference_model):
    def __init__(self, model_path = "/home/kailu/best_model.pth"):
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model = checkpoint['model']
        self.model.eval()
        
    
    def TTA(self, pil_img):
        return [pil_img]
    
    
    def predict(self, batch):
        with torch.no_grad():
            batch = batch.to(self.device)
            y_pred = self.model(batch)
            ret = self.give_predict(y_pred.cpu()[:,1].squeeze().numpy())
        return ret
                       
def show(images):
    %matplotlib inline
    import matplotlib.pyplot as plt
    rows = int(np.sqrt(len(images)))
    col = int(np.ceil(len(images)/ rows))
    fig, axes = plt.subplots(rows, col)
    ax = np.array(axes).reshape(-1)
    for i, img in enumerate(images):
        if img is not None:
            ax[i].imshow(img)
    plt.grid(False)
    plt.show()

In [3]:
def predict_on_all(video_paths, video_lables = [], video_reader=None, 
                   face_extractor=None, models=None, sample_number = 13, use_cache = False):
    if video_reader is None and not use_cache:
        video_reader = Video_reader()
    if face_extractor is None and not use_cache:
        face_extractor = MTCNN_extractor()
    if models is None:
        models = [Model1()]
    if use_cache:
        loader = Cache_loader()

    def predict_one_video(file_path):
        with torch.no_grad():
            try:
                if not use_cache:
                    frames = video_reader.extract_video(file_path)
                    sample = np.linspace(0, len(frames) - 1, sample_number*2).round().astype(int)
                    faces = face_extractor.get_faces(frames[sample])
                    np.random.shuffle(faces)
                else:
                    faces = list(loader.extract_video(file_path).values())
                assert len(faces) != 0
                pils = [Image.fromarray(face)  for face in faces[:sample_number] ]

                answers = []
                for model in models:
                    tr = model.data_transform()
                    batch = torch.stack([tr(p) for img in pils for p in model.TTA(img)])
                    answers.append(model.predict(batch))
                return np.mean(answers)
            except TypeError as e:
                print("Error with ", file_path, e)
                return 0.5

    predicts = [predict_one_video(i) for i in tqdm(video_paths)]

#     from concurrent.futures import ThreadPoolExecutor

#     with ThreadPoolExecutor(max_workers=5) as ex:
#         predicts = list(ex.map(predict_one_video, video_paths))

    if len(video_lables) == len(video_paths):
        print(f"loss = {log_loss(video_lables, predicts, labels=[0,1])}")
    return predicts

def gen_data(datadf):
    return  [base + fn for fn in datadf.index], [0 if la=='REAL' else 1 for la in datadf.label]

In [4]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
kaggle = False
speed_test = True
print(f"using {'cuda:0' if torch.cuda.is_available() else 'cpu'}")

if kaggle:
    filenames = glob.glob('/kaggle/input/deepfake-detection-challenge/test_videos/*.mp4')
    labels = []
else:
    base = '/data/deepfake/dfdc_train/'
    metadata = pd.read_json(base + 'metadata_kailu.json').T
    df = metadata[(metadata['split_kailu'] == 'test')]
    filenames, labels = gen_data(df)
      
models = [Model1()]
face_extractor = MTCNN_extractor()

using cuda:0


In [12]:
if speed_test:
    testnum = 1000
    start = time.time()
    if kaggle:
        test_filenames, test_labels = filenames[:testnum], []
    else:
        test_filenames, test_labels = gen_data(df.sample(testnum))
    ret = predict_on_all(test_filenames, video_lables=test_labels, models=models, face_extractor=face_extractor)
    time_dur = time.time()-start
    print(f"totally {time_dur} s used, {time_dur/testnum} s per video, mean = {np.mean(ret)}")


  0%|          | 0/1000 [00:00<?, ?it/s][A
  0%|          | 1/1000 [00:35<9:43:40, 35.06s/it][A
  0%|          | 2/1000 [01:32<11:37:14, 41.92s/it][A
  0%|          | 3/1000 [02:04<10:45:17, 38.83s/it][A
  0%|          | 4/1000 [02:36<10:08:16, 36.64s/it][A
  0%|          | 5/1000 [03:04<9:25:34, 34.11s/it] [A
  1%|          | 6/1000 [03:34<9:07:09, 33.03s/it][A
  1%|          | 7/1000 [05:23<15:19:39, 55.57s/it][A
  1%|          | 8/1000 [05:40<12:08:25, 44.06s/it][A
  1%|          | 9/1000 [05:58<10:01:47, 36.44s/it][A
  1%|          | 10/1000 [06:24<9:07:22, 33.17s/it][A
  1%|          | 11/1000 [06:48<8:22:16, 30.47s/it][A
  1%|          | 12/1000 [07:09<7:32:53, 27.50s/it][A
  1%|▏         | 13/1000 [07:30<7:01:42, 25.64s/it][A
  1%|▏         | 14/1000 [07:42<5:52:59, 21.48s/it][A
  2%|▏         | 15/1000 [08:03<5:53:21, 21.52s/it][A
  2%|▏         | 16/1000 [08:26<5:58:55, 21.89s/it][A
  2%|▏         | 17/1000 [08:46<5:50:27, 21.39s/it][A
  2%|▏         | 18/10

Error with  /data1/data/deepfake/dfdc_train/igrjzpduve.mp4 object of type 'NoneType' has no len()



  2%|▏         | 21/1000 [10:03<5:36:28, 20.62s/it][A
  2%|▏         | 22/1000 [10:32<6:15:16, 23.02s/it][A
  2%|▏         | 23/1000 [10:55<6:18:27, 23.24s/it][A
  2%|▏         | 24/1000 [11:30<7:16:33, 26.84s/it][A
  2%|▎         | 25/1000 [11:56<7:07:56, 26.33s/it][A
  3%|▎         | 26/1000 [12:19<6:51:19, 25.34s/it][A
  3%|▎         | 27/1000 [12:41<6:35:03, 24.36s/it][A
  3%|▎         | 28/1000 [13:12<7:10:32, 26.58s/it][A
  3%|▎         | 29/1000 [13:33<6:42:24, 24.87s/it][A
  3%|▎         | 30/1000 [14:08<7:27:25, 27.68s/it][A
  3%|▎         | 31/1000 [14:29<6:56:05, 25.76s/it][A

Error with  /data1/data/deepfake/dfdc_train/kykfgjhirn.mp4 object of type 'NoneType' has no len()



  3%|▎         | 32/1000 [14:50<6:35:00, 24.48s/it][A
  3%|▎         | 33/1000 [15:11<6:15:41, 23.31s/it][A
  3%|▎         | 34/1000 [15:32<6:02:44, 22.53s/it][A
  4%|▎         | 35/1000 [15:57<6:15:37, 23.35s/it][A
  4%|▎         | 36/1000 [16:17<5:59:13, 22.36s/it][A
  4%|▎         | 37/1000 [16:42<6:10:22, 23.08s/it][A
  4%|▍         | 38/1000 [17:03<6:01:48, 22.57s/it][A
  4%|▍         | 39/1000 [17:24<5:50:59, 21.91s/it][A
  4%|▍         | 40/1000 [17:49<6:05:47, 22.86s/it][A
  4%|▍         | 41/1000 [18:13<6:13:01, 23.34s/it][A
  4%|▍         | 42/1000 [18:40<6:28:30, 24.33s/it][A
  4%|▍         | 43/1000 [19:03<6:22:00, 23.95s/it][A
  4%|▍         | 44/1000 [19:25<6:12:16, 23.37s/it][A
  4%|▍         | 45/1000 [19:46<6:01:41, 22.72s/it][A
  5%|▍         | 46/1000 [20:08<5:56:39, 22.43s/it][A
  5%|▍         | 47/1000 [20:31<5:59:15, 22.62s/it][A
  5%|▍         | 48/1000 [20:51<5:48:56, 21.99s/it][A
  5%|▍         | 49/1000 [21:14<5:50:19, 22.10s/it][A
  5%|▌   

Error with  /data1/data/deepfake/dfdc_train/utdqvtzcwx.mp4 object of type 'NoneType' has no len()



  7%|▋         | 73/1000 [27:43<1:41:39,  6.58s/it][A
  7%|▋         | 74/1000 [27:48<1:36:53,  6.28s/it][A
  8%|▊         | 75/1000 [27:57<1:48:23,  7.03s/it][A
  8%|▊         | 76/1000 [28:05<1:53:20,  7.36s/it][A
  8%|▊         | 77/1000 [28:11<1:46:54,  6.95s/it][A
  8%|▊         | 78/1000 [28:23<2:09:45,  8.44s/it][A
  8%|▊         | 79/1000 [28:31<2:06:34,  8.25s/it][A

KeyboardInterrupt: 

In [7]:
predictions = predict_on_all(filenames, video_lables=labels, models=models, face_extractor=face_extractor)

submission_df = pd.DataFrame({"filename": [fn.split('/')[-1] for fn in filenames], "label": predictions})
submission_df.to_csv("submission.csv", index=False)

  0%|          | 5/12640 [07:34<339:27:37, 96.72s/it] 

Error with  /data1/data/deepfake/dfdc_train/aoydktojny.mp4 object of type 'NoneType' has no len()


  0%|          | 21/12640 [34:11<320:17:37, 91.37s/it] 

KeyboardInterrupt: 