In [None]:
# Imports
import os
import numpy as np
import torch
import supervision as sv
from google.colab import drive
from inference import get_model
from tqdm import tqdm
from PIL import Image
from transformers import AutoProcessor, SiglipVisionModel
import umap.umap_ as umap
from sklearn.cluster import KMeans
import configparser

In [None]:
# Install required packages
!pip install -q gdown inference-gpu supervision umap-learn torch transformers

In [None]:
# Set up environment
!nvidia-smi
drive.mount('/content/drive')

Fri Oct  4 20:58:14 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8              14W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
# Set CUDA execution provider
os.environ["ONNXRUNTIME_EXECUTION_PROVIDERS"] = "[CUDAExecutionProvider]"

In [None]:
# Initialize model

# Read configuration
config = configparser.ConfigParser()
config.read('config.ini')

roboflow_api_key = config['Roboflow']['API_KEY']
player_tracking_model_id = config['Project']['PROJECT_NAME']

model = get_model(model_id=player_tracking_model_id, api_key=roboflow_api_key)



In [None]:
# Define constants
BALL_ID = 0
PLAYER_ID = 2
GOALKEEPER_ID = 1
REFEREE_ID = 3

In [None]:
# Initialize annotators
ellipse_annotator = sv.EllipseAnnotator(
    color=sv.ColorPalette.from_hex(['#fc2d87', '#57befa', '#fce42d']),
    thickness=3
)

triangle_annotator = sv.TriangleAnnotator(
    color=sv.Color.from_hex('#fff41f'),
    base=20, height=17
)

label_annotator = sv.LabelAnnotator(
    color=sv.ColorPalette.from_hex(['#fc2d87', '#57befa', '#fce42d']),
    text_color=sv.Color.from_hex('#ff4d52'),
    text_position=sv.Position.TOP_CENTER
)

In [None]:
# Initialize tracker
tracker = sv.ByteTrack()

In [None]:
# Helper functions
def extract_crops(video_path, player_id=PLAYER_ID, stride=25):
    """
    Extract player crops from a video.
    """
    frame_generator = sv.get_video_frames_generator(source_path=video_path, stride=stride)
    crops = []
    for frame in tqdm(frame_generator, desc='Extracting crops'):
        results = model.infer(frame, confidence=0.3)[0]
        detections = sv.Detections.from_inference(results)
        detections = detections.with_nms(threshold=0.5, class_agnostic=True)
        player_detections = detections[detections.class_id == player_id]
        players_crops = [sv.crop_image(frame, xyxy) for xyxy in player_detections.xyxy]
        crops += players_crops
    return crops

def team_classifier(crops):
    """
    Classify team based on player crops.
    """
    siglip_model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
    processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

    feature_vectors = []
    for crop in crops:
        crop_image = Image.fromarray(crop)
        inputs = processor(images=crop_image, return_tensors="pt", padding="max_length")
        with torch.no_grad():
            outputs = siglip_model(**inputs)
        feature_vector = outputs.pooler_output
        feature_vectors.append(feature_vector.cpu().numpy())

    feature_vectors = np.concatenate(feature_vectors)
    
    reducer = umap.UMAP(n_components=3)
    clustering_model = KMeans(n_clusters=2)
    
    projections = reducer.fit_transform(feature_vectors)
    clusters = clustering_model.fit_predict(projections)
    
    return clusters

def resolve_goalkeepers_team_id(players, goalkeepers):
    """
    Resolve team ID for goalkeepers based on their position relative to team centroids.
    """
    goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
    players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
    team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
    team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
    goalkeepers_team_id = []
    for goalkeeper_xy in goalkeepers_xy:
        dist_0 = np.linalg.norm(goalkeeper_xy - team_0_centroid)
        dist_1 = np.linalg.norm(goalkeeper_xy - team_1_centroid)
        goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
    return np.array(goalkeepers_team_id)

In [None]:
# Main processing
def process_video(input_video_path, output_video_path):
    video_info = sv.VideoInfo.from_video_path(input_video_path)
    video_sink = sv.VideoSink(output_video_path, video_info=video_info)
    frame_generator = sv.get_video_frames_generator(input_video_path)

    with video_sink:
        for frame in tqdm(frame_generator, total=video_info.total_frames):
            results = model.infer(frame, confidence=0.3)[0]
            detections = sv.Detections.from_inference(results)

            ball_detections = detections[detections.class_id == BALL_ID]
            ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)

            all_detections = detections[detections.class_id != BALL_ID]
            all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
            all_detections = tracker.update_with_detections(detections=all_detections)

            goalkeepers_detections = all_detections[all_detections.class_id == GOALKEEPER_ID]
            players_detections = all_detections[all_detections.class_id == PLAYER_ID]
            referees_detections = all_detections[all_detections.class_id == REFEREE_ID]

            players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
            players_detections.class_id = team_classifier(players_crops)

            goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
                players_detections, goalkeepers_detections)

            referees_detections.class_id -= 1

            all_detections = sv.Detections.merge([
                players_detections, goalkeepers_detections, referees_detections])

            labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]

            all_detections.class_id = all_detections.class_id.astype(int)

            annotated_frame = frame.copy()
            annotated_frame = ellipse_annotator.annotate(
                scene=annotated_frame,
                detections=all_detections)
            annotated_frame = label_annotator.annotate(
                scene=annotated_frame,
                detections=all_detections,
                labels=labels)
            annotated_frame = triangle_annotator.annotate(
                scene=annotated_frame,
                detections=ball_detections)

            video_sink.write_frame(annotated_frame)

In [None]:
# Run the main processing
input_video_path = '/content/drive/MyDrive/08fd33_4_short.mp4'
output_video_path = '/content/clusteredresult.mp4'
process_video(input_video_path, output_video_path)

In [None]:
# Display a sample frame (optional)
sample_frame = next(sv.get_video_frames_generator(input_video_path))
results = model.infer(sample_frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(results)

ball_detections = detections[detections.class_id == BALL_ID]
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)

all_detections = detections[detections.class_id != BALL_ID]
all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)

players_detections = all_detections[all_detections.class_id == PLAYER_ID]
players_crops = [sv.crop_image(sample_frame, xyxy) for xyxy in players_detections.xyxy]
players_detections.class_id = team_classifier(players_crops)

annotated_frame = sample_frame.copy()
annotated_frame = ellipse_annotator.annotate(scene=annotated_frame, detections=all_detections)
annotated_frame = triangle_annotator.annotate(scene=annotated_frame, detections=ball_detections)

sv.plot_image(annotated_frame)