Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions gateway/sds_gateway/api_methods/helpers/temporal_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import re

from django.db.models import QuerySet

from opensearchpy.exceptions import NotFoundError as OpenSearchNotFoundError
from sds_gateway.api_methods.models import CaptureType, Capture, File
from sds_gateway.api_methods.utils.opensearch_client import get_opensearch_client
from sds_gateway.api_methods.utils.relationship_utils import get_capture_files
from loguru import logger as log

# Digital RF spec: rf@SECONDS.MILLISECONDS.h5 (e.g. rf@1396379502.000.h5)
# https://github.com/MITHaystack/digital_rf
DRF_RF_FILENAME_PATTERN = re.compile(
r"^rf@(\d+)\.(\d+)\.h5$",
re.IGNORECASE,
)
DRF_RF_FILENAME_REGEX_STR = r"^rf@\d+\.\d+\.h5$"


def drf_rf_filename_from_ms(ms: int) -> str:
"""Format ms as DRF rf data filename (canonical for range queries)."""
return f"rf@{ms // 1000}.{ms % 1000:03d}.h5"


def drf_rf_filename_to_ms(file_name: str) -> int | None:
"""
Parse DRF rf data filename to milliseconds.
Handles rf@SECONDS.MILLISECONDS.h5; fractional part padded to 3 digits.
"""
name = file_name.strip()
match = DRF_RF_FILENAME_PATTERN.match(name)
if not match:
return None
try:
seconds = int(match.group(1))
frac = match.group(2).ljust(3, "0")[:3]
return seconds * 1000 + int(frac)
except (ValueError, TypeError):
return None


def _catch_capture_type_error(capture_type: CaptureType) -> None:
if capture_type != CaptureType.DigitalRF:
msg = "Only DigitalRF captures are supported for temporal filtering."
log.error(msg)
raise ValueError(msg)


def get_capture_bounds(capture_type: CaptureType, capture_uuid: str) -> tuple[int, int]:
"""Get start and end bounds for capture from opensearch."""

_catch_capture_type_error(capture_type)

client = get_opensearch_client()
index = f"captures-{capture_type}"

try:
response = client.get(index=index, id=capture_uuid)
except OpenSearchNotFoundError as e:
raise ValueError(
f"Capture {capture_uuid} not found in OpenSearch index {index}"
) from e

if not response.get("found"):
raise ValueError(
f"Capture {capture_uuid} not found in OpenSearch index {index}"
)

source = response.get("_source", {})
search_props = source.get("search_props", {})
start_time = search_props.get("start_time", 0)
end_time = search_props.get("end_time", 0)
return start_time, end_time


def get_data_files(capture_type: CaptureType, capture: Capture) -> QuerySet[File]:
"""Get the data files in the capture."""
_catch_capture_type_error(capture_type)

return get_capture_files(capture).filter(name__regex=DRF_RF_FILENAME_REGEX_STR)
Comment on lines +75 to +80
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hot path involving a DB call: it's called once per channel, then for each capture being serialized, and from time range computations, file cadence, and capture bounds.

See if we can refactor this call stack first, then cache results (lru_cache if a function, or django.utils.functional.cached_property if you make it a method of Capture).

The refactoring suggestion is because this and other functions in this file might make more sense as properties or methods of a Capture instance.



def get_file_cadence(capture_type: CaptureType, capture: Capture) -> int:
"""Get the file cadence in milliseconds. OpenSearch bounds are in seconds."""
_catch_capture_type_error(capture_type)

capture_uuid = str(capture.uuid)
try:
start_time, end_time = get_capture_bounds(capture_type, capture_uuid)
except ValueError as e:
log.error(e)
raise e
Comment on lines +88 to +92
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try-except not needed, just let it raise, or handle it


data_files = get_data_files(capture_type, capture)
count = data_files.count()

# the first file represents the beginning of the capture
# exclude it from the count to get the correct file cadence
# the count - 1 gives us the number of "spaces" between the files
count -= 1
if count == 0:
return 0
duration_sec = end_time - start_time
duration_ms = duration_sec * 1000
return max(1, int(duration_ms / count))


def filter_capture_data_files_selection_bounds(
capture_type: CaptureType,
capture: Capture,
start_time: int, # relative ms from start of capture (from UI)
end_time: int, # relative ms from start of capture (from UI)
) -> QuerySet[File]:
"""Filter the capture file selection bounds to the given start and end times."""
_catch_capture_type_error(capture_type)
epoch_start_sec, _ = get_capture_bounds(capture_type, str(capture.uuid))
epoch_start_ms = epoch_start_sec * 1000
start_ms = epoch_start_ms + start_time
end_ms = epoch_start_ms + end_time

start_file_name = drf_rf_filename_from_ms(start_ms)
end_file_name = drf_rf_filename_from_ms(end_ms)

data_files = get_data_files(capture_type, capture)
return data_files.filter(
name__gte=start_file_name,
name__lte=end_file_name,
).order_by("name")

def get_capture_files_with_temporal_filter(
capture_type: CaptureType,
capture: Capture,
start_time: int | None = None, # milliseconds since start of capture
end_time: int | None = None,
) -> QuerySet[File]:
"""Get the capture files with temporal filtering."""
_catch_capture_type_error(capture_type)

if start_time is None or end_time is None:
log.warning("Start or end time is None, returning all capture files without temporal filtering")
return get_capture_files(capture)

# get non-data files
non_data_files = get_capture_files(capture).exclude(name__regex=DRF_RF_FILENAME_REGEX_STR)

# get data files with temporal filtering
data_files = filter_capture_data_files_selection_bounds(
capture_type, capture, start_time, end_time
)

# return all files
return non_data_files.union(data_files)
168 changes: 158 additions & 10 deletions gateway/sds_gateway/api_methods/serializers/capture_serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Capture serializers for the SDS Gateway API methods."""

import logging
from typing import Any
from typing import cast

Expand All @@ -9,6 +10,9 @@
from rest_framework.utils.serializer_helpers import ReturnList

from sds_gateway.api_methods.helpers.index_handling import retrieve_indexed_metadata
from sds_gateway.api_methods.helpers.temporal_filtering import get_capture_bounds
from sds_gateway.api_methods.helpers.temporal_filtering import get_file_cadence
from sds_gateway.api_methods.helpers.temporal_filtering import get_data_files
from sds_gateway.api_methods.models import Capture
from sds_gateway.api_methods.models import CaptureType
from sds_gateway.api_methods.models import DEPRECATEDPostProcessedData
Expand Down Expand Up @@ -70,7 +74,12 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]):
files = serializers.SerializerMethodField()
center_frequency_ghz = serializers.SerializerMethodField()
sample_rate_mhz = serializers.SerializerMethodField()
files_count = serializers.SerializerMethodField()
length_of_capture_ms = serializers.SerializerMethodField()
file_cadence_ms = serializers.SerializerMethodField()
capture_start_epoch_sec = serializers.SerializerMethodField()
data_files_count = serializers.SerializerMethodField()
data_files_total_size = serializers.SerializerMethodField()
per_data_file_size = serializers.SerializerMethodField()
Comment on lines +77 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change gateway/sds_gateway/static/js/components.js after this files_count removal

total_file_size = serializers.SerializerMethodField()
formatted_created_at = serializers.SerializerMethodField()
capture_type_display = serializers.SerializerMethodField()
Expand All @@ -94,23 +103,89 @@ def get_files(self, capture: Capture) -> ReturnList[File]:
def get_center_frequency_ghz(self, capture: Capture) -> float | None:
"""Get the center frequency in GHz from the capture model property."""
return capture.center_frequency_ghz

@extend_schema_field(serializers.FloatField)
@extend_schema_field(serializers.FloatField(allow_null=True))
def get_sample_rate_mhz(self, capture: Capture) -> float | None:
"""Get the sample rate in MHz from the capture model property."""
"""Get the sample rate in MHz from the capture model property. None if not indexed in OpenSearch."""
return capture.sample_rate_mhz

@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_length_of_capture_ms(self, capture: Capture) -> int | None:
"""Get the length of the capture in milliseconds. OpenSearch bounds are in seconds."""
try:
start_time, end_time = get_capture_bounds(capture.capture_type, str(capture.uuid))
return (end_time - start_time) * 1000
except (ValueError, IndexError, KeyError):
return None

@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_file_cadence_ms(self, capture: Capture) -> int | None:
"""Get the file cadence in milliseconds. None if not indexed in OpenSearch."""
try:
return get_file_cadence(capture.capture_type, capture)
except (ValueError, IndexError, KeyError):
return None

@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_capture_start_epoch_sec(self, capture: Capture) -> int | None:
"""Get the capture start time as Unix epoch seconds. None if not indexed in OpenSearch."""
try:
start_time, _ = get_capture_bounds(capture.capture_type, str(capture.uuid))
return start_time
except (ValueError, IndexError, KeyError):
return None

@extend_schema_field(serializers.IntegerField)
def get_files_count(self, capture: Capture) -> int:
def get_data_files_count(self, capture: Capture) -> int | None:
"""Get the count of files associated with this capture."""
return get_capture_files(capture, include_deleted=False).count()
if capture.capture_type != CaptureType.DigitalRF:
return None

return get_data_files(capture.capture_type, capture).count()

@extend_schema_field(serializers.IntegerField)
def get_data_files_total_size(self, capture: Capture) -> int | None:
"""Exact sum of data file sizes; use this for consistent totals with total_file_size."""
if capture.capture_type != CaptureType.DigitalRF:
return None
data_files = get_data_files(capture.capture_type, capture)
result = data_files.aggregate(total_size=Sum("size"))
return result.get("total_size") or 0

@extend_schema_field(serializers.FloatField)
def get_per_data_file_size(self, capture: Capture) -> float | None:
"""Get the size of each data file in the capture."""
if capture.capture_type != CaptureType.DigitalRF:
return None

data_files = get_data_files(capture.capture_type, capture)

if data_files.count() == 0:
return None

data_file_sizes = data_files.aggregate(total_size=Sum("size"))
total_size = data_file_sizes.get("total_size")

if not total_size:
return None

return float(total_size) / data_files.count()

Comment on lines +161 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling 3 queries (2x count + 1x agg); see if a single query works here

stats = data_files.aggregate(
    total_size=Sum("size"),
    count=Count("id")
)

if stats["count"] == 0:
    return None

return float(stats["total_size"]) / stats["count"]

@extend_schema_field(serializers.IntegerField)
def get_total_file_size(self, capture: Capture) -> int:
"""Get the total file size of all files associated with this capture."""
all_files = get_capture_files(capture, include_deleted=False)
result = all_files.aggregate(total_size=Sum("size"))
return result["total_size"] or 0
total = result["total_size"] or 0
if capture.capture_type == CaptureType.DigitalRF:
data_total = self.get_data_files_total_size(capture) or 0
if total < data_total:
logging.getLogger(__name__).warning(
"Capture %s: total_file_size (%s) < data_files_total_size (%s); using data total.",
str(capture.uuid), total, data_total,
)
total = data_total
return total

@extend_schema_field(serializers.DictField)
def get_capture_props(self, capture: Capture) -> dict[str, Any]:
Expand Down Expand Up @@ -301,9 +376,13 @@ class CompositeCaptureSerializer(serializers.Serializer):

# Computed fields
files = serializers.SerializerMethodField()
files_count = serializers.SerializerMethodField()
data_files_count = serializers.SerializerMethodField()
data_files_total_size = serializers.SerializerMethodField()
total_file_size = serializers.SerializerMethodField()
formatted_created_at = serializers.SerializerMethodField()
length_of_capture_ms = serializers.SerializerMethodField()
file_cadence_ms = serializers.SerializerMethodField()
capture_start_epoch_sec = serializers.SerializerMethodField()

def get_files(self, obj: dict[str, Any]) -> ReturnList[File]:
"""Get all files from all channels in the composite capture."""
Expand All @@ -321,25 +400,52 @@ def get_files(self, obj: dict[str, Any]) -> ReturnList[File]:
return cast("ReturnList[File]", all_files)

@extend_schema_field(serializers.IntegerField)
def get_files_count(self, obj: dict[str, Any]) -> int:
def get_data_files_count(self, obj: dict[str, Any]) -> int | None:
"""Get the total count of files across all channels."""
if obj["capture_type"] != CaptureType.DigitalRF:
return None

total_count = 0
for channel_data in obj["channels"]:
capture_uuid = channel_data["uuid"]
capture = Capture.objects.get(uuid=capture_uuid)
total_count += get_capture_files(capture, include_deleted=False).count()
total_count += get_data_files(capture.capture_type, capture).count()
return total_count

@extend_schema_field(serializers.IntegerField)
def get_data_files_total_size(self, obj: dict[str, Any]) -> int | None:
"""Exact sum of data file sizes across all channels."""
if obj["capture_type"] != CaptureType.DigitalRF:
return None
total = 0
for channel_data in obj["channels"]:
capture_uuid = channel_data["uuid"]
capture = Capture.objects.get(uuid=capture_uuid)
data_files = get_data_files(capture.capture_type, capture)
result = data_files.aggregate(total_size=Sum("size"))
total += result.get("total_size") or 0
return total

Comment on lines 402 to +428
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're doing some rework here, maybe there's a way to return both count and size in one pass for both serialized fields

@extend_schema_field(serializers.IntegerField)
def get_total_file_size(self, obj: dict[str, Any]) -> int:
"""Get the total file size across all channels."""
if obj["capture_type"] != CaptureType.DigitalRF:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return 0 or -> int | None


total_size = 0
for channel_data in obj["channels"]:
capture_uuid = channel_data["uuid"]
capture = Capture.objects.get(uuid=capture_uuid)
all_files = get_capture_files(capture, include_deleted=False)
result = all_files.aggregate(total_size=Sum("size"))
total_size += result["total_size"] or 0
data_total = self.get_data_files_total_size(obj) or 0
if total_size < data_total:
logging.getLogger(__name__).warning(
"Composite capture: total_file_size (%s) < data_files_total_size (%s); using data total.",
total_size, data_total,
)
total_size = data_total
return total_size

@extend_schema_field(serializers.CharField)
Expand All @@ -350,6 +456,48 @@ def get_formatted_created_at(self, obj: dict[str, Any]) -> str:
return created_at.strftime("%m/%d/%Y %I:%M:%S %p")
return ""

@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_length_of_capture_ms(self, obj: dict[str, Any]) -> int | None:
"""Use first channel's bounds for composite capture duration."""
channels = obj.get("channels") or []
if not channels:
return None
try:
capture = Capture.objects.get(uuid=channels[0]["uuid"])
start_time, end_time = get_capture_bounds(
capture.capture_type, str(capture.uuid)
)
return (end_time - start_time) * 1000
except (ValueError, IndexError, KeyError):
return None

@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_file_cadence_ms(self, obj: dict[str, Any]) -> int | None:
"""Use first channel's file cadence for composite capture."""
channels = obj.get("channels") or []
if not channels:
return None
try:
capture = Capture.objects.get(uuid=channels[0]["uuid"])
return get_file_cadence(capture.capture_type, capture)
except (ValueError, IndexError, KeyError):
return None

@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_capture_start_epoch_sec(self, obj: dict[str, Any]) -> int | None:
"""Use first channel's start time for composite capture."""
channels = obj.get("channels") or []
if not channels:
return None
try:
capture = Capture.objects.get(uuid=channels[0]["uuid"])
start_time, _ = get_capture_bounds(
capture.capture_type, str(capture.uuid)
)
return start_time
except (ValueError, IndexError, KeyError):
return None


def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]:
"""Build composite capture data from a list of captures with the same top_level_dir.
Expand Down
Loading
Loading