Skip to content

Commit

Permalink
feat(playback): Add RewindTreeMap to keep rewind history
Browse files Browse the repository at this point in the history
It was originally introduced in xymaxim/mpv-ytpb#1.
  • Loading branch information
xymaxim committed Apr 10, 2024
1 parent a94f8d1 commit 91fd078
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 12 deletions.
105 changes: 94 additions & 11 deletions src/ytpb/playback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Iterable, Literal
from typing import Iterable, Literal, Self, TypeGuard
from urllib.parse import parse_qs, urlparse

import requests
Expand Down Expand Up @@ -34,6 +34,7 @@
RelativePointInStream,
RelativeSegmentSequence,
SegmentSequence,
Timestamp,
)
from ytpb.utils.other import resolve_relativity_in_interval
from ytpb.utils.url import (
Expand All @@ -48,6 +49,71 @@
SEGMENT_URL_PATTERN = r"https://.+\.googlevideo\.com/videoplayback/.+"


@dataclass
class RewindTreeNode:
key: Timestamp
value: SegmentSequence
left: Self | None = None
right: Self | None = None


@dataclass
class RewindTreeMap:
"""A binary search tree implementation to store key-value pairs.
Keys represent timestamps of segments, while values are sequence numbers.
"""

root: RewindTreeNode | None = None

@staticmethod
def _is_tree_node(node: RewindTreeNode | None) -> TypeGuard[RewindTreeNode]:
return node is not None

@staticmethod
def _insert(
node: RewindTreeNode | None, key: Timestamp, value: SegmentSequence
) -> RewindTreeNode:
if not RewindTreeMap._is_tree_node(node):
return RewindTreeNode(key, value, None, None)
else:
if key < node.key:
left = RewindTreeMap._insert(node.left, key, value)
return RewindTreeNode(node.key, node.value, left, node.right)
elif key > node.key:
right = RewindTreeMap._insert(node.right, key, value)
return RewindTreeNode(node.key, node.value, node.left, right)
else:
return RewindTreeNode(node.key, value, node.left, node.right)

def insert(self, key: Timestamp, value: SegmentSequence) -> None:
"""Inserts a pair of timestamp and sequence number into the tree."""
self.root = RewindTreeMap._insert(self.root, key, value)

@staticmethod
def _closest(
node: RewindTreeNode | None, target: Timestamp, closest: RewindTreeNode
) -> RewindTreeNode | None:
if not RewindTreeMap._is_tree_node(node):
return closest
else:
result = closest
if abs(target - closest.key) > abs(target - node.key):
result = node
if target < node.key:
return RewindTreeMap._closest(node.left, target, result)
elif target > node.key:
return RewindTreeMap._closest(node.right, target, result)
else:
return result

def closest(self, target: Timestamp) -> RewindTreeNode | None:
"""Finds the node closest to the target timestamp."""
if self.root is None:
return None
return RewindTreeMap._closest(self.root, target, self.root)


@dataclass(frozen=True)
class RewindMoment:
"""Represents a moment that has been rewound."""
Expand Down Expand Up @@ -197,6 +263,8 @@ def __init__(
self._temp_directory: Path | None = None
self._cache_directory: Path | None = None

self.rewind_history = RewindTreeMap()

@classmethod
def from_url(cls, video_url: str, **kwargs) -> "Playback":
"""Creates a playback for the given video URL.
Expand Down Expand Up @@ -457,31 +525,40 @@ def locate_moment(
itag = itag or next(iter(self.streams)).itag
base_url = self._get_reference_base_url(itag)

def _get_non_located_date(segment: Segment) -> datetime:
if is_end:
date = segment.ingestion_end_date
else:
date = segment.ingestion_start_date
return date

segment: Segment | None = None

match point:
case SegmentSequence() as sequence:
self.download_segment(sequence, base_url)
segment = self.get_segment(sequence, base_url)
if is_end:
date = segment.ingestion_end_date
else:
date = segment.ingestion_start_date
date = _get_non_located_date(segment)
moment = RewindMoment(date, sequence, 0, is_end)
case datetime() as date:
reference_sequence: SegmentSequence | None = None
if reference := self.rewind_history.closest(date.timestamp()):
reference_sequence = reference.value
sl = SegmentLocator(
base_url,
reference_sequence=reference_sequence,
temp_directory=self.get_temp_directory(),
session=self.session,
)
locate_result = sl.find_sequence_by_time(point.timestamp(), end=is_end)
segment = self.get_segment(locate_result.sequence, base_url)
locate_result = sl.find_sequence_by_time(date.timestamp(), end=is_end)

if locate_result.falls_in_gap:
if is_end:
date = segment.ingestion_end_date
else:
date = segment.ingestion_start_date
segment = self.get_segment(locate_result.sequence, base_url)
date = _get_non_located_date(segment)
cut_at = 0
else:
cut_at = locate_result.time_difference

moment = RewindMoment(
date=date,
sequence=locate_result.sequence,
Expand All @@ -490,6 +567,12 @@ def locate_moment(
falls_in_gap=locate_result.falls_in_gap,
)

if segment is None:
segment = self.get_segment(moment.sequence, base_url)
self.rewind_history.insert(
segment.ingestion_start_date.timestamp(), moment.sequence
)

return moment

def locate_interval(
Expand Down
24 changes: 23 additions & 1 deletion tests/test_playback.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def test_end_date(self, add_responses_callback_for_reference_base_url: Callable)
(timedelta(seconds=2), datetime.fromisoformat("2023-03-25T23:33:57Z")),
(RelativeSegmentSequence(1), 7959121),
(RelativeSegmentSequence(1), datetime.fromisoformat("2023-03-25T23:33:57Z")),
(RelativeSegmentSequence(1), 7959121),
],
)
def test_locate_interval(
Expand Down Expand Up @@ -205,6 +204,29 @@ def test_locate_interval_with_swapped_start_and_end(
playback.locate_interval(start, end, "140")


def test_insert_to_rewind_history(
fake_info_fetcher: "FakeInfoFetcher",
add_responses_callback_for_reference_base_url: Callable,
add_responses_callback_for_segment_urls: Callable,
mocked_responses: responses.RequestsMock,
stream_url: str,
audio_base_url: str,
tmp_path: Path,
) -> None:
# Given:
add_responses_callback_for_segment_urls(urljoin(audio_base_url, r"sq/\w+"))

# When:
playback = Playback(stream_url, fetcher=fake_info_fetcher)
playback.fetch_and_set_essential()
playback.locate_moment(7959120, "140")
playback.locate_moment(7959122, "140")

# Then:
assert playback.rewind_history.closest(1679787234.491176).value == 7959120
assert playback.rewind_history.closest(1679787238.491916).value == 7959122


def test_create_playback_from_url(
fake_info_fetcher: "FakeInfoFetcher",
active_live_video_info: YouTubeVideoInfo,
Expand Down

0 comments on commit 91fd078

Please sign in to comment.