In [1]:
import cv2
import torch
from ultralytics import YOLO
import numpy as np
import math
from numpy import random
from IPython.display import HTML
import torchvision.models as torch_models
from base64 import b64encode
import os
from IPython.display import Video
from utils import *
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from cls_detection import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps") if torch.backends.mps.is_available() else device
print(device)

preprocess = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
])

yolo_model_path = "weights/detect_large.pt"
model = YOLO(yolo_model_path)
classNames = ['basketball', 'hoop', 'person']

cuda


In [4]:
all_pos_vid = "video_test_dataset/all_made.mp4"
all_neg_vid = "video_test_dataset/all_miss.mp4"

all_models = [
    "cls_chkpoint_resnet18/checkpoint_2023-12-24-15-10_lr_0.0001_batch_64/best_model.pth",
    "cls_chkpoint_resnet18/checkpoint_2023-12-24-15-02_lr_0.001_batch_64/best_model.pth",
    "cls_chkpoint_resnet18/checkpoint_2023-12-24-15-27_lr_0.0005_batch_64/best_model.pth",
    ]

batch_size = 128 
conf = 0.4
device = "cuda:1"

for model_path in all_models:
    pos_results = inference_by_batch(yolo_model_path,
                    model_path,
                    video_path = all_pos_vid,
                    save_result_vid = False,
                    display_result = False,
                    batch_size = batch_size,
                    show_progress=True,
                    cls_conf_threshold=conf,
                    device=device,
                    model_type="resnet18"
                    )
    pos_score = len(pos_results)
    
    neg_results = inference_by_batch(yolo_model_path,
                    model_path,
                    video_path = all_neg_vid,
                    save_result_vid = False,
                    display_result = False,
                    batch_size = batch_size,
                    show_progress=True,
                    cls_conf_threshold=conf,
                    device=device,
                    model_type="resnet18"
                    )
    neg_score = len(neg_results)
    print(f"{model_path}: pos_score: {pos_score}, neg_score: {neg_score}")

Loading models...


In [3]:
all_models = [
    "cls_chkpoint_resnet18/checkpoint_2023-12-24-15-10_lr_0.0001_batch_64/best_model.pth",
    "cls_chkpoint_resnet18/checkpoint_2023-12-24-15-02_lr_0.001_batch_64/best_model.pth",
    "cls_chkpoint_resnet18/checkpoint_2023-12-24-15-27_lr_0.0005_batch_64/best_model.pth",
    ]

batch_size = 128 
conf = 0.4

inference_by_batch(yolo_model_path,
                    all_models[0],
                    video_path = "video_test_dataset/0/RPReplay_Final1702697672 4.MOV",
                    save_result_vid = True,
                    display_result = True,
                    batch_size = 128 * 2,
                    show_progress=True,
                    cls_conf_threshold=conf,
                    device="cuda:1",
                    model_type="resnet18"
                    )

Loading models...
Models loaded!
Initializing video capture...


  0%|          | 0/1 [01:14<?, ?it/s]


KeyboardInterrupt: 

In [5]:
def cls_predict_image(cls_model, img, preprocess, device, threshold = 0.5):
    input_tensor = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        cls_output = cls_model(input_tensor)
    probability = torch.sigmoid(cls_output.squeeze())

    # prob, predicted_class = torch.max(probability, dim=0)
    # return predicted_class.item(), prob.item()
    return probability[1] > threshold, probability

def predict_hoop_box(img, cls_model, x1, y1, x2, y2, preprocess, device, threshold = 0.5):
    cropped_img = img[y1:y2, x1:x2]
    cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))
    # Preprocess the cropped image
    predicted_class, prob = cls_predict_batch(cls_model, cropped_img_pil, preprocess, device, threshold)
    return cropped_img_pil, predicted_class, prob

def cls_predict_batch(cls_model, batch_imgs, preprocess, device, threshold = 0.5):
    # Process and batch images
    # check if batch_imgs is a list.
        
    if isinstance(batch_imgs, list):
        batch_tensor = torch.stack([preprocess(img) for img in batch_imgs])
        batch_tensor = batch_tensor.to(device)
    else:
        batch_tensor = preprocess(batch_imgs).unsqueeze(0).to(device)

    # Forward pass for the whole batch
    with torch.no_grad():
        cls_output = cls_model(batch_tensor)

    # Calculate probabilities and predicted classes
    probabilities = torch.sigmoid(cls_output)
    return probabilities > threshold, probabilities[:, 1].tolist()

def predict_hoop_box_batch(img, cls_model,  preprocess, device, threshold = 0.5):
    cropped_imgs_pil = []

    for cropped_img in img:
        cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))
        cropped_imgs_pil.append(cropped_img_pil)

    return cls_predict_batch(cls_model, cropped_imgs_pil, preprocess, device, threshold)


def inference_by_batch(yolo_model_path,
                       cls_model_path,
                       video_path, 
                       cls_conf_threshold = 0.6,
                       detect_conf_threshold = 0.4,
                       save_result_vid = False, 
                       output_dir = None, 
                       saved_video_name = None,
                       batch_size=128,
                       display_result = False,
                       show_progress = True,
                       skip_to_sec = 0,
                       show_score_prob = False,
                       cls_img_size = 112,
                       device = device,
                       model_type = "resnet50"
                       ):
    preprocess = transforms.Compose([
                transforms.Resize((cls_img_size, cls_img_size)),
                transforms.ToTensor(),
            ])
    
    print("Loading models...")
    model = YOLO(yolo_model_path)
    cls_model = load_resnet50(cls_model_path, device=device) if model_type == "resnet50" else load_resnet18(cls_model_path, device=device)
    cls_model.to(device)    
    cls_model.eval()
    print("Models loaded!")
    
    cap, fps, frame_width, frame_height = get_video_info(video_path)
    if skip_to_sec > 0:
        cap.set(cv2.CAP_PROP_POS_MSEC, skip_to_sec * 1000)
        
    num_skiped_frames = int(skip_to_sec * fps)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - num_skiped_frames
    
    print("Initializing video capture...")
    if save_result_vid:
        video_name = video_path.split("/")[-1]
        video_name = video_name.split(".")[0] + ".mp4"

        if saved_video_name is not None:
            output_path = saved_video_name if output_dir is None else os.path.join(output_dir, saved_video_name)
        else:
            output_path = "inferenced_" + video_name if output_dir is None else os.path.join(output_dir, "inferenced_" + video_name)
            
        codec = cv2.VideoWriter_fourcc(*'vp09')
        out = cv2.VideoWriter(output_path, codec, fps, (frame_width,frame_height))
    
    num_batches = math.ceil(total_frames / batch_size)

    results = []
    score_timestamps = []
    
    count = 0
    score = 0
    display_prob = [0.0]
    
    if show_progress:
        batch_range = tqdm(range(num_batches))
    else:
        batch_range = range(num_batches)

    for i in batch_range:
        frames = []
        for i in range(batch_size):
            ret, img = cap.read()
            if ret:
                frames.append(img)
            else:
                break

        if frames:
            results = model(frames, 
                            stream=False, 
                            verbose = False, 
                            conf=detect_conf_threshold,
                            device=device)
        else:
            continue
        
        print("Finished detecting objects in batch", i + 1, "out of", num_batches)

        for c, r in tqdm(enumerate(results)):
            print(c)
            img = r.orig_img
            boxes = r.boxes
            cropped_images = []
            count += 1
            for box in boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0])  # convert to int values
                confidence = box.conf.item()
                predicted_class = model.names[int(box.cls)] 
                if predicted_class == "hoop":
                    cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
                    cv2.putText(img, f'{predicted_class}: {confidence:.3f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                    
                    if x1 > x2 or y1 > y2:
                        continue
                    else:
                        cropped_img = img[y1:y2, x1:x2]
                        cropped_images.append(cropped_img)
                        
                if predicted_class == "basketball":
                    cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    cv2.putText(img, f'{predicted_class}: {confidence:.3f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
                    
            
            if len(cropped_images) == 0:
                continue
            pred, prob = predict_hoop_box_batch(cropped_images, cls_model,  preprocess, device, threshold=cls_conf_threshold)
            if pred.sum() > 0 and count > 60:
                score += 1
                count = 0
                current_frame = i * batch_size + c
                time_stamp = current_frame / fps
                score_timestamps.append((time_stamp, prob))
                display_prob = prob
        
            cv2.putText(img, f'Score: {score}', (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 2)
            if show_score_prob:
                cv2.putText(img, f'Prob: {max(display_prob):.3f}', (10, 140), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 2)
            if save_result_vid:
                out.write(img)
        print("finished inferencing with cls")
        if not ret:
            break
        
    if save_result_vid:
        print("releasing video writer")
        out.release()
    print("releasing video capture")
    cap.release()
    

    if display_result:
        display(Video(output_path, embed=True))
        return score_timestamps, output_path
    else:
        return score_timestamps

In [11]:
Image.open("classification_dataset_groupby_env/mima/0/2023_06_21_-_Game_2-xJLCPqvNo00_mp4-1_jpg.rf.9a5945913f60f3140f47637565fc1c4e.jpg").size

(100, 97)