In [1]:
import io
import logging
from typing import Optional

import h5py
import hermes.quiver as qv
import torch

import os

from collections.abc import Sequence

#from export.mm_snapshotter import mm_add_streaming_input_preprocessor
from utils.s3 import open_file

from export.mm_modules import concatenation_layer

def scale_model(model, instances):
    """
    Scale the model to the number of instances per GPU desired
    at inference time
    """
    # TODO: should quiver handle this under the hood?
    try:
        model.config.scale_instance_group(instances)
    except ValueError:
        model.config.add_instance_group(count=instances)

from hermes.quiver import Platform
from hermes.quiver.streaming import utils as streaming_utils

#from utils.preprocessing import mm_BackgroundSnapshotter, mm_BatchWhitener

from typing import Callable, Tuple

from ml4gw.transforms import SpectralDensity, Whiten
from ml4gw.utils.slicing import unfold_windows
import numpy as np

Tensor = torch.Tensor

import torch.nn.functional as F
import torchaudio.transforms as T

class PsdEstimator(torch.nn.Module):
    """
    Module that takes a sample of data, splits it into
    two unequal-length segments, calculates the PSD of
    the first section, then returns this PSD along with
    the second section.

    Args:
        length:
            The length, in seconds, of timeseries data
            to be returned for whitening. Note that the
            length of time used for the PSD will then be
            whatever remains along first part of the time
            axis of the input.
        sample_rate:
            Rate at which input data has been sampled in Hz
        fftlength:
            Length of FFTs to use when computing the PSD
        overlap:
            Amount of overlap between FFT windows when
            computing the PSD. Default value of `None`
            uses `fftlength / 2`
        average:
            Method for aggregating spectra from FFT
            windows, either `"mean"` or `"median"`
        fast:
            If `True`, use a slightly faster PSD algorithm
            that is inaccurate for the lowest two frequency
            bins. If you plan on highpassing later, this
            should be fine.
    """

    def __init__(
        self,
        length: float,
        sample_rate: float,
        fftlength: float,
        window: Optional[torch.Tensor] = None,
        overlap: Optional[float] = None,
        average: str = "median",
        fast: bool = True,
    ) -> None:
        super().__init__()
        self.size = int(length * sample_rate)
        self.spectral_density = SpectralDensity(
            sample_rate, fftlength, overlap, average, window=window, fast=fast
        )

    def forward(self, X: Tensor) -> Tuple[Tensor, Tensor]:
        splits = [X.size(-1) - self.size, self.size]
        background, X = torch.split(X, splits, dim=-1)

        # if we have 2 batch elements in our input data,
        # it will be assumed that the 0th element is data
        # being used to calculate the psd to whiten the
        # 1st element. Used when we want to use raw background
        # data to calculate the PSDs to whiten data with injected signals
        if X.ndim == 3 and X.size(0) == 2:
            # 0th background element is used to calculate PSDs
            background = background[0]
            # 1st element is the data to be whitened
            X = X[1]

        psds = self.spectral_density(background.double())
        return X, psds

class BackgroundSnapshotter(torch.nn.Module):
    """Update a kernel with a new piece of streaming data"""

    def __init__(
        self,
        psd_length,
        kernel_length,
        fduration,
        sample_rate,
        inference_sampling_rate,
    ) -> None:
        super().__init__()
        state_length = kernel_length + fduration + psd_length
        state_length -= 1 / inference_sampling_rate
        self.state_size = int(state_length * sample_rate)

    def forward(self, update: Tensor, snapshot: Tensor) -> Tuple[Tensor, ...]:
        x = torch.cat([snapshot, update], axis=-1)
        snapshot = x[:, :, -self.state_size :]
        return x, snapshot

class mm_BatchWhitener(torch.nn.Module):
    """Calculate the PSDs and whiten an entire batch of kernels at once"""

    def __init__(
        self,
        resample_rates: Sequence[float], 
        kernel_lengths: Sequence[float], 
        high_passes: Sequence[float], 
        low_passes: Sequence[float],
        inference_sampling_rates: Sequence[float],
        starting_offsets: Sequence[int],
        num_ifos: int,
        kernel_length: float,
        sample_rate: float,
        batch_size: int,
        fduration: float,
        fftlength: float,
    ) -> None:
        super().__init__()
        self.resample_rates = resample_rates
        self.stride_sizes = [int(sample_rate / isr) for isr in inference_sampling_rates]
        self.kernel_sizes = [int(kl * sample_rate) for kl in kernel_lengths]
        self.num_timeseries = len(kernel_lengths)-1
        self.num_ifos = num_ifos
        # do foreground length calculation in units of samples,
        # then convert back to length to guard for intification
        self.starting_offsets = [int(kernel_length*sample_rate-so*min(self.stride_sizes)-self.kernel_sizes[i]) 
                                 for i, so in enumerate(starting_offsets)]
        self.ending_offsets = [None if int(so*min(self.stride_sizes)) == 0 else -int(so*min(self.stride_sizes)) for so in starting_offsets]
        stride_size = sample_rate / max(inference_sampling_rates)
        self.kernel_size = int(kernel_length * sample_rate)
        strides = (batch_size - 1) * stride_size
        fsize = int(fduration * sample_rate)
        size = strides + self.kernel_size + fsize
        length = size / sample_rate
        self.psd_estimator = PsdEstimator(
            length,
            sample_rate,
            fftlength=fftlength,
            overlap=None,
            average="median",
            fast=highpass is not None,
        )
        self.whiteners = torch.nn.ModuleList([Whiten(fduration, sample_rate, highpass, lowpass) 
                                              for highpass, lowpass in zip(high_passes, low_passes)])
        self.resamplers = torch.nn.ModuleList([T.Resample(sample_rate, rr) for rr in resample_rates])
        self.resample_rate = [sample_rate//rr for rr in resample_rates]
        self.fft_highpass = high_passes[-1]
        self.fft_lowpass = low_passes[-1]

    def forward(self, x: Tensor) -> Tensor:
        # Get the number of channels so we know how to
        # reshape `x` appropriately after unfolding to
        # ensure we have (batch, channels, time) shape
        out_x = tuple()
        x, psd = self.psd_estimator(x)
        print(f'psd = {psd.shape}')
        for i in range(self.num_timeseries):
            whitened = self.whiteners[i](x.double(), psd)
            print(f'self.whiteners[i](x.double(), psd) = {whitened.shape}')
            sliced_x = whitened[..., self.starting_offsets[i]:self.ending_offsets[i]]
            print(f'sliced_x[..., self.starting_offsets[i]:self.ending_offsets[i]] = {sliced_x.shape}')
            sliced_x = unfold_windows(sliced_x, self.kernel_sizes[i], self.stride_sizes[i])
            print(f'unfold_windows(whitened, self.kernel_sizes[i], self.stride_sizes[i]) = {sliced_x.shape}')
            sliced_x = sliced_x.reshape(-1, self.num_ifos, self.kernel_sizes[i])
            print(f'sliced_x.reshape(-1, self.num_ifos, self.kernel_sizes[i]) = {sliced_x.shape}')
            bs = sliced_x.shape[0]
            sliced_x = sliced_x.reshape((self.num_ifos*bs, 1, self.kernel_sizes[i])).squeeze(-2)
            print(f'sliced_x.reshape((self.num_ifos*bs, 1, self.kernel_sizes[i])).squeeze(-2) = {sliced_x.shape}')
            sliced_x = self.resamplers[i](sliced_x)
            print(f'self.resamplers[i](sliced_x) = {sliced_x.shape}')
            sliced_x = sliced_x.reshape((bs, self.num_ifos, self.kernel_sizes[i]//self.resample_rate[i]))
            print(f'sliced_x.reshape((bs, self.num_ifos, self.kernel_sizes[i]//self.resample_rate[i])) = {sliced_x.shape}')
            out_x = out_x + (sliced_x,)
        
        whitened = self.whiteners[-1](x.double(), psd)
        sliced_x = whitened[..., self.starting_offsets[-1]:self.ending_offsets[-1]]
        sliced_x = unfold_windows(sliced_x, self.kernel_sizes[-1], self.stride_sizes[-1])
        sliced_x = sliced_x.reshape(-1, self.num_ifos, self.kernel_sizes[-1])
        bs = sliced_x.shape[0]
        sliced_x = sliced_x.reshape((self.num_ifos*bs, 1, self.kernel_sizes[-1])).squeeze(-2)
        sliced_x = self.resamplers[-1](sliced_x)
        sliced_x = sliced_x.reshape((bs, self.num_ifos, self.kernel_sizes[-1]//self.resample_rate[-1]))
        freqs = torch.fft.rfftfreq(
            sliced_x.shape[-1], d=1 / self.resample_rates[-1]
        )
        sliced_x = torch.fft.rfft(sliced_x)
        mask = freqs >= self.fft_highpass
        mask *= freqs <= self.fft_lowpass
        sliced_x = sliced_x[:, :, mask]
        freqs = np.linspace(0, self.resample_rates[-1]/2, psd.shape[-1])
        mask = freqs >= self.fft_highpass
        mask *= freqs <= self.fft_lowpass
        asds = (psd[..., mask]**0.5 * 1e23).float()
        asds = asds.unsqueeze(dim = 0)
        asds = F.interpolate(asds, size=(sliced_x.shape[-1],), mode="linear", align_corners=False)
        asds = asds.repeat(sliced_x.shape[0], 1, 1)
        sliced_x = torch.cat((sliced_x.real, sliced_x.imag, 1/asds), dim=1)
        out_x = out_x + (sliced_x,)
        return out_x

def mm_add_streaming_input_preprocessor(
    input_shapes: list,
    ensemble: "EnsembleModel",
    input: list,
    psd_length: float,
    sample_rate: float,
    kernel_length: float,
    inference_sampling_rate: float,
    fduration: float,
    fftlength: float,
    resample_rates: Sequence[float], 
    kernel_lengths: Sequence[float], 
    high_passes: Sequence[float], 
    low_passes: Sequence[float],
    inference_sampling_rates: Sequence[float],
    starting_offsets: Sequence[int],
    num_ifos: int,
    q: Optional[float] = None,
    highpass: Optional[float] = None,
    lowpass: Optional[float] = None,
    preproc_instances: Optional[int] = None,
    streams_per_gpu: int = 1,
) -> "ExposedTensor":
    """Create a snapshotter model and add it to the repository"""

    augmentor = None

    snapshotter = BackgroundSnapshotter(
        psd_length=psd_length,
        kernel_length=kernel_length,
        fduration=fduration,
        sample_rate=sample_rate,
        inference_sampling_rate=inference_sampling_rate,
    )

    stride = int(sample_rate / inference_sampling_rate)
    state_shape = (2, num_ifos, snapshotter.state_size)
    input_shape = (2, num_ifos, batch_size * stride)
    streaming_model = streaming_utils.add_streaming_model(
        ensemble.repository,
        streaming_layer=snapshotter,
        name="snapshotter",
        input_name="stream",
        input_shape=input_shape,
        state_names=["snapshot"],
        state_shapes=[state_shape],
        output_names=["strain"],
        streams_per_gpu=streams_per_gpu,
    )
    ensemble.add_input(streaming_model.inputs["stream"])
    preprocessor = mm_BatchWhitener(
        resample_rates = resample_rates, 
        kernel_lengths = kernel_lengths, 
        high_passes = high_passes, 
        low_passes = low_passes,
        inference_sampling_rates = inference_sampling_rates,
        starting_offsets = starting_offsets,
        num_ifos = num_ifos,
        kernel_length=kernel_length,
        sample_rate=sample_rate,
        batch_size=batch_size,
        fduration=fduration,
        fftlength=fftlength,
    )
    preproc_model = ensemble.repository.add(
        "preprocessor", platform=Platform.TORCHSCRIPT
    )
    # if we specified a number of instances we want per-gpu
    # for each model at inference time, scale them now
    if preproc_instances is not None:
        scale_model(preproc_model, preproc_instances)

    input_shape = streaming_model.outputs["strain"].shape
    preproc_model.export_version(
        preprocessor,
        input_shapes={"strain": input_shape},
        output_names=[f"whitened_{i}" for i in range(len(input_shapes))],
    )
    ensemble.pipe(
        streaming_model.outputs["strain"],
        preproc_model.inputs["strain"],
    )
    return [preproc_model.outputs[f"whitened_{i}"] for i in range(len(input_shapes))]

In [2]:
batch_size = 128
weights = '/home/seiya.tsukamoto/aframe/mm_v1/mm_v1/training/model.pt'
batch_file = '/home/seiya.tsukamoto/aframe/mm_v1/mm_v1/training/batch.h5'
repository_directory = '/home/seiya.tsukamoto/aframe/mm_v1/mm_v1/results/model_repo/'
clean = False
platform = qv.Platform.TENSORRT
aframe_instances = None
num_ifos = 2
kernel_length = 2.375
sample_rate = 2048
classes = [64, 64, 64, 64]
#layers = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
inference_sampling_rates = [8, 4, 2, 8]
psd_length=8
sample_rate=2048
inference_sampling_rate=8
fduration=1
fftlength=None
q = None
highpass=1024
lowpass=32
preproc_instances=None
streams_per_gpu=6
resample_rates = [2048, 1024, 512, 2048]
kernel_lengths = [0.5, 1, 2, 1]
high_passes = [32, 32, 32, 32]
low_passes = [1024, 128, 64, 1024]
starting_offsets = [0, 1, 3, 0]

In [3]:
# load in the model graph
logging.info("Initializing model graph")

with open_file(weights, "rb") as f:
    graph = torch.jit.load(f, map_location="cpu")

graph.eval()
logging.info(f"Initialize:\n{graph}")

with open_file(batch_file, "rb") as f:
    batch_file = h5py.File(io.BytesIO(f.read()))

layers = sorted(batch_file.keys() - "y")
input_shapes = [(batch_size*inference_sampling_rates[i]//max(inference_sampling_rates), 
                 batch_file[layer].shape[-2], 
                 batch_file[layer].shape[-1]) for i, layer in enumerate(layers)]
n_layers = len(layers)

graphs = []
model_parent_dir = os.path.dirname(weights)
for i in range(n_layers):
    with open_file(os.path.join(model_parent_dir, f"resnets_{i}.pt"), "rb") as f:
        graphs.append(torch.jit.load(f, map_location="cpu"))
        graphs[-1].eval()

with open_file(os.path.join(model_parent_dir, f"fc.pt"), "rb") as f:
    fc = torch.jit.load(f, map_location="cpu")
    fc.eval()
# instantiate a model repository at the
# indicated location. Split up the preprocessor
# and the neural network (which we'll call aframe)
# to export/scale them separately, and start by
# seeing if either already exists in the model repo

In [4]:
repo = qv.ModelRepository(repository_directory, clean)
aframe = []
for i in range(n_layers):
    try:
        aframe.append(repo.models[f"resnet_{i}"])
    except KeyError:
        aframe.append(repo.add(f"resnet_{i}", platform=platform))

try:
    aframe.append(repo.models["fc"])
except KeyError:
    aframe.append(repo.add("fc", platform=platform))

try:
    concatenation = repo.models["concatenation_layer"]
except KeyError:
    concatenation = repo.add("concatenation_layer", platform=platform)

# if we specified a number of instances we want per-gpu
# for each model at inference time, scale them now
#if aframe_instances is not None:
#    scale_model(aframe, aframe_instances)

# the network will have some different keyword
# arguments required for export depending on
# the target inference platform
# TODO: hardcoding these kwargs for now, but worth
# thinking about a more robust way to handle this
kwargs = {}
if platform == qv.Platform.ONNX:
    kwargs["opset_version"] = 13

    # turn off graph optimization because of this error
    # https://github.com/triton-inference-server/server/issues/3418
    aframe.config.optimization.graph.level = -1
elif platform == qv.Platform.TENSORRT:
    kwargs["use_fp16"] = False

for i in range(n_layers):
    aframe[i].export_version(
        graphs[i],
        input_shapes={f"whitened_{i}": input_shapes[i]},
        output_names=[f"classes_{i}"],
        **kwargs,
    )
cl  = concatenation_layer(inference_sampling_rates)
concatenation.export_version(
    cl,
    input_shapes={f"classes_{i}": (input_shapes[i][0], classes[i]) 
                  for i in range(n_layers)},
    output_names=["concatenated"],
    **kwargs,
)

aframe[-1].export_version(
    fc,
    input_shapes={"concatenated": (batch_size, sum(classes))},
    output_names=["discriminator"],
    **kwargs,
)



[09/20/2025-21:47:27] [TRT] [W] onnx2trt_utils.cpp:377: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.




'fc/1/model.plan'

In [12]:
ensemble_name = "aframe-stream"
# if we don't, create one
ensemble = repo.add(ensemble_name, platform=qv.Platform.ENSEMBLE)
# if fftlength isn't specified, calculate the default value
fftlength = fftlength or kernel_length + fduration

In [13]:
whitened = mm_add_streaming_input_preprocessor(
    input_shapes = input_shapes,
    ensemble = ensemble,
    input = [aframe[i].inputs[f"whitened_{i}"] for i in range(n_layers)],
    psd_length=psd_length,
    sample_rate=sample_rate,
    kernel_length=kernel_length,
    inference_sampling_rate=inference_sampling_rate,
    fduration=fduration,
    fftlength=fftlength,
    q=q,
    highpass=highpass,
    lowpass=lowpass,
    preproc_instances=preproc_instances,
    streams_per_gpu=streams_per_gpu,
    resample_rates = resample_rates, 
    kernel_lengths = kernel_lengths, 
    high_passes = high_passes, 
    low_passes = low_passes,
    inference_sampling_rates = inference_sampling_rates,
    starting_offsets = starting_offsets,
    num_ifos = num_ifos,
)
for i in range(n_layers):
    ensemble.pipe(whitened[i], aframe[i].inputs[f"whitened_{i}"])

for i in range(n_layers):
    ensemble.pipe(aframe[i].outputs[f"classes_{i}"], concatenation.inputs[f"classes_{i}"])

ensemble.pipe(concatenation.outputs["concatenated"], aframe[-1].inputs["concatenated"])
ensemble.add_output(aframe[-1].outputs["discriminator"])
# export the ensemble model, which basically amounts
# to writing its config and creating an empty version entry
ensemble.export_version(None)

# keep snapshot states around for a long time in case there are
# unexpected bottlenecks which throttle update for a few seconds
snapshotter = repo.models["snapshotter"]
snapshotter.config.sequence_batching.max_sequence_idle_microseconds = int(
    6e10
)
snapshotter.config.write()

  export_path = exporter(
  output_shapes = self._get_output_shapes(model_fn, output_names)
  if X.ndim == 3 and X.size(0) == 2:
  if x.shape[-1] < nperseg:
  if N <= (2 * pad):
  if psd.size(-1) != num_freqs:
  idx = int(highpass / df)
  idx = int(lowpass / df)
  if inv_asd.size(-1) % 2:
  if 2 * pad < q.size(-1):
  if remainder == 0:


psd = torch.Size([2, 3457])
self.whiteners[i](x.double(), psd) = torch.Size([1, 2, 37376])
sliced_x[..., self.starting_offsets[i]:self.ending_offsets[i]] = torch.Size([1, 2, 33536])
unfold_windows(whitened, self.kernel_sizes[i], self.stride_sizes[i]) = torch.Size([128, 1, 2, 1024])
sliced_x.reshape(-1, self.num_ifos, self.kernel_sizes[i]) = torch.Size([128, 2, 1024])
sliced_x.reshape((self.num_ifos*bs, 1, self.kernel_sizes[i])).squeeze(-2) = torch.Size([256, 1024])
self.resamplers[i](sliced_x) = torch.Size([256, 1024])
sliced_x.reshape((bs, self.num_ifos, self.kernel_sizes[i]//self.resample_rate[i])) = torch.Size([128, 2, 1024])
self.whiteners[i](x.double(), psd) = torch.Size([1, 2, 37376])
sliced_x[..., self.starting_offsets[i]:self.ending_offsets[i]] = torch.Size([1, 2, 34560])
unfold_windows(whitened, self.kernel_sizes[i], self.stride_sizes[i]) = torch.Size([64, 1, 2, 2048])
sliced_x.reshape(-1, self.num_ifos, self.kernel_sizes[i]) = torch.Size([64, 2, 2048])
sliced_x.reshape((self.