<a href="https://colab.research.google.com/github/yonatantussa/gaze-estimation/blob/main/gaze_estimation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install retina-face

In [2]:
import torch
import cv2
import numpy as np
from retinaface import RetinaFace
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import os
from tqdm import tqdm

In [4]:
class VideoGazeAnalyzer:
  def __init__(self, use_cuda=True):
    self.use_cuda = use_cuda
    self.device = torch.device("cuda" if use_cuda else "cpu")
    print(self.device)

    self.model, self.transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_input', pretrained=True)
    self.model.to(self.device)
    self.model.eval()

    self.colors = ['yellow', 'cyan', 'lime', 'red']

  def process_frame(self, frame):
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_pil = Image.from_array(frame_rgb)

    width, height = frame_pil.size

    results = RetinaFace.detect_faces(frame_pil)

    bboxes = [results[key]['facial_area'] for key in results.keys()]

    if not bboxes:
      return frame

    norm_bboxes = [
        [bbox[0] / width, bbox[1] / height, bbox[2] / width, bbox[3] / height]
        for bbox in bboxes
    ]

    img_tensor = self.transform(frame_pil).unsqueeze().to(self.device)

    input_data = {'images': img_tensor, 'bboxes': norm_bboxes}

    with torch.no_grad():
      output = self.mode(input_data)

  def visualize_all(self, pil_image, heatmaps, bboxes, input_scores, inpput_thresh=0.5):
    """ Visualize all detected faces and their gaze directions"""

    overlay_image = pil_image.convert("RGBA")