In [1]:
from model import MODEL
from ultralytics import YOLO
import cv2
from video_utils import *
import torch
import math
import numpy as np
import time
from obj_detection_utils import *
from get_yt_vids import *
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

cuda:0


In [2]:
model = MODEL("weights/detect_large.pt", device)
#model = YOLO("weights/detect_large.pt")

In [8]:
video_path = "full_videos/NM9_fvYsWME.mp4"
conf = 0.4
skip_to_sec = 5
end_at_sec = 8
batch_size = 32
show_progress = True
write_video= True

cap, fps, frame_width, frame_height, total_frames = initialize_video_capture(video_path=video_path, skip_to_sec = skip_to_sec)
out, output_path = initialize_video_writer(fps = fps,
                                           video_dimension= (frame_width, frame_height),
                                           video_path=video_path,
                                           saved_video_name="output.mp4"
                                           )

num_batches = math.ceil(total_frames / batch_size)
reached_stopping_time = False
box_containing_ball_prev = None
box_containing_ball_cur = None
score = 0
no_relevant_ball = True
last_scored_time = -1

# Initialize the progress bar if needed
if show_progress:
    batch_range = tqdm(range(num_batches), desc='Processing Batches')
else:
    batch_range = range(num_batches)

timestamps = []

for i in batch_range:
    start_time = time.time()  # Start time for fps calculation
    if reached_stopping_time:
        break
    frames = []
    for _ in range(batch_size):
        ret, img = cap.read()
        if ret:
            frames.append(img)
        else:
            break

    if frames:
        results = model.predict(frames)
    else:
        continue

    for idx, (frame, r) in enumerate(zip(frames, results)):
        current_frame_num = idx + i * batch_size
        current_time = skip_to_sec + current_frame_num / fps
        if current_time >= end_at_sec:
            reached_stopping_time = True
            break
        
        if write_video:
            cv2.putText(frame, f"Score: {score}", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 255, 255), 2)
        ##### for native ultralytics model: #####
        # boxes = r.boxes
        # bounding_boxes = boxes.xyxy.cpu().int().numpy()
        # cls =  boxes.cls.int().cpu().numpy()
        
        #### for custom YOLO #####
        bounding_boxes = r['boxes']
        cls = r['labels']

        labels = [model.model.names[i] for i in cls]
        objects = {label: [] for label in labels}
        

        for box, label in zip(bounding_boxes, labels):
            objects[label].append(box)
        
        if "basketball" not in objects or "hoop" not in objects:
            if write_video:
                out.write(frame)       
            continue
        hoop_boxes = objects["hoop"]
        detection_areas = [get_detection_box(*box) for box in hoop_boxes]
        entry_boxes = [get_entry_box(*box) for box in hoop_boxes]
        exit_boxes = [get_exit_box(*box) for box in hoop_boxes]
        ball_centers = [get_center(*box) for box in objects["basketball"]]
        # relevant_ball_centers = [center for center in ball_centers 
        #                                 for det_area in detection_areas
        #                                 if is_in_box(*center, *det_area)]
        relevant_ball_boxes = [box for box in objects["basketball"] 
                                        for det_area in detection_areas
                                        if is_in_box(*box, *det_area)]
        if not relevant_ball_boxes:
            no_relevant_ball = True
            if write_video:
                out.write(frame)
            continue
        else:
            no_relevant_ball = False
        if write_video:
            for ball_boxes in objects["basketball"]:
                cv2.circle(frame, get_center(*ball_boxes), 5, COLORS["basketball"], -1)
        focus_areas = {
            #"detection_area": detection_areas,
            "hoop_box": hoop_boxes,
            "entry_box": entry_boxes,
            "exit_box": exit_boxes
        }
        
        # determine which box the ball is in
        for box_name, all_boxes in focus_areas.items():
            for box in all_boxes:
                if any([is_in_box(*relevant_ball_boxes, *box, threshold=0.55) for relevant_ball_boxes in relevant_ball_boxes]):
                    box_containing_ball_cur = box_name #if not no_relevant_ball else None
                    if write_video:
                        cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), COLORS[box_name], 2)
                        cv2.putText(frame, box_name, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, COLORS[box_name], 2)
                else:
                    if write_video:
                        cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 0, 0), 2)
                        cv2.putText(frame, box_name, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
        ball_in_interested_area = (box_containing_ball_cur == "hoop_box" or box_containing_ball_cur == "exit_box")
        
        time_since_last_scored = current_time - last_scored_time
        if box_containing_ball_prev == "entry_box" and ball_in_interested_area and time_since_last_scored > 1:
            score += 1
            last_scored_time = current_time
            timestamps.append(current_time)
            
        box_containing_ball_prev = box_containing_ball_cur
        if write_video:
            cv2.putText(frame, f"Score: {score}", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 255, 255), 2)
            cv2.putText(frame, f"ball in: {box_containing_ball_cur}", (20, 150), cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 255, 255), 2)
            out.write(frame)  
    if show_progress:
        elapsed_time = time.time() - start_time  # Elapsed time for batch
        fps = batch_size / elapsed_time  # Calculate fps based on batches processed
        batch_range.set_postfix(fps=f"{fps:.2f} fps", refresh=True)        
        
            
cap.release()
out.release()

display_video(output_path, width = 1000)

Processing Batches:   1%|▏         | 6/424 [00:03<03:47,  1.84it/s, fps=71.79 fps]


In [7]:
model.model.names

{0: 'basketball', 1: 'hoop', 2: 'person'}

In [25]:
def get_scoring_timestamps(video_path = None,
                           url = None, 
                           model = model,
                           skip_to_sec = 0,
                           batch_size = 16,
                           show_progress = True,
                            ):
    if not video_path and not url:
        raise ValueError("Either video_path or url must be provided")
    elif url and not video_path:
        video_path = download_video(url, "testing_videos")
    cap, fps, frame_width, frame_height, total_frames = initialize_video_capture(video_path=video_path, skip_to_sec = skip_to_sec)

    num_batches = math.ceil(total_frames / batch_size)

    box_containing_ball_prev = None
    score = 0

    time_since_last_score = np.inf
    frame_since_last_score = 2 ** 10
    if show_progress:
        batch_range = tqdm(range(num_batches))
    else:
        batch_range = range(num_batches)

    timestamps = []
    for i in batch_range:
        frames = []
        for i in range(batch_size):
            ret, img = cap.read()
            if ret:
                frames.append(img)
            else:
                break

        if frames:
            results = model(frames, 
                            stream=False, 
                            verbose = False, 
                            conf=conf,
                            device=device)
        else:
            continue

        for r in results:
            boxes = r.boxes
            
            bounding_boxes = boxes.xyxy.cpu().numpy()
            bounding_boxes = bounding_boxes.astype(int)
            labels = [model.names[i] for i in boxes.cls.cpu().numpy()]
            
            objects = {label: [] for label in labels}
            

            for box, label in zip(bounding_boxes, labels):
                objects[label].append(box)
            
            if "basketball" not in objects or "hoop" not in objects:     
                continue
            hoop_box = objects["hoop"]
            detection_area = [get_detection_box(*box) for box in hoop_box]
            entry_box = [get_entry_box(*box) for box in hoop_box]
            exit_box = [get_exit_box(*box) for box in hoop_box]
            ball_center = [get_center(*box) for box in objects["basketball"]]
            relevant_ball_centers = [center for center in ball_center 
                                            for det_area in detection_area
                                            if is_in_box(*center, *det_area)]
            if not relevant_ball_centers:
                continue

            focus_areas = {
                "detection_area": detection_area,
                "hoop_box": hoop_box,
                "entry_box": entry_box,
                "exit_box": exit_box
            }
            box_containing_ball_cur = None
            # determine which box the ball is in
            for box_name, all_boxes in focus_areas.items():
                for box in all_boxes:
                    if any([is_in_box(*relevant_ball_center, *box) for relevant_ball_center in relevant_ball_centers]):
                        box_containing_ball_cur = box_name 

            ball_in_interested_area = (box_containing_ball_cur == "hoop_box" or box_containing_ball_cur == "exit_box")
            time_since_last_score = frame_since_last_score / fps
            if box_containing_ball_prev == "entry_box" and ball_in_interested_area and time_since_last_score > 2:
                score += 1
                frame_since_last_score = 0
                timestamps.append(cap.get(cv2.CAP_PROP_POS_MSEC) / 1000)
                
                
            box_containing_ball_prev = box_containing_ball_cur
            frame_since_last_score += 1         
                
    cap.release()
    
    timestamps = [t + skip_to_sec for t in timestamps]
    
    return video_path, timestamps

def trim_highlights_from_timestamps(video_path,
                      score_timestamps, 
                      clip_start_offset = 6, # number of seconds before scoring
                      clip_end_offset = 2,   # number of seconds after scoring
                      output_path = ".",
                      ffmpeg_path = "ffmpeg-git-20240203-amd64-static/ffmpeg"):
    # Create output directory if it doesn't exist
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    video_name = os.path.basename(video_path)
    video_name = os.path.splitext(video_name)[0]
    
    for i, timestamp in enumerate(score_timestamps):
        start_time = max(0, timestamp - clip_start_offset)
        end_time = timestamp + clip_end_offset
        clip_output_path = os.path.join(output_path, f"{video_name}_highlight_{i}.mp4")

        # Construct FFmpeg command for trimming
        ffmpeg_command = [ffmpeg_path, '-i', video_path, '-ss', str(start_time), '-to', str(end_time), '-c', 'copy', clip_output_path]
        subprocess.run(ffmpeg_command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        

    
    
def generate_highlights(video_path = None,
                           url = None, 
                           model = model,
                           skip_to_sec = 0,
                           batch_size = 16,
                           show_progress = True,
                           highlight_output_path = "highlights"
                           ):
                        
    video_path, timestamps = get_scoring_timestamps(video_path = video_path,
                           url = url, 
                           model = model,
                           skip_to_sec = skip_to_sec,
                           batch_size = batch_size,
                           show_progress = show_progress,
                            )
    print(f"found {len(timestamps)} highlights")
    
    print("starting to trim highlights")
    trim_highlights_from_timestamps(video_path, timestamps, output_path = highlight_output_path)
    print("finished trimming highlights")

                           
    

In [24]:
def download_video(url, save_path, resolution=None, ffmpeg_path = "ffmpeg-git-20240203-amd64-static/ffmpeg"):
    # Define download options for yt-dlp
    ydl_opts = {
        'outtmpl': os.path.join(save_path, '%(id)s.%(ext)s'),
        'format': 'bestvideo',
        'postprocessors': [{
            'key': 'FFmpegVideoConvertor',
            'preferedformat': 'mp4',  # Convert to mp4 if necessary
            
        }],  
        "ffmpeg_location": f"{ffmpeg_path}"
    }
    
    # If a specific resolution is requested, adjust the format selection
    if resolution:
        ydl_opts['format'] = f'bestvideo[height<={resolution}]'
    else:
        # Ensure the format is set to mp4 for consistency and compatibility
        ydl_opts['format'] += '[ext=mp4]'

    # Ensure the save directory exists
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    video_id = get_yt_video_id(url)
    video_path = os.path.join(save_path, f"{video_id}.mp4")
    if os.path.isfile(video_path):
        print(f'Video {video_id} already exists.')
        return video_path
    # Download the video
    with YoutubeDL(ydl_opts) as ydl:
        video_info = ydl.extract_info(url, download=True)
        video_id = video_info.get('id')
        
    return video_path

In [6]:
def download_video(url, save_path, resolution=None):
    yt = YouTube(url)
    if resolution:
        video = yt.streams.filter(mime_type="video/mp4", res = resolution).first()
    else:
        video = yt.streams.filter(mime_type="video/mp4").order_by("resolution").desc().first()

    video_id = get_yt_video_id(url)
    video_path = os.path.join(save_path, f"{video_id}.mp4")
    
    # if video does not exist, download it
    if not os.path.isfile(video_path):
        print(f'Downloading video {video_id}...')
        video.download(output_path=save_path, filename=video_id+".mp4")
    else:
        print(f'Video {video_id} already exists.')
    return video_path

In [7]:
download_video( "https://www.youtube.com/watch?v=NM9_fvYsWME&t=2s", "testing_videos")

Downloading video NM9_fvYsWME...


'testing_videos/NM9_fvYsWME.mp4'

In [27]:
trim_highlights_from_timestamps("testing_videos/NM9_fvYsWME.mp4", [10, 20, 30], output_path = "highlights")

In [29]:
yt_links = [
    "https://www.youtube.com/watch?v=NM9_fvYsWME&t=2s",
    "https://www.youtube.com/watch?v=4PLIiY_sJTo",
    "https://www.youtube.com/watch?v=5qDqxZhOtlM",
    "https://www.youtube.com/watch?v=w2wkz62PJeY",
]

for link in yt_links:
    generate_highlights(url = link,
                           batch_size= 16,
                           skip_to_sec = 240)

Video NM9_fvYsWME already exists.


  8%|▊         | 31/407 [00:07<01:32,  4.08it/s]


KeyboardInterrupt: 