In [2]:
import argparse
import logging
import os

import cv2
import torch
import yt_dlp
from mivolo.data.data_reader import InputType, get_all_files, get_input_type
from mivolo.predictor import Predictor
from timm.utils import setup_default_logging
import csv
_logger = logging.getLogger("inference")


def get_direct_video_url(video_url):
    ydl_opts = {
        "format": "bestvideo",
        "quiet": True,  # Suppress terminal output (remove this line if you want to see the log)
    }

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        info_dict = ydl.extract_info(video_url, download=False)

        if "url" in info_dict:
            direct_url = info_dict["url"]
            resolution = (info_dict["width"], info_dict["height"])
            fps = info_dict["fps"]
            yid = info_dict["id"]
            return direct_url, resolution, fps, yid

    return None, None, None, None


def get_local_video_info(vid_uri):
    cap = cv2.VideoCapture(vid_uri)
    if not cap.isOpened():
        raise ValueError(f"Failed to open video source {vid_uri}")
    res = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    fps = cap.get(cv2.CAP_PROP_FPS)
    return res, fps


def get_parser():
    parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference")
    parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images")
    parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results")
    parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).")
    parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")

    parser.add_argument(
        "--with-persons", action="store_true", default=False, help="If set model will run with persons, if available"
    )
    parser.add_argument(
        "--disable-faces", action="store_true", default=False, help="If set model will use only persons if available"
    )

    parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn")
    parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.")

    return parser

def main():
    parser = get_parser()
    setup_default_logging()
    args = parser.parse_args(args=[
        "--input", "../../data",
        "--output", "output",
        "--detector-weights", "models/yolov8x_person_face.pt",
        "--checkpoint", "models/mivolo_imbd.pth.tar",
        "--device", "cuda:0",
        "--with-persons","--draw"
    ])

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
    os.makedirs(args.output, exist_ok=True)

    predictor = Predictor(args, verbose=True)

    input_type = get_input_type(args.input)

    if input_type == InputType.Video or input_type == InputType.VideoStream:
        if not args.draw:
            raise ValueError("Video processing is only supported with --draw flag. No other way to visualize results.")

        if "youtube" in args.input:
            args.input, res, fps, yid = get_direct_video_url(args.input)
            if not args.input:
                raise ValueError(f"Failed to get direct video url {args.input}")
            outfilename = os.path.join(args.output, f"out_{yid}.avi")
        else:
            bname = os.path.splitext(os.path.basename(args.input))[0]
            outfilename = os.path.join(args.output, f"out_{bname}.avi")
            res, fps = get_local_video_info(args.input)

        if args.draw:
            fourcc = cv2.VideoWriter_fourcc(*"XVID")
            out = cv2.VideoWriter(outfilename, fourcc, fps, res)
            _logger.info(f"Saving result to {outfilename}..")

        for (detected_objects_history, frame) in predictor.recognize_video(args.input):
            if args.draw:
                out.write(frame)

    elif input_type == InputType.Image:
        image_files = get_all_files(args.input) if os.path.isdir(args.input) else [args.input]
        data = []
        for img_p in image_files:
            img = cv2.imread(img_p)
            detected_objects, out_im = predictor.recognize(img)
            gender_face=[detected_objects.genders[ind] for ind in detected_objects.get_bboxes_inds('face')]
            gender_person=[detected_objects.genders[ind] for ind in detected_objects.get_bboxes_inds('person')]
            person = max(len(gender_person),len(gender_face))
            
            
            male=0
            ages= []
            if len(gender_person)>len(gender_face):
                ages = [detected_objects.ages[ind] for ind in detected_objects.get_bboxes_inds('person')]
                for gender in gender_person:
                 if gender == 'male' :
                    male+=1
            else:
                ages = [detected_objects.ages[ind] for ind in detected_objects.get_bboxes_inds('face')]
                for gender in gender_face:
                 if gender == 'male' :
                    male+=1
            
            name_parts = img_p.split('\\')[-2:]  
            name = '\\'.join(name_parts)
            
            new_entri={
                "keyframe":name,
                "person":person,
                "male":male,
                "female":person-male,
                "age":ages
            }
            data.append(new_entri)
            if args.draw:
                bname = os.path.splitext(os.path.basename(img_p))[0]
                filename = os.path.join(args.output, f"out_{bname}.jpg")
                cv2.imwrite(filename, out_im)
                _logger.info(f"Saved result to {filename}")
                
    with open("output.csv", "w", newline='', encoding="utf-8") as file:
        writer = csv.DictWriter(file, fieldnames=["keyframe", "person", "male","female","age"])
        writer.writeheader()
        writer.writerows(data)


if __name__ == "__main__":
    main()


Model summary (fused): 268 layers, 68125494 parameters, 0 gradients, 257.4 GFLOPs


Model meta:
min_age: 1, max_age: 95, avg_age: 48.0, num_classes: 3, in_chans: 6, with_persons_model: True, disable_faces: False, use_persons: True, only_age: False, num_classes_gender: 2, input_size: 224, use_person_crops: True, use_face_crops: True
Model meta:
min_age: 1, max_age: 95, avg_age: 48.0, num_classes: 3, in_chans: 6, with_persons_model: True, disable_faces: False, use_persons: True, only_age: False, num_classes_gender: 2, input_size: 224, use_person_crops: True, use_face_crops: True
Loaded state_dict from checkpoint 'models/mivolo_imbd.pth.tar'
Loaded state_dict from checkpoint 'models/mivolo_imbd.pth.tar'
Model mivolo_d1_224 created, param count: 27432414
Model mivolo_d1_224 created, param count: 27432414
Data processing configuration for current model + dataset:
Data processing configuration for current model + dataset:
	input_size: (3, 224, 224)
	input_size: (3, 224, 224)
	interpolation: bicubic
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	mean: (0.485, 0.456, 0

ValueError: Unknown input data