diff --git a/src/ytpb/playback.py b/src/ytpb/playback.py index b71105d..49c9479 100644 --- a/src/ytpb/playback.py +++ b/src/ytpb/playback.py @@ -6,7 +6,7 @@ from dataclasses import asdict, dataclass from datetime import datetime, timedelta from pathlib import Path -from typing import Iterable, Literal +from typing import Iterable, Literal, Self, TypeGuard from urllib.parse import parse_qs, urlparse import requests @@ -34,6 +34,7 @@ RelativePointInStream, RelativeSegmentSequence, SegmentSequence, + Timestamp, ) from ytpb.utils.other import resolve_relativity_in_interval from ytpb.utils.url import ( @@ -48,6 +49,71 @@ SEGMENT_URL_PATTERN = r"https://.+\.googlevideo\.com/videoplayback/.+" +@dataclass +class RewindTreeNode: + key: Timestamp + value: SegmentSequence + left: Self | None = None + right: Self | None = None + + +@dataclass +class RewindTreeMap: + """A binary search tree implementation to store key-value pairs. + + Keys represent timestamps of segments, while values are sequence numbers. + """ + + root: RewindTreeNode | None = None + + @staticmethod + def _is_tree_node(node: RewindTreeNode | None) -> TypeGuard[RewindTreeNode]: + return node is not None + + @staticmethod + def _insert( + node: RewindTreeNode | None, key: Timestamp, value: SegmentSequence + ) -> RewindTreeNode: + if not RewindTreeMap._is_tree_node(node): + return RewindTreeNode(key, value, None, None) + else: + if key < node.key: + left = RewindTreeMap._insert(node.left, key, value) + return RewindTreeNode(node.key, node.value, left, node.right) + elif key > node.key: + right = RewindTreeMap._insert(node.right, key, value) + return RewindTreeNode(node.key, node.value, node.left, right) + else: + return RewindTreeNode(node.key, value, node.left, node.right) + + def insert(self, key: Timestamp, value: SegmentSequence) -> None: + """Inserts a pair of timestamp and sequence number into the tree.""" + self.root = RewindTreeMap._insert(self.root, key, value) + + @staticmethod + def _closest( + node: RewindTreeNode | None, target: Timestamp, closest: RewindTreeNode + ) -> RewindTreeNode | None: + if not RewindTreeMap._is_tree_node(node): + return closest + else: + result = closest + if abs(target - closest.key) > abs(target - node.key): + result = node + if target < node.key: + return RewindTreeMap._closest(node.left, target, result) + elif target > node.key: + return RewindTreeMap._closest(node.right, target, result) + else: + return result + + def closest(self, target: Timestamp) -> RewindTreeNode | None: + """Finds the node closest to the target timestamp.""" + if self.root is None: + return None + return RewindTreeMap._closest(self.root, target, self.root) + + @dataclass(frozen=True) class RewindMoment: """Represents a moment that has been rewound.""" @@ -197,6 +263,8 @@ def __init__( self._temp_directory: Path | None = None self._cache_directory: Path | None = None + self.rewind_history = RewindTreeMap() + @classmethod def from_url(cls, video_url: str, **kwargs) -> "Playback": """Creates a playback for the given video URL. @@ -457,31 +525,40 @@ def locate_moment( itag = itag or next(iter(self.streams)).itag base_url = self._get_reference_base_url(itag) + def _get_non_located_date(segment: Segment) -> datetime: + if is_end: + date = segment.ingestion_end_date + else: + date = segment.ingestion_start_date + return date + + segment: Segment | None = None + match point: case SegmentSequence() as sequence: self.download_segment(sequence, base_url) segment = self.get_segment(sequence, base_url) - if is_end: - date = segment.ingestion_end_date - else: - date = segment.ingestion_start_date + date = _get_non_located_date(segment) moment = RewindMoment(date, sequence, 0, is_end) case datetime() as date: + reference_sequence: SegmentSequence | None = None + if reference := self.rewind_history.closest(date.timestamp()): + reference_sequence = reference.value sl = SegmentLocator( base_url, + reference_sequence=reference_sequence, temp_directory=self.get_temp_directory(), session=self.session, ) - locate_result = sl.find_sequence_by_time(point.timestamp(), end=is_end) - segment = self.get_segment(locate_result.sequence, base_url) + locate_result = sl.find_sequence_by_time(date.timestamp(), end=is_end) + if locate_result.falls_in_gap: - if is_end: - date = segment.ingestion_end_date - else: - date = segment.ingestion_start_date + segment = self.get_segment(locate_result.sequence, base_url) + date = _get_non_located_date(segment) cut_at = 0 else: cut_at = locate_result.time_difference + moment = RewindMoment( date=date, sequence=locate_result.sequence, @@ -490,6 +567,12 @@ def locate_moment( falls_in_gap=locate_result.falls_in_gap, ) + if segment is None: + segment = self.get_segment(moment.sequence, base_url) + self.rewind_history.insert( + segment.ingestion_start_date.timestamp(), moment.sequence + ) + return moment def locate_interval( diff --git a/tests/test_playback.py b/tests/test_playback.py index c002ebc..3236952 100644 --- a/tests/test_playback.py +++ b/tests/test_playback.py @@ -113,7 +113,6 @@ def test_end_date(self, add_responses_callback_for_reference_base_url: Callable) (timedelta(seconds=2), datetime.fromisoformat("2023-03-25T23:33:57Z")), (RelativeSegmentSequence(1), 7959121), (RelativeSegmentSequence(1), datetime.fromisoformat("2023-03-25T23:33:57Z")), - (RelativeSegmentSequence(1), 7959121), ], ) def test_locate_interval( @@ -205,6 +204,29 @@ def test_locate_interval_with_swapped_start_and_end( playback.locate_interval(start, end, "140") +def test_insert_to_rewind_history( + fake_info_fetcher: "FakeInfoFetcher", + add_responses_callback_for_reference_base_url: Callable, + add_responses_callback_for_segment_urls: Callable, + mocked_responses: responses.RequestsMock, + stream_url: str, + audio_base_url: str, + tmp_path: Path, +) -> None: + # Given: + add_responses_callback_for_segment_urls(urljoin(audio_base_url, r"sq/\w+")) + + # When: + playback = Playback(stream_url, fetcher=fake_info_fetcher) + playback.fetch_and_set_essential() + playback.locate_moment(7959120, "140") + playback.locate_moment(7959122, "140") + + # Then: + assert playback.rewind_history.closest(1679787234.491176).value == 7959120 + assert playback.rewind_history.closest(1679787238.491916).value == 7959122 + + def test_create_playback_from_url( fake_info_fetcher: "FakeInfoFetcher", active_live_video_info: YouTubeVideoInfo,