# Library

All function need for this script


In [1]:
from deepface import DeepFace
import pandas as pd
import os
from config import get_config,update_config
import numpy as np
import torch
import os
import cv2
import torchvision.transforms as transforms
try:
    from apex import amp
except ImportError:
    print("AMP is not available")
from datasets.predict_loader import test_loader
from datasets.extract_depth_gazefollow import extract_depth
from datasets.extract_depth_dino import extract_depth_dino
from models import get_model, load_pretrained
from utils import get_memory_format
import matplotlib.pyplot as plt
import os
import cv2
from tqdm import tqdm
import re
from PIL import Image


AMP is not available


In [2]:

def face_extract(img1_path, img2_path):
    result = DeepFace.verify(img1_path = img1_path, img2_path = img2_path, enforce_detection=False)
    dict=result["facial_areas"]["img1"]
    return [dict["x"],dict["y"],dict["x"]+dict["w"],dict["y"]+dict["h"], result["distance"]]

def generate_facemask(image_list, start_box, dataset_dir, coefficient=1):
    """
    Generate an annotation file based on the provided image list and reference image.

    Parameters:
    - image_list: a list of image names or "ALL" to indicate all images.
    - ref_img: the path to the reference image.
    - dataset_dir: the path to the dataset directory.
    """

    # Step 1: Detect all images under the specified directory
    image_dir = os.path.join(dataset_dir, "image_original")
    all_images = [os.path.join(image_dir, img_name) for img_name in os.listdir(image_dir) if img_name.endswith(('jpg', 'png', 'jpeg'))]

    # Step 2: Update image_list based on the input
    if image_list == "ALL":
        image_list = all_images
    else:
        image_list = [img for img in all_images if any(s in img for s in image_list)]
    image_list.sort(key=lambda x: int(re.search(r'frame_(\d+).jpg', os.path.basename(x)).group(1)))
    rows = []

    x0,y0,w0,h0= start_box[0], start_box[1],start_box[2]-start_box[0],start_box[3]-start_box[1]
    # Iterating over each image
    for img in tqdm(image_list):
        # Extract face information
        dist=99999999
        conf=0.99
        face_objs = DeepFace.extract_faces(img_path = img, target_size = (224, 224),detector_backend = 'retinaface')
        for face_obj in face_objs:
            facebox=face_obj["facial_area"]
            x=facebox["x"]
            y=facebox["y"]
            w=facebox["w"]
            h=facebox["h"]
            err=(x-x0)**2+(y-y0)**2
            if err<dist and abs(x-x0)< w0*coefficient and abs(y-y0)< h0*coefficient:
                x1,y1,w1,h1,conf=x,y,w,h,0
                dist=err

        path = os.path.join("image_original", os.path.basename(img))
        if conf==0:
            x0,y0,w0,h0=x1,y1,w1,h1
        rows.append([path] + [x0,y0,x0+w0,y0+h0]+[conf])

    # Convert list of rows to a DataFrame
    df = pd.DataFrame(rows)

    # Step 5: Save the dataframe to a file
    output_path = os.path.join(dataset_dir, "head_information.txt")
    df.to_csv(output_path, index=False, header=False)


def composition_image(image_path, image, gaze_heatmap_pred, depth, face, head_channel,demo_dir):

    # heatmap, original image, depth --> Saving the composite image demo
    original_image = Image.open(image_path).convert("RGB")
    w,h=original_image.size
    composite_img = np.zeros((h*2, w*2, 3), dtype=np.uint8) # Denormalize
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    #original_image = image.cpu() * std + mean
    #original_image = original_image[0][[0, 1, 2], :, :]
    #original_image = transforms.ToPILImage()(original_image).convert("RGB")
    original_image = np.array(original_image.resize((w, h)))
    heatmap_image_resized = gaze_heatmap_pred.resize((w, h))
    heatmap_image_resized = np.array(heatmap_image_resized)
    heatmap_image_resized = np.stack([heatmap_image_resized * 1, heatmap_image_resized* 1, heatmap_image_resized * 0], axis=2)
    masked_image = cv2.addWeighted(original_image, 1, heatmap_image_resized, 2, 0)
    composite_img[0:h, 0:w, :] = masked_image

    # LEFT DOWN: Depths image
    depths_image = transforms.ToPILImage()(depth.cpu()).convert("RGB")
    depths_image = np.array(depths_image.resize((w, h)))
    composite_img[h:2*h, 0:w, :] = depths_image

    # RIGHT UP: Faces image
    faces_image = face.cpu()* std + mean
    faces_image = faces_image[0][[0, 1, 2], :, :]
    faces_image = transforms.ToPILImage()(faces_image).convert("RGB")
    faces_image = np.array(faces_image.resize((w, h)))
    composite_img[0:h, w:2*w, :] = faces_image

    # RIGHT DOWN: Head_channels image (grayscale
    heatmap_image = gaze_heatmap_pred.resize((w, h))
    heatmap_image = np.array(heatmap_image.resize((w, h)))
    #head_channels_image = transforms.ToPILImage()(head_channel.cpu()).convert("L")
    #head_channels_image = np.array(head_channels_image.resize((w, h)))
    composite_img[h:2*h, w:2*w, 0] = heatmap_image
    composite_img[h:2*h, w:2*w, 1] = heatmap_image
    composite_img[h:2*h, w:2*w, 2] = heatmap_image
    # Save composite_img
    composite_file_name = f"composite_{os.path.basename(image_path)}"
    cv2.imwrite(os.path.join(demo_dir, composite_file_name), cv2.cvtColor(composite_img, cv2.COLOR_RGB2BGR))


def plot_image(image, gaze_heatmap_pred, depth, face, head_channel):

    plt.figure(figsize=(16, 4))

    # Original Image with Heatmap
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    original_image = image.cpu() * std + mean
    original_image = original_image[0][[0, 1, 2], :, :]
    original_image = transforms.ToPILImage()(original_image).convert("RGB")
    original_image = np.array(original_image.resize((224, 224)))
    heatmap_image_resized = np.array(gaze_heatmap_pred.resize((224, 224)))
    heatmap_image_resized = np.stack([heatmap_image_resized, heatmap_image_resized, 0 * heatmap_image_resized], axis=2)
    masked_image = cv2.addWeighted(original_image, 1, heatmap_image_resized, 2, 0)
    plt.subplot(1, 4, 1)
    plt.imshow(masked_image)
    plt.title("Image with Heatmap")
    plt.axis("off")

    # Depth Image
    depths_image = transforms.ToPILImage()(depth.cpu()).convert("RGB")
    depths_image = np.array(depths_image.resize((224, 224)))
    plt.subplot(1, 4, 2)
    plt.imshow(depths_image)
    plt.title("Depth")
    plt.axis("off")

    # Faces Image
    faces_image = face.cpu() * std + mean
    faces_image = faces_image[0][[0, 1, 2], :, :]
    faces_image = transforms.ToPILImage()(faces_image).convert("RGB")
    faces_image = np.array(faces_image.resize((224, 224)))
    plt.subplot(1, 4, 3)
    plt.imshow(faces_image)
    plt.title("Faces")
    plt.axis("off")

    # Head Channels Image
    head_channels_image = transforms.ToPILImage()(head_channel.cpu()).convert("L")
    head_channels_image = np.array(head_channels_image.resize((224, 224)))
    plt.subplot(1, 4, 4)
    plt.imshow(head_channels_image, cmap='gray')
    plt.title("Head Channels")
    plt.axis("off")

    plt.show()


def predict_image_without_annotation(dataset_dir, image_list,startbox, device="cuda",plot=True, heatmap=True, Demo=True,coefficient=2, depth_mode="midas"):
    #generate annotation under dataset_dir
    print("======Generate Head Box======")
    generate_facemask(image_list,startbox,dataset_dir,coefficient=coefficient) #Genearte a annotation file
    predict_image_with_annotation(dataset_dir, image_list, device="cuda",plot=plot, heatmap=heatmap, Demo=Demo,depth_mode=depth_mode)

def predict_image_with_annotation(dataset_dir, image_list, device="cuda",plot=True, heatmap=True, Demo=True, depth_mode="midas"):

# dataset_dir: The Dir should in style of gazefollow, that contains head_information.txt and file named image_original contains all images
# checkpoint_dir: The dir of model weight.
# image_list: A list of image name, should be same with first column in file test_annotations_release.txt. If image_name=“ALL”, it process all image detected.
# ====== Output =======
# if plot, only plot the result 2 image per row;
# if heatmap, save heatmap prediction to heatmap_predict under dataset_dir;
# if demo, save square diagram into heatmap_predict under dataset_dir.


# Get config
    print("======Loading Config======")
    config= get_config()
    config= update_config(config,dataset_dir, device)
    device = torch.device(config.device)
    print(f"Running on {device}")

# Make Datasets
    print("======Loading dataset======")
    target_test_loader = test_loader(config,image_list)

# Load model
    print("======Loading model======")
    model = get_model(config, device=device)
    pretrained_dict = torch.load(config.eval_weights, map_location=device)
    pretrained_dict = pretrained_dict.get("model_state_dict") or pretrained_dict.get("model")
    model = load_pretrained(model, pretrained_dict)

# Process Depth File
    print("======Process Depth======")
    input_path=os.path.join(dataset_dir,"image_original")
    output_path=os.path.join(dataset_dir,"depth_intermediate")
    model_weights="weights\\dpt_large-midas-2f21e586.pt"
    
    if depth_mode=="midas":
        extract_depth(input_path, output_path, model_weights, "dpt_large", "no-optimize", image_list)
    else:
        extract_depth_dino(input_path, output_path, model_weights, "dpt_large", "no-optimize", image_list)
        
# Prediction (dataloader, model, config --> saved image prediction)
    print("======Prediction======")
    model.eval()
    gaze_inside_all = []
    gaze_inside_pred_all = []
    # # Prediction (dataloader, model, config --> batched prediction array)
    with torch.no_grad():
        for batch, data in tqdm(enumerate(target_test_loader)):
            (
                images,
                depths,
                faces,
                head_channels,
                img_size,
                path,
            ) = data
            images = images.to(device, non_blocking=True, memory_format=get_memory_format(config))
            depths = depths.to(device, non_blocking=True, memory_format=get_memory_format(config))
            faces = faces.to(device, non_blocking=True, memory_format=get_memory_format(config))
            head = head_channels.to(device, non_blocking=True, memory_format=get_memory_format(config))
            gaze_heatmap_pred, gaze_inside_pred, _, _, _ = model(images, depths, head, faces)
            gaze_inside_pred_all.extend(gaze_inside_pred.squeeze(1).cpu().tolist())
            gaze_heatmap_pred = gaze_heatmap_pred.squeeze(1).cpu()

    # # Prediction (batched prediction array --> save)
            # Define the directory where you want to save the predicted images
            heatmap_dir = os.path.join(config.dataset_dir,"predict_heatmap")
            demo_dir = os.path.join(config.dataset_dir,"predict_demo")
            os.makedirs(heatmap_dir, exist_ok=True)  # Ensure the directory exists
            os.makedirs(demo_dir, exist_ok=True)  # Ensure the directory exists

            # heatmap array, Address --> Saving the gaze_heatmap_pred image
            for i, image_path in enumerate(path):

                heatmap_image = gaze_heatmap_pred[i]
                heatmap_image = transforms.ToPILImage()(heatmap_image)

                if heatmap:
                    heatmap_image.save(os.path.join(heatmap_dir, os.path.basename(image_path)))

                if plot:
                    plot_image(images[i], heatmap_image, depths[i], faces[i], head_channels[i])

                if Demo:
                    composition_image(os.path.join(dataset_dir,image_path), images[i], heatmap_image, depths[i], faces[i], head_channels[i],demo_dir)
    print("ALl finished")


def predict_video(vid_dir, startbox, sampling_fps, device="cuda",annotated=False, depth_mode="midas"):
    """
    Predict on video frames.

    Parameters:
    - vid_dir: The path to the video.
    - save_dir: Directory to save extracted images.
    - ref_img: The path to the reference image.
    - sampling_fps: FPS to sample frames from the video.
    - device: Device to use for prediction (default is "cuda").
    """

    print("======Load Video======")
    # Use OpenCV to get video properties
    cap = cv2.VideoCapture(vid_dir)
    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    video_fps = int(cap.get(cv2.CAP_PROP_FPS))
    cap.release()

    print(f"Length: {video_length} frames, Frequency: {video_fps} FPS")

    if sampling_fps > video_fps:
        raise ValueError("The sampling FPS cannot be greater than the video's FPS.")

    # Calculate frame extraction frequency based on desired sampling FPS
    freq = video_fps // sampling_fps

    print("======Extract Image from Video======")

    # Create a directory to save extracted images
    basename_without_ext = os.path.splitext(os.path.basename(vid_dir))[0]
    save_dir=os.path.join(os.path.dirname(vid_dir),basename_without_ext)
    os.makedirs(save_dir, exist_ok=True)
    img_dir = os.path.join(save_dir, "image_original")
    if not os.path.exists(img_dir):
        os.makedirs(img_dir, exist_ok=True)

    # Extract frames from the video
    cap = cv2.VideoCapture(vid_dir)
    frame_number = 0

    for _ in tqdm(range(video_length), desc="Extracting Frames"):
        ret, frame = cap.read()
        if not ret:
            break
        if frame_number % freq == 0:
            frame_path = os.path.join(img_dir, f"frame_{frame_number}.jpg")
            cv2.imwrite(frame_path, frame)
        frame_number += 1
    cap.release()

    # Call predict_image_without_annotation function
    if annotated:
        predict_image_with_annotation(save_dir, "ALL", plot=False,device=device, depth_mode=depth_mode)
    else:
        predict_image_without_annotation(save_dir, "ALL", startbox=startbox, plot=False,device=device,coefficient=2,depth_mode=depth_mode)

    print("======Import Image to Video======")
    # Output video path
    output_video_path = os.path.join(save_dir, f"predict_demo/{basename_without_ext}_output.avi")
    # Get size info from the first image
    first_image_path = os.path.join(save_dir, "predict_demo/composite_frame_0.jpg")
    img_sample = cv2.imread(first_image_path)
    height, width, _ = img_sample.shape

    # Set up video writer
    fourcc = cv2.VideoWriter_fourcc(*'XVID')  # Can be changed based on codecs available on the system
    out = cv2.VideoWriter(output_video_path, fourcc, sampling_fps, (width, height))

    # Read all processed images and write them to the video
    for i in tqdm(range(0, video_length, freq), desc="Importing Frames"):
        frame_path = os.path.join(save_dir, f"predict_demo/composite_frame_{i}.jpg")
        img = cv2.imread(frame_path)
        out.write(img)

    out.release()
    print(f"Processed video saved to {output_video_path}")


# Prediction
 

## Method1: With given teacher start position

**Given**: 

a. Video path

b. The tracking information needed for teacher.
(eg: open the video ---> pause at the first frame ---> navigate roughtly bbox of teacher head. )

**Wanted**:

Predicted Gaze Image.

In [3]:
video_path=r"C:\Datasets\Engagement\VIDEO000\00000.avi"
start_bbox_of_teacher_in_Video=[267,200,314,248]
sampling_fps=2

predict_video(video_path,startbox=start_bbox_of_teacher_in_Video,sampling_fps=sampling_fps,annotated=True,depth_mode="midas")
#predict_video(video_path,startbox=start_bbox_of_teacher_in_Video,sampling_fps=sampling_fps,annotated=True,depth_mode="dino")

Length: 751 frames, Frequency: 25 FPS


Extracting Frames: 100%|████████████████████████████████████████████████████████████▉| 750/751 [00:07<00:00, 99.24it/s]


Running on cuda
Total params: 92183098
Total trainable params: 92183098
<All keys matched successfully>
initialize
device: cuda
start processing

 Working on folder C:\Datasets\Engagement\VIDEO000\00000\image_original
63


63it [00:30,  2.09it/s]


finished


16it [00:29,  1.87s/it]


ALl finished


Importing Frames: 100%|████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 15.56it/s]

Processed video saved to C:\Datasets\Engagement\VIDEO000\00000\predict_demo/00000_output.avi





## Method2: With given teacher segmentation

**Given**: 

a. Video path

b. mask of teacher.

**Wanted**:

Predicted Gaze Image.

In [4]:
video_path=r"D:\Datasets\engagement_follow\Video\00000.avi"
start_bbox_of_teacher_in_Video=[267,200,314,248]
sampling_fps=2

#predict_video_with_segmentation(video_path,startbox=start_bbox_of_teacher_in_Video,sampling_fps=sampling_fps,annotated=False)
