## Autotrim
> Generate suggested trims based on the elimination information extracted from a given video.

To achieve this, we'll need to implement some form of provenance tracking functionality to figure out if the last n kills of a 

In [90]:
#| default_exp autotrim
#| export
import enum
import dataclasses
import os
import pathlib

from typing import List, Optional, TypeAlias

import dataclasses_json
import numpy as np
import numpy.typing as nptypes
from PIL import Image
from rapidfuzz import fuzz
import scipy.cluster
import scipy.spatial
import scipy.misc

from csgo_clips_autotrim.experiment_utils.utils import getLogger
from csgo_clips_autotrim.experiment_utils.config import InferenceConfig
from csgo_clips_autotrim.segmentation.elimination import EliminationSegmentationResult, EliminationEvent, FrameInfo, get_inference_result, preprocess_image, crop_img_to_bbox, XYXYBBox


logger = getLogger('autotrim')

### Clutch detection
> Classify if the given clip was a clutch or not, and if it is, find the cutpoints to trim the clip to the clutch.

In [91]:
#| export
class EventType(enum.Enum):
    ELIMINATION = 'ELIMINATION',
    ROUND_END = 'ROUND_END'

def similarity(event1: EliminationEvent, event2: EliminationEvent) -> float:
    """Find similarity between two elimination event, to find new events in the
    timeline.

    Args:
        event1 (EliminationEvent)
        event2 (EliminationEvent)

    Returns:
        float
    """
    if event1.eliminated.ocr is None or event2.eliminator.ocr is None:
        raise ValueError('Need ocr information in given events to find out similarity.')
    
    ratios = []
    ratios.append(fuzz.WRatio(event1.eliminated.ocr.text, event2.eliminated.ocr.text))
    ratios.append(fuzz.WRatio(event1.eliminator.ocr.text, event2.eliminator.ocr.text))

    return ratios


In [102]:
#| export

@dataclasses.dataclass
class TimelineEvent:
    frame_info: FrameInfo
    event: EliminationEvent

class GameStateLabel(enum.Enum):
    CT_WIN = 0
    ELIMINATOR_MARK = 1
    T_WIN = 2

@dataclasses.dataclass
class GameStateElement:
    label: GameStateLabel
    bbox: XYXYBBox

Timeline: TypeAlias = List[TimelineEvent]

@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ClutchDetectionResult:
    num_eliminations: int
    player: str
    start_frame_idx: int
    end_frame_idx: int


def get_dominant_color(img_ar: nptypes.ArrayLike) -> nptypes.ArrayLike:
    NUM_CLUSTERS = 5

    shape = img_ar.shape
    frame_ar = img_ar.reshape(np.product(shape[:2]), shape[2]).astype(float)
    codes, dist = scipy.cluster.vq.kmeans(frame_ar, NUM_CLUSTERS)
    vecs, dist = scipy.cluster.vq.vq(frame_ar, codes)
    counts, bins = np.histogram(vecs, len(codes))
    index_max = np.argmax(counts)
    peak = codes[index_max]

    return peak

def is_eliminator_win(last_frame_img_np: nptypes.ArrayLike, last_event: TimelineEvent, win_element_bbox: XYXYBBox) -> bool:
    """Find if the eliminator won the round.

    Args:
        last_frame_img_np (nptypes.ArrayLike)
        last_event (TimelineEvent)
        win_element_bbox (XYXYBBox)

    Returns:
        bool
    """
    h, w, _ = last_frame_img_np.shape
    top_right_quad = last_frame_img_np[0:h // 2, w // 2:w]

    eliminator_img_np = crop_img_to_bbox(top_right_quad, last_event.event.eliminator.bbox)
    eliminator_dominant_color = get_dominant_color(eliminator_img_np)

    win_img_np = crop_img_to_bbox(last_frame_img_np, win_element_bbox)
    win_dominant_color = get_dominant_color(win_img_np)

    eliminated_img_np = crop_img_to_bbox(top_right_quad, last_event.event.eliminated.bbox)
    eliminated_dominant_color = get_dominant_color(eliminated_img_np)

    return scipy.spatial.distance.euclidean(eliminator_dominant_color, win_dominant_color) < scipy.spatial.distance.euclidean(eliminated_dominant_color, win_dominant_color) 

def find_closest_event_from_result(event: EliminationEvent, result: EliminationSegmentationResult, threshold: int = 85) -> Optional[EliminationEvent]:
    for e0 in result.elimination_events:
        if all(map(lambda x: x > threshold, similarity(event, e0))):
            return e0
    
    return None

def get_timeline(segmentation_result_path: os.PathLike, threshold: int = 85) -> Timeline:
    """Get timeline of elimination events from the segmentation results.

    Args:
        segmentation_result_path (os.PathLike)
        threshold (int, optional): Threshold used for similarity scanning. Defaults to 85.

    Returns:
        Timeline: _description_
    """
    segmentation_result_path = pathlib.Path(segmentation_result_path)
    elimination_results_files = list(segmentation_result_path.glob('*.json'))

    elimination_results = []
    for path in elimination_results_files:
        with open(path, 'r') as f:
            try:
                result = EliminationSegmentationResult.schema().loads(f.read())
                elimination_results.append(result)
            except:
                logger.exception('Failed to load file: %s', path)

    elimination_results_by_time: List[EliminationSegmentationResult] = sorted(elimination_results, key=lambda x: x.frame_info.idx)
    timeline: Timeline = []
    num_frames = len(elimination_results)

    for idx, result in enumerate(elimination_results_by_time[1:], start=1):
        r0 = elimination_results_by_time[idx - 1]
        r1 = result

        for e1 in r1.elimination_events:
            seen = False

            for e0 in r0.elimination_events:
                if any(map(lambda x: x > threshold, similarity(e1, e0))):
                    seen = True
                    break

            if not seen:
                # Lookahead next few frames to figure out the most confident OCR preds.
                lookahead_frames = min(5, num_frames - idx - 1)

                best_eliminator = e1.eliminator
                best_eliminated = e1.eliminated

                for lidx in range(lookahead_frames):
                    lookahead_result = elimination_results_by_time[idx + lidx]
                    closest_event = find_closest_event_from_result(e1, lookahead_result)

                    if not closest_event:
                        logger.warning('Lookahead did not find any matching event in next frames.')
                        continue

                    if closest_event.eliminator.ocr.confidence > best_eliminator.ocr.confidence:
                        best_eliminator = closest_event.eliminator

                    if closest_event.eliminated.ocr.confidence > best_eliminated.ocr.confidence:
                        best_eliminated = closest_event.eliminated
                
                best_event = dataclasses.replace(e1, eliminator=best_eliminator, eliminated=best_eliminated)
                timeline.append(TimelineEvent(r1.frame_info, best_event))

    return timeline


def get_frame(frame_info: FrameInfo, image_dir: os.PathLike) -> nptypes.ArrayLike:
    """Get the frame as a numpy array for given frame info.

    Args:
        frame_info (FrameInfo)
        image_dir (os.PathLike)

    Returns:
        nptypes.ArrayLike
    """
    frame_path = pathlib.Path(image_dir) / f'{frame_info.name}.png'
    with Image.open(frame_path, 'r') as f:
        return np.asarray(f)

def detect_game_state_elements(frame_info: FrameInfo, image_dir: os.PathLike, inference_config: InferenceConfig) -> List[GameStateElement]:
    """Detect game state elements from the given frame.

    Args:
        frame_info (FrameInfo)
        image_dir (os.PathLike)

    Returns:
        List[GameStateLabel]
    """
    frame_img = get_frame(frame_info, image_dir)
    preprocess_result = preprocess_image(frame_img, inference_config.mlflow_artifact_run_id)
    results = get_inference_result(preprocess_result, inference_config)

    return [GameStateElement(bbox=result.bbox, label=GameStateLabel(result.label)) for result in results]

def detect_clutch(timeline: Timeline, image_dir: os.PathLike, game_state_inference_config: InferenceConfig) -> Optional[ClutchDetectionResult]:
    """Detect if the given timeline is a clutch.

    Args:
        timeline (Timeline)

    Returns:
        ClutchDetectionResult
    """
    # Check if the last event in the timeline has game state elements that make this a clutch.
    last_event = timeline[-1]

    last_event_game_state_elements = detect_game_state_elements(last_event.frame_info, image_dir, game_state_inference_config)

    if not last_event_game_state_elements:
        logger.info('Did not find any game state elements in the last timeline event, no clutch detected.')
        return None
    
    win_element = list(filter(lambda x: x.label in (GameStateLabel.CT_WIN, GameStateLabel.T_WIN), last_event_game_state_elements))

    if not win_element:
        logger.info('Did not find any win elements in the last timline event, no clutch detected.')
        return None

    if len(win_element) > 1:
        logger.warning('Found multiple win elements in the last timeline event. Using first for further inference.')

    win_element = win_element[0]

    # Check if the team of the last eliminator won.
    last_frame_img_np = get_frame(last_event.frame_info, image_dir)

    if not is_eliminator_win(last_frame_img_np, last_event, win_element.bbox):
        logger.info('Eliminator did not win the last engagement.')
        return None
    
    # Check if the eliminator was the last alive.
    teammates_eliminated_marks = list(filter(lambda x: x.label == GameStateLabel.ELIMINATOR_MARK, last_event_game_state_elements))

    if len(teammates_eliminated_marks) != 4:
        logger.info('The last eliminator had %d teammates eliminated, no clutch detected.', len(teammates_eliminated_marks))
        return ClutchDetectionResult()

    # Count the number of teammates alive at the win condition.
    last_eliminator = timeline[-1].event.eliminator
    threshold: int = 60

    num_events = min(5, len(timeline))
    potential_clutch_events = timeline[-num_events:-1]
    num_eliminations = 1

    for timeline_event in reversed(potential_clutch_events):
        if fuzz.WRatio(timeline_event.event.eliminator.ocr.text, last_eliminator.ocr.text) < threshold:
            break
        num_eliminations += 1
    
    first_clutch_elimination = timeline[-num_eliminations]
    last_clutch_elimination = timeline[-1]

    return ClutchDetectionResult(num_eliminations=num_eliminations,
                                 player=last_eliminator.ocr.text,
                                 start_frame_idx=first_clutch_elimination.frame_info.idx,
                                 end_frame_idx=last_clutch_elimination.frame_info.idx)

In [103]:
from csgo_clips_autotrim.experiment_utils.constants import BASE_DIR

TEST_DIR = BASE_DIR / 'nbs' / 'out' / 'work-dir' / 'full-frame'
game_state_inference_config = InferenceConfig(mlflow_artifact_run_id='2fe893e46e554b1e8b1ae44176677fb3', triton_model_name='csgo-game-state-segmentation-yolov8', triton_url='localhost:8000', score_threshold=0.5)
timeline = get_timeline(TEST_DIR)
detect_clutch(timeline, TEST_DIR, game_state_inference_config)

ClutchDetectionResult(num_eliminations=4, player='Oblue', start_frame_idx=48, end_frame_idx=82)