## Autotrim
> Generate suggested trims based on the elimination information extracted from a given video.

To achieve this, we'll need to implement some form of provenance tracking functionality to figure out if the last n kills of a 

In [50]:
#| default_exp autotrim
#| export
import enum
import dataclasses
import os
import pathlib

from typing import Any, List, Optional, Tuple, TypeAlias

from rapidfuzz import fuzz

from csgo_clips_autotrim.experiment_utils.utils import getLogger
from csgo_clips_autotrim.segmentation.elimination import EliminationSegmentationResult, EliminationEvent, FrameInfo


logger = getLogger('autotrim')

### Clutch detection
> Classify if the given clip was a clutch or not, and if it is, find the cutpoints to trim the clip to the clutch.

In [51]:
#| export
class EventType(enum.Enum):
    ELIMINATION = 'ELIMINATION',
    ROUND_END = 'ROUND_END'

def similarity(event1: EliminationEvent, event2: EliminationEvent) -> float:
    """Find similarity between two elimination event, to find new events in the
    timeline.

    Args:
        event1 (EliminationEvent)
        event2 (EliminationEvent)

    Returns:
        float
    """
    if event1.eliminated.ocr is None or event2.eliminator.ocr is None:
        raise ValueError('Need ocr information in given events to find out similarity.')
    
    ratios = []
    ratios.append(fuzz.WRatio(event1.eliminated.ocr.text, event2.eliminated.ocr.text))
    ratios.append(fuzz.WRatio(event1.eliminator.ocr.text, event2.eliminator.ocr.text))

    return ratios


In [52]:
#| export

@dataclasses.dataclass
class TimelineEvent:
    frame_info: FrameInfo
    event: EliminationEvent

Timeline: TypeAlias = List[TimelineEvent]

@dataclasses.dataclass
class ClutchDetectionResult:
    num_kills: int

def find_closest_event_from_result(event: EliminationEvent, result: EliminationSegmentationResult, threshold: int = 85) -> Optional[EliminationEvent]:
    for e0 in result.elimination_events:
        if all(map(lambda x: x > threshold, similarity(event, e0))):
            return e0
    
    return None

def get_timeline(segmentation_result_path: os.PathLike, threshold: int = 85) -> Timeline:
    """Get timeline of elimination events from the segmentation results.

    Args:
        segmentation_result_path (os.PathLike)
        threshold (int, optional): Threshold used for similarity scanning. Defaults to 85.

    Returns:
        Timeline: _description_
    """
    segmentation_result_path = pathlib.Path(segmentation_result_path)
    elimination_results_files = list(segmentation_result_path.glob('*.json'))

    elimination_results = []
    for path in elimination_results_files:
        with open(path, 'r') as f:
            try:
                result = EliminationSegmentationResult.schema().loads(f.read())
                elimination_results.append(result)
            except:
                logger.exception('Failed to load file: %s', path)

    elimination_results_by_time: List[EliminationSegmentationResult] = sorted(elimination_results, key=lambda x: x.frame_info.idx)
    timeline: Timeline = []
    num_frames = len(elimination_results)

    for idx, result in enumerate(elimination_results_by_time[1:], start=1):
        r0 = elimination_results_by_time[idx - 1]
        r1 = result

        for e1 in r1.elimination_events:
            seen = False

            for e0 in r0.elimination_events:
                if any(map(lambda x: x > threshold, similarity(e1, e0))):
                    seen = True
                    break

            if not seen:
                # Lookahead next few frames to figure out the most confident OCR preds.
                lookahead_frames = min(5, num_frames - idx - 1)

                best_eliminator = e1.eliminator
                best_eliminated = e1.eliminated

                for lidx in range(lookahead_frames):
                    lookahead_result = elimination_results_by_time[idx + lidx]
                    closest_event = find_closest_event_from_result(e1, lookahead_result)

                    if not closest_event:
                        logger.warning('Lookahead did not find any matching event in next frames.')
                        continue

                    if closest_event.eliminator.ocr.confidence > best_eliminator.ocr.confidence:
                        best_eliminator = closest_event.eliminator

                    if closest_event.eliminated.ocr.confidence > best_eliminated.ocr.confidence:
                        best_eliminated = closest_event.eliminated
                
                best_event = dataclasses.replace(e1, eliminator=best_eliminator, eliminated=best_eliminated)
                timeline.append(TimelineEvent(r1.frame_info, best_event))

    return timeline


def detect_clutch(timeline: Timeline) -> ClutchDetectionResult:
    """Detect if the given timeline is a clutch.

    Args:
        timeline (Timeline)

    Returns:
        ClutchDetectionResult
    """
    # Approach: check if the last n kills have the same eliminator.
    # This will miss cases when the eliminator dies last but still the round is won (b)

    # TODO: Need to find the following things to improve accuracy.
    # 1. Which team won
    # 2. Round end
    # 3. Number of enemies remaining.
    last_eliminator = timeline[-1].event.eliminator
    threshold: int = 60

    num_events = min(5, len(timeline))
    potential_clutch_events = timeline[-num_events:-1]
    num_kills = 1

    for timeline_event in reversed(potential_clutch_events):
        if fuzz.WRatio(timeline_event.event.eliminator.ocr.text, last_eliminator.ocr.text) < threshold:
            break
        num_kills += 1

    return ClutchDetectionResult(num_kills=num_kills)

In [53]:
from csgo_clips_autotrim.experiment_utils.constants import BASE_DIR

TEST_DIR = BASE_DIR / 'nbs' / 'out' / 'work-dir' / 'provenance'
timeline = get_timeline(TEST_DIR)
detect_clutch(timeline)

ClutchDetectionResult(num_kills=4)